probnmn.trainers.module_training_trainer

class probnmn.trainers.module_training_trainer.ModuleTrainingTrainer(config: probnmn.config.Config, serialization_dir: str, gpu_ids: List[int] = [0], cpu_workers: int = 0)[source]

Bases: probnmn.trainers._trainer._Trainer

Performs training for module_training phase, using batches of training examples from ModuleTrainingDataset.

Parameters
config: Config

A Config object with all the relevant configuration parameters.

serialization_dir: str

Path to a directory for tensorboard logging and serializing checkpoints.

gpu_ids: List[int], optional (default=[0])

List of GPU IDs to use or evaluation, [-1] - use CPU.

cpu_workers: int, optional (default = 0)

Number of CPU workers to use for fetching batch examples in dataloader.

Examples

>>> config = Config("config.yaml")  # PHASE must be "module_training"
>>> trainer = ModuleTrainingTrainer(config, serialization_dir="/tmp")
>>> evaluator = ModuleTrainingEvaluator(config, trainer.models)
>>> for iteration in range(100):
>>>     trainer.step()
>>>     # validation every 100 steps
>>>     if iteration % 100 == 0:
>>>         val_metrics = evaluator.evaluate()
>>>         trainer.after_validation(val_metrics, iteration)
_do_iteration(self, batch:Dict[str, Any]) → Dict[str, Any][source]

Forward and backward passes on models, given a batch sampled from dataloader.

Parameters
batch: Dict[str, Any]

A batch of training examples sampled from dataloader. See step() and _cycle() on how this batch is sampled.

Returns
Dict[str, Any]

An output dictionary typically returned by the models. This would be passed to _after_iteration() for tensorboard logging.

after_validation(self, val_metrics:Dict[str, Any], iteration:Union[int, NoneType]=None)[source]

Set "metric" key in val_metrics, this governs learning rate scheduling and keeping track of best checkpoint (in super method). This metric will be answer accuracy.

Super method will perform learning rate scheduling, serialize checkpoint, and log all the validation metrics to tensorboard.

Parameters
val_metrics: Dict[str, Any]

Validation metrics of NeuralModuleNetwork. Returned by evaluate method of ModuleTrainingEvaluator.

iteration: int, optional (default = None)

Iteration number. If None, use the internal self._iteration counter.