probnmn.trainers.joint_training_trainer¶
-
class
probnmn.trainers.joint_training_trainer.JointTrainingTrainer(config: probnmn.config.Config, serialization_dir: str, gpu_ids: List[int] = [0], cpu_workers: int = 0)[source]¶ Bases:
probnmn.trainers._trainer._TrainerPerforms training for
joint_trainingphase, using batches of training examples fromJointTrainingDataset.- Parameters
- config: Config
A
Configobject with all the relevant configuration parameters.- serialization_dir: str
Path to a directory for tensorboard logging and serializing checkpoints.
- gpu_ids: List[int], optional (default=[0])
List of GPU IDs to use or evaluation,
[-1]- use CPU.- cpu_workers: int, optional (default = 0)
Number of CPU workers to use for fetching batch examples in dataloader.
Examples
>>> config = Config("config.yaml") # PHASE must be "joint_training" >>> trainer = JointTrainingTrainer(config, serialization_dir="/tmp") >>> evaluator = JointTrainingEvaluator(config, trainer.models) >>> for iteration in range(100): >>> trainer.step() >>> # validation every 100 steps >>> if iteration % 100 == 0: >>> val_metrics = evaluator.evaluate() >>> trainer.after_validation(val_metrics, iteration)
-
_do_iteration(self, batch:Dict[str, Any]) → Dict[str, Any][source]¶ Forward and backward passes on models, given a batch sampled from dataloader.
- Parameters
- batch: Dict[str, Any]
A batch of training examples sampled from dataloader. See
step()and_cycle()on how this batch is sampled.
- Returns
- Dict[str, Any]
An output dictionary typically returned by the models. This would be passed to
_after_iteration()for tensorboard logging.
-
after_validation(self, val_metrics:Dict[str, Any], iteration:Union[int, NoneType]=None)[source]¶ Set
"metric"key inval_metrics, this governs learning rate scheduling and keeping track of best checkpoint (insupermethod). This metric will be answer accuracy.Super method will perform learning rate scheduling, serialize checkpoint, and log all the validation metrics to tensorboard.
- Parameters
- val_metrics: Dict[str, Any]
Validation metrics of
NeuralModuleNetwork. Returned byevaluatemethod ofJointTrainingEvaluator.- iteration: int, optional (default = None)
Iteration number. If
None, use the internalself._iterationcounter.