probnmn.trainers._trainer¶
-
class
probnmn.trainers._trainer.
_Trainer
(config: probnmn.config.Config, dataloader: torch.utils.data.dataloader.DataLoader, models: Dict[str, torch.nn.modules.module.Module], serialization_dir: str, gpu_ids: List[int] = [0])[source]¶ Bases:
object
A base class for generic training of models. This class can have multiple models interacting with each other, rather than a single model, which is suitable to our use-case (for example,
module_training
phase has two models:ProgramGenerator
andNeuralModuleNetwork
). It offers full flexibility, with sensible defaults which may be changed (or disabled) while extending this class.Default
Adam
Optimizer, updates parameters of all models in this trainer. Learning rate and weight decay for this optimizer are picked up from the provided config.Default
ReduceLROnPlateau
learning rate scheduler. Gamma and patience arguments are picked up from the provided config. Observed metric is assumed to be of type “higher is better”. For ‘lower is better” metrics, make sure to reciprocate.Tensorboard logging of loss curves, metrics etc.
Serialization of models and optimizer as checkpoint (.pth) files after every validation. The observed metric for keeping track of best checkpoint is of type “higher is better”, follow (2) above if the observed metric is of type “lower is better”.
Extend this class and override suitable methods as per requirements, some important ones are:
step()
, provides complete customization, this is the method which comprises of one full training iteration, and internally calls (in order) -_before_iteration()
,_do_iteration()
and_after_iteration()
. Most of the times you may not require overriding this method, instead one of the mentioned three methods called by :meth:`step._do_iteration()
, with core training loop - what happens every iteration, given abatch
from the dataloader this class holds._before_iteration()
and_after_iteration()
, for any pre- or post-processing steps. Default behaviour:_before_iteration()
- calloptimizer.zero_grad()
_after_iteration()
- calloptimizer.step()
and do tensorboard logging.
after_validation()
, to specify any steps after evaluation. Default behaviour is to do learning rate scheduling and log validation metrics on tensorboard.
- Parameters
- config: Config
A
Config
object with all the relevant configuration parameters.- dataloader: torch.utils.data.DataLoader
A
DataLoader
which provides batches of training examples. It wraps one ofprobnmn.data.datasets
depending on the evaluation phase.- models: Dict[str, Type[nn.Module]]
All the models which interact with each other during training. These are one or more from
probnmn.models
depending on the training phase.- serialization_dir: str
Path to a directory for tensorboard logging and serializing checkpoints.
- gpu_ids: List[int], optional (default=[0])
List of GPU IDs to use or evaluation,
[-1]
- use CPU.
Notes
All models are passed by assignment, so they could be shared with an external evaluator. Do not set
self._models = ...
anywhere while extending this class.-
step
(self, iteration:Union[int, NoneType]=None)[source]¶ Perform one iteration of training.
- Parameters
- iteration: int, optional (default = None)
Iteration number (useful to hard set to any number when loading checkpoint). If
None
, use the internalself._iteration
counter.
-
_before_iteration
(self)[source]¶ Steps to do before doing the forward pass of iteration. Default behavior is to simply call
zero_grad()
for optimizer. Called insidestep()
.
-
_do_iteration
(self, batch:Dict[str, Any]) → Dict[str, Any][source]¶ Forward and backward passes on models, given a batch sampled from dataloader.
- Parameters
- Returns
- Dict[str, Any]
An output dictionary typically returned by the models. This would be passed to
_after_iteration()
for tensorboard logging.
-
_after_iteration
(self, output_dict:Dict[str, Any])[source]¶ Steps to do after doing the forward pass of iteration. Default behavior is to simply do gradient update through
optimizer.step()
, and log metrics to tensorboard.- Parameters
- output_dict: Dict[str, Any]
This is exactly the object returned by :meth:_do_iteration`, which would contain all the required losses for tensorboard logging.
-
after_validation
(self, val_metrics:Dict[str, Any], iteration:Union[int, NoneType]=None)[source]¶ Steps to do after an external
_Evaluator
performs evaluation. This is not called bystep()
, call it from outside at appropriate time. Default behavior is to perform learning rate scheduling, serializaing checkpoint and to log validation metrics to tensorboard.Since this implementation assumes a key
"metric"
inval_metrics
, it is convenient to set this key while overriding this method, when there are multiple models and multiple metrics and there is one metric which decides best checkpoint.- Parameters
- val_metrics: Dict[str, Any]
Validation metrics for all the models. Returned by
evaluate
method of_Evaluator
(or its extended class).- iteration: int, optional (default = None)
Iteration number. If
None
, use the internalself._iteration
counter.
-
load_checkpoint
(self, checkpoint_path:str, iteration:Union[int, NoneType]=None)[source]¶ Load a checkpoint to continue training from. The iteration when this checkpoint was serialized, is inferred from its name (so do not rename after serialization).
- Parameters
- checkpoint_path: str
Path to a checkpoint containing models and optimizers of the phase which is being trained on.
- iteration: int, optional (default = None)
Iteration number. If
None
, get it from the checkpoint.
-
_cycle
(self, dataloader:torch.utils.data.dataloader.DataLoader) → Generator[Dict[str, torch.Tensor], NoneType, NoneType][source]¶ A generator which yields a random batch from dataloader perpetually. This generator is used in the constructor.
This is done so because we train for a fixed number of iterations, and do not have the notion of ‘epochs’. Using
itertools.cycle
with dataloader is harmful and may cause unexpeced memory leaks.