virtex.optim.lookahead
Lookahead Optimizer: k steps forward, 1 step back.
This implementation is adapted with minimal modifications from the authors’ implementation.
If you take it from here, please cite them:
@inproceedings{zhang2019lookahead,
title={Lookahead Optimizer: k steps forward, 1 step back},
author={Zhang, Michael R and Lucas, James and Hinton, Geoffrey and Ba, Jimmy},
journal={NeurIPS},
year={2019}
}
- class virtex.optim.lookahead.Lookahead(optimizer: torch.optim.optimizer.Optimizer, k: int = 5, alpha: float = 0.8)[source]
Bases:
torch.optim.optimizer.Optimizer
Implements Lookahead optimizer.
- Parameters
optimizer – Wrapper inner optimizer. The weights it manages will be the “fast” weights.
k – Number of lookahead steps before updating “slow” weights.
alpha – Linear interpolation factor, 1.0 recovers inner optimizer.
- state_dict()[source]
Returns the state of the optimizer as a
dict
.It contains two entries:
- state - a dict holding current optimization state. Its content
differs between optimizer classes.
- param_groups - a list containing all parameter groups where each
parameter group is a dict
- load_state_dict(state_dict: Dict[str, Any])[source]
Loads the optimizer state.
- Parameters
state_dict (dict) – optimizer state. Should be an object returned from a call to
state_dict()
.
- step(closure: Optional[Callable] = None)[source]
Perform a single Lookahead optimization step.
- Parameters
closure – A callable that re-evaluates the model and returns loss.
- load_slow_weights()[source]
Load slow weights from Lookahead optimizer. Useful for performing evaluation on the slow weights (which typically generalize better).
This method backs up fast weights to load them after evaluation. No need to call this method if evaluation happens just after a lookahead step.
- restore_fast_weights()[source]
Restore fast weights for optimization. Call this after evaluation if
load_slow_weights()
was called.