VirTex Model Zoo

We provide a collection of pretrained model weights and corresponding config names in this model zoo. Tables contain partial paths to config files for each model, download link for pretrained weights and for reference – VOC07 mAP and ImageNet top-1 accuracy.

The simplest way to download and use a full pretrained model (including both, the visual backbone and the textual head) is through virtex.model_zoo API as follows. This code snippet works from anywhere, and does not require to be executed from project root.

# Get our full best performing VirTex model:
import virtex.model_zoo as mz
model = mz.get("width_ablations/bicaptioning_R_50_L1_H2048.yaml", pretrained=True)

# Optionally extract the torchvision-like visual backbone (with ``avgpool``
# and ``fc`` layers replaced with ``nn.Identity`` module).
cnn = model.visual.cnn

Alternatively, weights can be manually downloaded from links below, and this can be executed from the project root:

from virtex.config import Config
from virtex.factories import PretrainingModelFactory
from virtex.utils.checkpointing import CheckpointManager

# Get the best performing VirTex model:
_C = Config("configs/width_ablations/bicaptioning_R_50_L1_H2048.yaml")
model = PretrainingModelFactory.from_config(_C)

CheckpointManager(model=model).load("/path/to/downloaded/weights.pth")

# Optionally extract the torchvision-like visual backbone (with ``avgpool``
# and ``fc`` layers replaced with ``nn.Identity`` module).
cnn = model.visual.cnn

The pretrained ResNet-50 visual backbone of our best performing model (width_ablations/bicaptioning_R_50_L1_H2048.yaml) can be loaded in a single line, without following any installation steps (only requires PyTorch v1.5):

import torch

model = torch.hub.load("kdexd/virtex", "resnet50", pretrained=True)

# This is a torchvision-like resnet50 model, with ``avgpool`` and ``fc``
# layers replaced with ``nn.Identity`` module.
image_batch = torch.randn(1, 3, 224, 224)  # batch tensor of one image.
features_batch = model(image_batch)  # shape: (1, 2048, 7, 7)

Pretraining Task Ablations

Model Config Name VOC07
mAP
ImageNet
Top-1 Acc.
Model URL
task_ablations/bicaptioning_R_50_L1_H2048.yaml 88.7 53.8 model
task_ablations/captioning_R_50_L1_H2048.yaml 88.6 50.8 model
task_ablations/token_classification_R_50.yaml 88.8 48.6 model
task_ablations/multilabel_classification_R_50.yaml 86.2 46.2 model
task_ablations/masked_lm_R_50_L1_H2048.yaml 86.4 46.7 model

Width Ablations

Model Config Name VOC07
mAP
ImageNet
Top-1 Acc.
Model URL
width_ablations/bicaptioning_R_50_L1_H512.yaml 88.4 51.8 model
width_ablations/bicaptioning_R_50_L1_H768.yaml 88.3 52.3 model
width_ablations/bicaptioning_R_50_L1_H1024.yaml 88.3 53.2 model
width_ablations/bicaptioning_R_50_L1_H2048.yaml 88.7 53.8 model

Depth Ablations

Model Config Name VOC07
mAP
ImageNet
Top-1 Acc.
Model URL
depth_ablations/bicaptioning_R_50_L1_H1024.yaml 88.3 53.2 model
depth_ablations/bicaptioning_R_50_L2_H1024.yaml 88.8 53.8 model
depth_ablations/bicaptioning_R_50_L3_H1024.yaml 88.7 53.9 model
depth_ablations/bicaptioning_R_50_L4_H1024.yaml 88.7 53.9 model

Backbone Ablations

Model Config Name VOC07
mAP
ImageNet
Top-1 Acc.
Model URL
backbone_ablations/bicaptioning_R_50_L1_H1024.yaml 88.3 53.2 model
backbone_ablations/bicaptioning_R_50W2X_L1_H1024.yaml 88.5 52.9 model
backbone_ablations/bicaptioning_R_101_L1_H1024.yaml 88.7 52.1 model