probnmn.utils.checkpointing¶
-
class
probnmn.utils.checkpointing.
CheckpointManager
(serialization_dir: str = '/tmp', keep_recent: int = 10, **checkpointables: Any)[source]¶ Bases:
object
A
CheckpointManager
periodically serializes models and other checkpointable objects (which implementstate_dict
method) as .pth files during training, and optionally keeps track of best performing checkpoint based on an observed metric.This class closely follows the API of PyTorch optimizers and learning rate schedulers.
Note
For
DataParallel
andDistributedDataParallel
objects,module.state_dict
is called instead ofstate_dict
.Note
The observed metric for keeping best checkpoint is assumed “higher is better”, flip the sign if otherwise.
- Parameters
- serialization_dir: str
Path to an empty or non-existent directory to save checkpoints.
- keep_recent: int, optional (default=10)
Number of recent ‘k’ checkpoints to keep on disk. Older checkpoints will be removed. Set to a very large value for keeping all checkpoints.
- checkpointables: Any
Keyword arguments with any checkpointable objects, for example: model, optimizer, learning rate scheduler. Their state dicts can be accessed as the name of keyword.
Examples
>>> model = torch.nn.Linear(10, 2) >>> optimizer = torch.optim.Adam(model.parameters()) >>> ckpt_manager = CheckpointManager("/tmp/ckpt", model=model, optimizer=optimizer) >>> num_epochs = 20 >>> for epoch in range(num_epochs): ... train(model) ... val_loss = validate(model) ... ckpt_manager.step(- val_loss, epoch)
-
step
(self, iteration:int, metric:Union[float, NoneType]=None)[source]¶ Serialize checkpoint and update best checkpoint based on metric.
-
load
(self, checkpoint_path:str)[source]¶ Load a serialized checkpoint from a path. This method will try to find each of
checkpointables
in the file and load its state dict. Since our checkpointables are held as references, this method does not return them.- Parameters
- checkpoint_path: str
Path to a checkpoint serialized by
step()
.
- Returns
- int
Iteration corresponding to the loaded checkpoint. Useful for resuming training. This will be -1 in case of best checkpoint, or if info does not exist.