virtex.utils.checkpointing
- class virtex.utils.checkpointing.CheckpointManager(serialization_dir: str = '/tmp', keep_recent: int = 200, **checkpointables: Any)[source]
Bases:
object
A helper class to periodically serialize models and other checkpointable objects (optimizers, LR schedulers etc., which implement
state_dict
method) during training, and optionally record best performing checkpoint based on an observed metric.Note
For
DistributedDataParallel
objects,state_dict
of internal model is serialized.Note
The observed metric for keeping best checkpoint is assumed “higher is better”, flip the sign if otherwise.
- Parameters
serialization_dir – Path to a directory to save checkpoints.
keep_recent – Number of recent
k
checkpoints to keep on disk. Older checkpoints will be removed. Set to a very large value for keeping all checkpoints.checkpointables – Keyword arguments with any checkpointable objects, for example: model, optimizer, learning rate scheduler.
Examples
>>> model = torch.nn.Linear(10, 2) >>> optimizer = torch.optim.Adam(model.parameters()) >>> ckpt_manager = CheckpointManager("/tmp", model=model, optimizer=optimizer) >>> num_epochs = 20 >>> for epoch in range(num_epochs): ... train(model) ... val_loss = validate(model) ... ckpt_manager.step(- val_loss, epoch)
- step(iteration: int, metric: Optional[float] = None)[source]
Serialize checkpoint and update best checkpoint based on metric. Keys in serialized checkpoint match those in
checkpointables
.- Parameters
iteration – Current training iteration. Will be saved with other checkpointables.
metric – Observed metric (higher is better) for keeping track of the best checkpoint. If this is
None
, best chckpoint will not be recorded/updated.
- load(checkpoint_path: str)[source]
Load a serialized checkpoint from a path. This method will try to find each of
checkpointables
in the file and load its state dict. Since our checkpointables are held as references, this method does not return them.- Parameters
checkpoint_path – Path to a checkpoint serialized by
step()
.- Returns
Iteration corresponding to the loaded checkpoint. Useful for resuming training. This will be -1 in case of best checkpoint, or if info does not exist.