virtex.utils.checkpointing


class virtex.utils.checkpointing.CheckpointManager(serialization_dir: str = '/tmp', keep_recent: int = 200, **checkpointables: Any)[source]

Bases: object

A helper class to periodically serialize models and other checkpointable objects (optimizers, LR schedulers etc., which implement state_dict method) during training, and optionally record best performing checkpoint based on an observed metric.

Note

For DistributedDataParallel objects, state_dict of internal model is serialized.

Note

The observed metric for keeping best checkpoint is assumed “higher is better”, flip the sign if otherwise.

Parameters
  • serialization_dir – Path to a directory to save checkpoints.

  • keep_recent – Number of recent k checkpoints to keep on disk. Older checkpoints will be removed. Set to a very large value for keeping all checkpoints.

  • checkpointables – Keyword arguments with any checkpointable objects, for example: model, optimizer, learning rate scheduler.

Examples

>>> model = torch.nn.Linear(10, 2)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> ckpt_manager = CheckpointManager("/tmp", model=model, optimizer=optimizer)
>>> num_epochs = 20
>>> for epoch in range(num_epochs):
...     train(model)
...     val_loss = validate(model)
...     ckpt_manager.step(- val_loss, epoch)
step(iteration: int, metric: Optional[float] = None)[source]

Serialize checkpoint and update best checkpoint based on metric. Keys in serialized checkpoint match those in checkpointables.

Parameters
  • iteration – Current training iteration. Will be saved with other checkpointables.

  • metric – Observed metric (higher is better) for keeping track of the best checkpoint. If this is None, best chckpoint will not be recorded/updated.

_state_dict()[source]

Return a dict containing state dict of all checkpointables.

remove_earliest_checkpoint()[source]

Remove earliest serialized checkpoint from disk.

load(checkpoint_path: str)[source]

Load a serialized checkpoint from a path. This method will try to find each of checkpointables in the file and load its state dict. Since our checkpointables are held as references, this method does not return them.

Parameters

checkpoint_path – Path to a checkpoint serialized by step().

Returns

Iteration corresponding to the loaded checkpoint. Useful for resuming training. This will be -1 in case of best checkpoint, or if info does not exist.