VirTex Model Zoo
We provide a collection of pretrained model weights and corresponding config names in this model zoo. Tables contain partial paths to config files for each model, download link for pretrained weights and for reference – VOC07 mAP and ImageNet top-1 accuracy.
The simplest way to download and use a full pretrained model (including both, the visual backbone and the textual head) is through virtex.model_zoo API as follows. This code snippet works from anywhere, and does not require to be executed from project root.
# Get our full best performing VirTex model:
import virtex.model_zoo as mz
model = mz.get("width_ablations/bicaptioning_R_50_L1_H2048.yaml", pretrained=True)
# Optionally extract the torchvision-like visual backbone (with ``avgpool``
# and ``fc`` layers replaced with ``nn.Identity`` module).
cnn = model.visual.cnn
Alternatively, weights can be manually downloaded from links below, and this can be executed from the project root:
from virtex.config import Config
from virtex.factories import PretrainingModelFactory
from virtex.utils.checkpointing import CheckpointManager
# Get the best performing VirTex model:
_C = Config("configs/width_ablations/bicaptioning_R_50_L1_H2048.yaml")
model = PretrainingModelFactory.from_config(_C)
CheckpointManager(model=model).load("/path/to/downloaded/weights.pth")
# Optionally extract the torchvision-like visual backbone (with ``avgpool``
# and ``fc`` layers replaced with ``nn.Identity`` module).
cnn = model.visual.cnn
The pretrained ResNet-50 visual backbone of our best performing model
(width_ablations/bicaptioning_R_50_L1_H2048.yaml
) can be loaded in a single
line, without following any installation steps (only requires PyTorch v1.5):
import torch
model = torch.hub.load("kdexd/virtex", "resnet50", pretrained=True)
# This is a torchvision-like resnet50 model, with ``avgpool`` and ``fc``
# layers replaced with ``nn.Identity`` module.
image_batch = torch.randn(1, 3, 224, 224) # batch tensor of one image.
features_batch = model(image_batch) # shape: (1, 2048, 7, 7)
Pretraining Task Ablations
Model Config Name | VOC07 mAP |
ImageNet Top-1 Acc. |
Model URL |
task_ablations/bicaptioning_R_50_L1_H2048.yaml | 88.7 | 53.8 | model |
task_ablations/captioning_R_50_L1_H2048.yaml | 88.6 | 50.8 | model |
task_ablations/token_classification_R_50.yaml | 88.8 | 48.6 | model |
task_ablations/multilabel_classification_R_50.yaml | 86.2 | 46.2 | model |
task_ablations/masked_lm_R_50_L1_H2048.yaml | 86.4 | 46.7 | model |
Width Ablations
Model Config Name | VOC07 mAP |
ImageNet Top-1 Acc. |
Model URL |
width_ablations/bicaptioning_R_50_L1_H512.yaml | 88.4 | 51.8 | model |
width_ablations/bicaptioning_R_50_L1_H768.yaml | 88.3 | 52.3 | model |
width_ablations/bicaptioning_R_50_L1_H1024.yaml | 88.3 | 53.2 | model |
width_ablations/bicaptioning_R_50_L1_H2048.yaml | 88.7 | 53.8 | model |
Depth Ablations
Model Config Name | VOC07 mAP |
ImageNet Top-1 Acc. |
Model URL |
depth_ablations/bicaptioning_R_50_L1_H1024.yaml | 88.3 | 53.2 | model |
depth_ablations/bicaptioning_R_50_L2_H1024.yaml | 88.8 | 53.8 | model |
depth_ablations/bicaptioning_R_50_L3_H1024.yaml | 88.7 | 53.9 | model |
depth_ablations/bicaptioning_R_50_L4_H1024.yaml | 88.7 | 53.9 | model |