virtex.models


class virtex.models.classification.ClassificationModel(visual: virtex.modules.visual_backbones.VisualBackbone, textual: virtex.modules.textual_heads.TextualHead, ignore_indices: List[int])[source]

Bases: torch.nn.modules.module.Module

A model to perform classification (generally, with multiple targets). It is composed of a VisualBackbone and a TextualHead on top of it.

Note

As with currently available textual heads, only one textual head is supported here: LinearTextualHead.

During training, it minimizes the KL-divergence loss with a K-hot vector, with values 1/K, where K are the number of unique labels to classify.

Parameters
  • visual – A VisualBackbone which computes visual features from an input image.

  • textual – A TextualHead which makes final predictions conditioned on visual features.

  • ignore_indices – Ignore a set of token indices while computing KL-divergence loss. These are special tokens such as [SOS], [EOS] etc.

forward(batch: Dict[str, torch.Tensor])[source]

Given a batch of images and set of labels, perform classification with multiple targets by minimizing a KL-divergence loss.

Parameters

batch – A batch of images and labels. Possible set of keys: {"image_id", "image", "labels"}

Returns

A dict with the following structure, containing loss for optimization, loss components to log directly to tensorboard, and optionally predictions.

{
    "loss": torch.Tensor,
    "loss_components": {
        "classification": torch.Tensor,
    },
    "predictions": torch.Tensor
}

class virtex.models.classification.TokenClassificationModel(visual: virtex.modules.visual_backbones.VisualBackbone, textual: virtex.modules.textual_heads.TextualHead, ignore_indices: List[int])[source]

Bases: virtex.models.classification.ClassificationModel

Convenient extension of ClassificationModel for better readability (this only modifies the tensorboard logging logic).

Ground truth targets here are a set of unique caption tokens (ignoring the special tokens like [SOS], [EOS] etc.).

class virtex.models.classification.MultiLabelClassificationModel(visual: virtex.modules.visual_backbones.VisualBackbone, textual: virtex.modules.textual_heads.TextualHead, ignore_indices: List[int])[source]

Bases: virtex.models.classification.ClassificationModel

Convenient extension of ClassificationModel for better readability (this only modifies the tensorboard logging logic).

Ground truth targets here are a set of unique instances in images (ignoring the special background token, category id = 0 in COCO).


class virtex.models.captioning.CaptioningModel(visual: virtex.modules.visual_backbones.VisualBackbone, textual: virtex.modules.textual_heads.TextualHead, caption_backward: bool = False, sos_index: int = 1, eos_index: int = 2, decoder: Optional[Any] = None)[source]

Bases: torch.nn.modules.module.Module

A model to perform image captioning (in both forward and backward directions independently, only in forward direction). It is composed of a VisualBackbone and a TextualHead on top of it.

During training, it maximizes the likelihood of ground truth caption conditioned on image features. During inference, it predicts a caption for an input image through beam search decoding.

Parameters
  • visual – A VisualBackbone which computes visual features from an input image.

  • textual – A TextualHead which makes final predictions conditioned on visual features.

  • sos_index – The index of the start token ([SOS]) in vocabulary.

  • eos_index – The index of the end token ([EOS]) in vocabulary.

  • caption_backward – Whether to also perform captioning in backward direction. Default is False – only forward captioning is performed. When True, a clone of textual head is created, which does not share weights with “forward” model except input/output embeddings.

  • decoder – A AutoRegressiveBeamSearch or AutoRegressiveNucleusSampling object for decoding captions during inference (unused during training).

forward(batch: Dict[str, torch.Tensor]) Dict[str, Any][source]

Given a batch of images and captions, compute log likelihood loss per caption token during training. During inference (with images), predict a caption through either beam search decoding or nucleus sampling.

Parameters

batch – A batch of images and (optionally) ground truth caption tokens. Possible set of keys: {"image_id", "image", "caption_tokens", "noitpac_tokens", "caption_lengths"}.

Returns

A dict with the following structure, containing loss for optimization, loss components to log directly to tensorboard, and optionally predictions.

{
    "loss": torch.Tensor,
    "loss_components": {
        "captioning_forward": torch.Tensor,
        "captioning_backward": torch.Tensor, (optional)
    },
    "predictions": torch.Tensor
}

decoding_step(visual_features: torch.Tensor, partial_captions: torch.Tensor) torch.Tensor[source]

Given visual features and a batch of (assumed) partial captions, predict the logits over output vocabulary tokens for next timestep. This method is used by AutoRegressiveBeamSearch and AutoRegressiveNucleusSampling.

Note

For nucleus sampling, beam_size will always be 1 (not relevant).

Parameters
  • projected_visual_features – A tensor of shape (batch_size, ..., textual_feature_size) with visual features already projected to textual_feature_size.

  • partial_captions – A tensor of shape (batch_size * beam_size, timesteps) containing tokens predicted so far – one for each beam. We need all prior predictions because our model is auto-regressive.

Returns

A tensor of shape (batch_size * beam_size, vocab_size) – logits over output vocabulary tokens for next timestep.

class virtex.models.captioning.ForwardCaptioningModel(visual: virtex.modules.visual_backbones.VisualBackbone, textual: virtex.modules.textual_heads.TextualHead, sos_index: int = 1, eos_index: int = 2, decoder: Optional[Any] = None)[source]

Bases: virtex.models.captioning.CaptioningModel

Convenient extension of CaptioningModel for better readability: this passes caption_backward=False to super class.

class virtex.models.captioning.BidirectionalCaptioningModel(visual: virtex.modules.visual_backbones.VisualBackbone, textual: virtex.modules.textual_heads.TextualHead, sos_index: int = 1, eos_index: int = 2, decoder: Optional[Any] = None)[source]

Bases: virtex.models.captioning.CaptioningModel

Convenient extension of CaptioningModel for better readability: this passes caption_backward=True to super class.

virtex.models.captioning.VirTexModel[source]

alias of virtex.models.captioning.BidirectionalCaptioningModel


class virtex.models.masked_lm.MaskedLMModel(visual: virtex.modules.visual_backbones.VisualBackbone, textual: virtex.modules.textual_heads.TextualHead)[source]

Bases: torch.nn.modules.module.Module

A model to perform BERT-like masked language modeling. It is composed of a VisualBackbone and a TextualHead on top of it.

During training, the model received caption tokens with certain tokens replaced by [MASK] token, and it predicts these masked tokens based on surrounding context.

Parameters
  • visual – A VisualBackbone which computes visual features from an input image.

  • textual – A TextualHead which makes final predictions conditioned on visual features.

forward(batch: Dict[str, torch.Tensor]) Dict[str, Any][source]

Given a batch of images and captions with certain masked tokens, predict the tokens at masked positions.

Parameters

batch – A batch of images, ground truth caption tokens and masked labels. Possible set of keys: {"image_id", "image", "caption_tokens", "masked_labels", "caption_lengths"}.

Returns

A dict with the following structure, containing loss for optimization, loss components to log directly to tensorboard, and optionally predictions.

{
    "loss": torch.Tensor,
    "loss_components": {"masked_lm": torch.Tensor},
    "predictions": torch.Tensor
}