probnmn.trainers.question_coding_trainer¶
-
class
probnmn.trainers.question_coding_trainer.
QuestionCodingTrainer
(config: probnmn.config.Config, serialization_dir: str, gpu_ids: List[int] = [0], cpu_workers: int = 0)[source]¶ Bases:
probnmn.trainers._trainer._Trainer
Performs training for
question_coding
phase, using batches of training examples fromQuestionCodingDataset
.- Parameters
- config: Config
A
Config
object with all the relevant configuration parameters.- serialization_dir: str
Path to a directory for tensorboard logging and serializing checkpoints.
- gpu_ids: List[int], optional (default=[0])
List of GPU IDs to use or evaluation,
[-1]
- use CPU.- cpu_workers: int, optional (default = 0)
Number of CPU workers to use for fetching batch examples in dataloader.
Examples
>>> config = Config("config.yaml") # PHASE must be "question_coding" >>> trainer = QuestionCodingTrainer(config, serialization_dir="/tmp") >>> evaluator = QuestionCodingEvaluator(config, trainer.models) >>> for iteration in range(100): >>> trainer.step() >>> # validation every 100 steps >>> if iteration % 100 == 0: >>> val_metrics = evaluator.evaluate() >>> trainer.after_validation(val_metrics, iteration)
-
_do_iteration
(self, batch:Dict[str, Any]) → Dict[str, Any][source]¶ Forward and backward passes on models, given a batch sampled from dataloader.
- Parameters
- batch: Dict[str, Any]
A batch of training examples sampled from dataloader. See
step()
and_cycle()
on how this batch is sampled.
- Returns
- Dict[str, Any]
An output dictionary typically returned by the models. This would be passed to
_after_iteration()
for tensorboard logging.
-
after_validation
(self, val_metrics:Dict[str, Any], iteration:Union[int, NoneType]=None)[source]¶ Steps to do after an external
_Evaluator
performs evaluation. This is not called bystep()
, call it from outside at appropriate time. Default behavior is to perform learning rate scheduling, serializaing checkpoint and to log validation metrics to tensorboard.Since this implementation assumes a key
"metric"
inval_metrics
, it is convenient to set this key while overriding this method, when there are multiple models and multiple metrics and there is one metric which decides best checkpoint.- Parameters
- val_metrics: Dict[str, Any]
Validation metrics for all the models. Returned by
evaluate
method of_Evaluator
(or its extended class).- iteration: int, optional (default = None)
Iteration number. If
None
, use the internalself._iteration
counter.