How to train your VirTex model?
We provide training scripts for all type of VirTex models from the paper; including our best-performing model and other ablations. Our training jobs are specified by config files (YAML). Execute all commands from project root to use the provided config files.
Training the base VirTex model
Train the base VirTex model with ResNet-50 visual backbone; and a textual head
with L = 1, H = 1024
using all default optimization hyperparameters.
python scripts/pretrain_virtex.py \
--config configs/_base_bicaptioning_R_50_L1_H1024.yaml \
--num-gpus-per-machine 8 \
--cpu-workers 4 \
--serialization-dir /tmp/VIRTEX_R_50_L1_H1024
# Default: --checkpoint-every 2000 --log-every 20
Training job will save checkpoints, tensorboard logs (loss curves and metrics),
and back up the config in --serialization-dir
. Use tensorboard --logdir
<serialization_dir>
to view training curves, validation metrics etc. directly
on tensorboard.
We recommend training with 8 GPUs on the same machine, although training with
multiple GPUs across machines (see: --num-machines
and --machine-rank
),
single GPU (--num-gpus-per-machine 1
) as well as CPU
(--num-gpus-per-machine 0
) is also supported. Using multiple GPUs for
interactive debugging with PDB is not supported, as PDB and multiprocessing
module do not play nice.
Reproducing all VirTex ablations
To reproduce all ablations from the paper,
replace the --config
argument in above command with the following (all
assumed to be relative to project root):
Pretraining Task Ablations
Bicaptioning: configs/task_ablations/bicaptioning_R_50_L1_H2048.yaml
Forward Captioning: configs/task_ablations/captioning_R_50_L1_H2048.yaml
Token Classification: configs/task_ablations/token_classification_R_50.yaml
Multilabel Classification: configs/task_ablations/multilabel_classification_R_50.yaml
Masked Language Modeling: configs/task_ablations/masked_lm_R_50_L1_H2048.yaml
Transformer Size Ablations
Width (H = 512): configs/width_ablations/bicaptioning_R_50_L1_H512.yaml
Width (H = 768): configs/width_ablations/bicaptioning_R_50_L1_H768.yaml
Width (H = 1024): configs/width_ablations/bicaptioning_R_50_L1_H1024.yaml
Width (H = 2048): configs/width_ablations/bicaptioning_R_50_L1_H2048.yaml
Depth (L = 1): configs/depth_ablations/bicaptioning_R_50_L1_H1024.yaml
Depth (L = 2): configs/depth_ablations/bicaptioning_R_50_L2_H1024.yaml
Depth (L = 3): configs/depth_ablations/bicaptioning_R_50_L3_H1024.yaml
Depth (L = 4): configs/depth_ablations/bicaptioning_R_50_L4_H1024.yaml
Backbone Ablations
ResNet-50: configs/backbone_ablations/bicaptioning_R_50_L1_H1024.yaml
ResNet-50 w2x: configs/backbone_ablations/bicaptioning_R_50W2X_L1_H1024.yaml
ResNet-101: configs/backbone_ablations/bicaptioning_R_101_L1_H1024.yaml
Note
Pretraining Task Ablations (1), Transformer Size Ablations (3 and 5) and Backbone Ablations (1) are all the same exact model.