probnmn.trainers._trainer

class probnmn.trainers._trainer._Trainer(config: probnmn.config.Config, dataloader: torch.utils.data.dataloader.DataLoader, models: Dict[str, torch.nn.modules.module.Module], serialization_dir: str, gpu_ids: List[int] = [0])[source]

Bases: object

A base class for generic training of models. This class can have multiple models interacting with each other, rather than a single model, which is suitable to our use-case (for example, module_training phase has two models: ProgramGenerator and NeuralModuleNetwork). It offers full flexibility, with sensible defaults which may be changed (or disabled) while extending this class.

  1. Default Adam Optimizer, updates parameters of all models in this trainer. Learning rate and weight decay for this optimizer are picked up from the provided config.

  2. Default ReduceLROnPlateau learning rate scheduler. Gamma and patience arguments are picked up from the provided config. Observed metric is assumed to be of type “higher is better”. For ‘lower is better” metrics, make sure to reciprocate.

  3. Tensorboard logging of loss curves, metrics etc.

  4. Serialization of models and optimizer as checkpoint (.pth) files after every validation. The observed metric for keeping track of best checkpoint is of type “higher is better”, follow (2) above if the observed metric is of type “lower is better”.

Extend this class and override suitable methods as per requirements, some important ones are:

  1. step(), provides complete customization, this is the method which comprises of one full training iteration, and internally calls (in order) - _before_iteration(), _do_iteration() and _after_iteration(). Most of the times you may not require overriding this method, instead one of the mentioned three methods called by :meth:`step.

  2. _do_iteration(), with core training loop - what happens every iteration, given a batch from the dataloader this class holds.

  3. _before_iteration() and _after_iteration(), for any pre- or post-processing steps. Default behaviour:

  4. after_validation(), to specify any steps after evaluation. Default behaviour is to do learning rate scheduling and log validation metrics on tensorboard.

Parameters
config: Config

A Config object with all the relevant configuration parameters.

dataloader: torch.utils.data.DataLoader

A DataLoader which provides batches of training examples. It wraps one of probnmn.data.datasets depending on the evaluation phase.

models: Dict[str, Type[nn.Module]]

All the models which interact with each other during training. These are one or more from probnmn.models depending on the training phase.

serialization_dir: str

Path to a directory for tensorboard logging and serializing checkpoints.

gpu_ids: List[int], optional (default=[0])

List of GPU IDs to use or evaluation, [-1] - use CPU.

Notes

All models are passed by assignment, so they could be shared with an external evaluator. Do not set self._models = ... anywhere while extending this class.

step(self, iteration:Union[int, NoneType]=None)[source]

Perform one iteration of training.

Parameters
iteration: int, optional (default = None)

Iteration number (useful to hard set to any number when loading checkpoint). If None, use the internal self._iteration counter.

_before_iteration(self)[source]

Steps to do before doing the forward pass of iteration. Default behavior is to simply call zero_grad() for optimizer. Called inside step().

_do_iteration(self, batch:Dict[str, Any]) → Dict[str, Any][source]

Forward and backward passes on models, given a batch sampled from dataloader.

Parameters
batch: Dict[str, Any]

A batch of training examples sampled from dataloader. See step() and _cycle() on how this batch is sampled.

Returns
Dict[str, Any]

An output dictionary typically returned by the models. This would be passed to _after_iteration() for tensorboard logging.

_after_iteration(self, output_dict:Dict[str, Any])[source]

Steps to do after doing the forward pass of iteration. Default behavior is to simply do gradient update through optimizer.step(), and log metrics to tensorboard.

Parameters
output_dict: Dict[str, Any]

This is exactly the object returned by :meth:_do_iteration`, which would contain all the required losses for tensorboard logging.

after_validation(self, val_metrics:Dict[str, Any], iteration:Union[int, NoneType]=None)[source]

Steps to do after an external _Evaluator performs evaluation. This is not called by step(), call it from outside at appropriate time. Default behavior is to perform learning rate scheduling, serializaing checkpoint and to log validation metrics to tensorboard.

Since this implementation assumes a key "metric" in val_metrics, it is convenient to set this key while overriding this method, when there are multiple models and multiple metrics and there is one metric which decides best checkpoint.

Parameters
val_metrics: Dict[str, Any]

Validation metrics for all the models. Returned by evaluate method of _Evaluator (or its extended class).

iteration: int, optional (default = None)

Iteration number. If None, use the internal self._iteration counter.

load_checkpoint(self, checkpoint_path:str, iteration:Union[int, NoneType]=None)[source]

Load a checkpoint to continue training from. The iteration when this checkpoint was serialized, is inferred from its name (so do not rename after serialization).

Parameters
checkpoint_path: str

Path to a checkpoint containing models and optimizers of the phase which is being trained on.

iteration: int, optional (default = None)

Iteration number. If None, get it from the checkpoint.

_cycle(self, dataloader:torch.utils.data.dataloader.DataLoader) → Generator[Dict[str, torch.Tensor], NoneType, NoneType][source]

A generator which yields a random batch from dataloader perpetually. This generator is used in the constructor.

This is done so because we train for a fixed number of iterations, and do not have the notion of ‘epochs’. Using itertools.cycle with dataloader is harmful and may cause unexpeced memory leaks.