virtex.models
- class virtex.models.classification.ClassificationModel(visual: virtex.modules.visual_backbones.VisualBackbone, textual: virtex.modules.textual_heads.TextualHead, ignore_indices: List[int])[source]
Bases:
torch.nn.modules.module.Module
A model to perform classification (generally, with multiple targets). It is composed of a
VisualBackbone
and aTextualHead
on top of it.Note
As with currently available textual heads, only one textual head is supported here:
LinearTextualHead
.During training, it minimizes the KL-divergence loss with a K-hot vector, with values
1/K
, where K are the number of unique labels to classify.- Parameters
visual – A
VisualBackbone
which computes visual features from an input image.textual – A
TextualHead
which makes final predictions conditioned on visual features.ignore_indices – Ignore a set of token indices while computing KL-divergence loss. These are special tokens such as
[SOS]
,[EOS]
etc.
- forward(batch: Dict[str, torch.Tensor])[source]
Given a batch of images and set of labels, perform classification with multiple targets by minimizing a KL-divergence loss.
- Parameters
batch – A batch of images and labels. Possible set of keys:
{"image_id", "image", "labels"}
- Returns
A dict with the following structure, containing loss for optimization, loss components to log directly to tensorboard, and optionally predictions.
{ "loss": torch.Tensor, "loss_components": { "classification": torch.Tensor, }, "predictions": torch.Tensor }
- class virtex.models.classification.TokenClassificationModel(visual: virtex.modules.visual_backbones.VisualBackbone, textual: virtex.modules.textual_heads.TextualHead, ignore_indices: List[int])[source]
Bases:
virtex.models.classification.ClassificationModel
Convenient extension of
ClassificationModel
for better readability (this only modifies the tensorboard logging logic).Ground truth targets here are a set of unique caption tokens (ignoring the special tokens like
[SOS]
,[EOS]
etc.).
- class virtex.models.classification.MultiLabelClassificationModel(visual: virtex.modules.visual_backbones.VisualBackbone, textual: virtex.modules.textual_heads.TextualHead, ignore_indices: List[int])[source]
Bases:
virtex.models.classification.ClassificationModel
Convenient extension of
ClassificationModel
for better readability (this only modifies the tensorboard logging logic).Ground truth targets here are a set of unique instances in images (ignoring the special background token, category id = 0 in COCO).
- class virtex.models.captioning.CaptioningModel(visual: virtex.modules.visual_backbones.VisualBackbone, textual: virtex.modules.textual_heads.TextualHead, caption_backward: bool = False, sos_index: int = 1, eos_index: int = 2, decoder: Optional[Any] = None)[source]
Bases:
torch.nn.modules.module.Module
A model to perform image captioning (in both forward and backward directions independently, only in forward direction). It is composed of a
VisualBackbone
and aTextualHead
on top of it.During training, it maximizes the likelihood of ground truth caption conditioned on image features. During inference, it predicts a caption for an input image through beam search decoding.
- Parameters
visual – A
VisualBackbone
which computes visual features from an input image.textual – A
TextualHead
which makes final predictions conditioned on visual features.sos_index – The index of the start token (
[SOS]
) in vocabulary.eos_index – The index of the end token (
[EOS]
) in vocabulary.caption_backward – Whether to also perform captioning in backward direction. Default is
False
– only forward captioning is performed. WhenTrue
, a clone of textual head is created, which does not share weights with “forward” model except input/output embeddings.decoder – A
AutoRegressiveBeamSearch
orAutoRegressiveNucleusSampling
object for decoding captions during inference (unused during training).
- forward(batch: Dict[str, torch.Tensor]) Dict[str, Any] [source]
Given a batch of images and captions, compute log likelihood loss per caption token during training. During inference (with images), predict a caption through either beam search decoding or nucleus sampling.
- Parameters
batch – A batch of images and (optionally) ground truth caption tokens. Possible set of keys:
{"image_id", "image", "caption_tokens", "noitpac_tokens", "caption_lengths"}
.- Returns
A dict with the following structure, containing loss for optimization, loss components to log directly to tensorboard, and optionally predictions.
{ "loss": torch.Tensor, "loss_components": { "captioning_forward": torch.Tensor, "captioning_backward": torch.Tensor, (optional) }, "predictions": torch.Tensor }
- decoding_step(visual_features: torch.Tensor, partial_captions: torch.Tensor) torch.Tensor [source]
Given visual features and a batch of (assumed) partial captions, predict the logits over output vocabulary tokens for next timestep. This method is used by
AutoRegressiveBeamSearch
andAutoRegressiveNucleusSampling
.Note
For nucleus sampling,
beam_size
will always be 1 (not relevant).- Parameters
projected_visual_features – A tensor of shape
(batch_size, ..., textual_feature_size)
with visual features already projected totextual_feature_size
.partial_captions – A tensor of shape
(batch_size * beam_size, timesteps)
containing tokens predicted so far – one for each beam. We need all prior predictions because our model is auto-regressive.
- Returns
A tensor of shape
(batch_size * beam_size, vocab_size)
– logits over output vocabulary tokens for next timestep.
- class virtex.models.captioning.ForwardCaptioningModel(visual: virtex.modules.visual_backbones.VisualBackbone, textual: virtex.modules.textual_heads.TextualHead, sos_index: int = 1, eos_index: int = 2, decoder: Optional[Any] = None)[source]
Bases:
virtex.models.captioning.CaptioningModel
Convenient extension of
CaptioningModel
for better readability: this passescaption_backward=False
to super class.
- class virtex.models.captioning.BidirectionalCaptioningModel(visual: virtex.modules.visual_backbones.VisualBackbone, textual: virtex.modules.textual_heads.TextualHead, sos_index: int = 1, eos_index: int = 2, decoder: Optional[Any] = None)[source]
Bases:
virtex.models.captioning.CaptioningModel
Convenient extension of
CaptioningModel
for better readability: this passescaption_backward=True
to super class.
- virtex.models.captioning.VirTexModel[source]
alias of
virtex.models.captioning.BidirectionalCaptioningModel
- class virtex.models.masked_lm.MaskedLMModel(visual: virtex.modules.visual_backbones.VisualBackbone, textual: virtex.modules.textual_heads.TextualHead)[source]
Bases:
torch.nn.modules.module.Module
A model to perform BERT-like masked language modeling. It is composed of a
VisualBackbone
and aTextualHead
on top of it.During training, the model received caption tokens with certain tokens replaced by
[MASK]
token, and it predicts these masked tokens based on surrounding context.- Parameters
visual – A
VisualBackbone
which computes visual features from an input image.textual – A
TextualHead
which makes final predictions conditioned on visual features.
- forward(batch: Dict[str, torch.Tensor]) Dict[str, Any] [source]
Given a batch of images and captions with certain masked tokens, predict the tokens at masked positions.
- Parameters
batch – A batch of images, ground truth caption tokens and masked labels. Possible set of keys:
{"image_id", "image", "caption_tokens", "masked_labels", "caption_lengths"}
.- Returns
A dict with the following structure, containing loss for optimization, loss components to log directly to tensorboard, and optionally predictions.
{ "loss": torch.Tensor, "loss_components": {"masked_lm": torch.Tensor}, "predictions": torch.Tensor }