probnmn.trainers.question_coding_trainer¶
-
class
probnmn.trainers.question_coding_trainer.QuestionCodingTrainer(config: probnmn.config.Config, serialization_dir: str, gpu_ids: List[int] = [0], cpu_workers: int = 0)[source]¶ Bases:
probnmn.trainers._trainer._TrainerPerforms training for
question_codingphase, using batches of training examples fromQuestionCodingDataset.- Parameters
- config: Config
A
Configobject with all the relevant configuration parameters.- serialization_dir: str
Path to a directory for tensorboard logging and serializing checkpoints.
- gpu_ids: List[int], optional (default=[0])
List of GPU IDs to use or evaluation,
[-1]- use CPU.- cpu_workers: int, optional (default = 0)
Number of CPU workers to use for fetching batch examples in dataloader.
Examples
>>> config = Config("config.yaml") # PHASE must be "question_coding" >>> trainer = QuestionCodingTrainer(config, serialization_dir="/tmp") >>> evaluator = QuestionCodingEvaluator(config, trainer.models) >>> for iteration in range(100): >>> trainer.step() >>> # validation every 100 steps >>> if iteration % 100 == 0: >>> val_metrics = evaluator.evaluate() >>> trainer.after_validation(val_metrics, iteration)
-
_do_iteration(self, batch:Dict[str, Any]) → Dict[str, Any][source]¶ Forward and backward passes on models, given a batch sampled from dataloader.
- Parameters
- batch: Dict[str, Any]
A batch of training examples sampled from dataloader. See
step()and_cycle()on how this batch is sampled.
- Returns
- Dict[str, Any]
An output dictionary typically returned by the models. This would be passed to
_after_iteration()for tensorboard logging.
-
after_validation(self, val_metrics:Dict[str, Any], iteration:Union[int, NoneType]=None)[source]¶ Steps to do after an external
_Evaluatorperforms evaluation. This is not called bystep(), call it from outside at appropriate time. Default behavior is to perform learning rate scheduling, serializaing checkpoint and to log validation metrics to tensorboard.Since this implementation assumes a key
"metric"inval_metrics, it is convenient to set this key while overriding this method, when there are multiple models and multiple metrics and there is one metric which decides best checkpoint.- Parameters
- val_metrics: Dict[str, Any]
Validation metrics for all the models. Returned by
evaluatemethod of_Evaluator(or its extended class).- iteration: int, optional (default = None)
Iteration number. If
None, use the internalself._iterationcounter.