virtex
1.4
  • How to setup this codebase?
  • VirTex Model Zoo
  • How to train your VirTex model?
  • How to evaluate on downstream tasks?
  • virtex.config
  • virtex.factories
  • virtex.data
  • virtex.models
  • virtex.modules
  • virtex.optim
  • virtex.utils
    • virtex.utils.common
    • virtex.utils.distributed
    • virtex.utils.timer
    • virtex.utils.checkpointing
    • virtex.utils.beam_search
    • virtex.utils.metrics
  • virtex.model_zoo
virtex
  • »
  • virtex.utils »
  • virtex.utils.beam_search
  • View page source

virtex.utils.beam_search


This Beam Search implementation is adapted with minor modifications from AllenNLP.

Thanks to the developers of AllenNLP!

Update (v1.2): The “backpointer” trick in Beam Search (as implemented in AllenNLP) does not work well with autoregressive models (transformers). It is now removed and it improves qualitative predictions and captioning metrics (CIDEr/SPICE) for VirTex. Updated captioning results are on ArXiv v3. Refer CHANGELOG and Release Page for more details.

Huge thanks to Nicolas Carion (@alcinos) and Aishwarya Kamath (@ashkamath) for helping me fix this bug!

class virtex.utils.beam_search.AutoRegressiveBeamSearch(eos_index: int, max_steps: int = 50, beam_size: int = 5, per_node_beam_size: int = 2)[source]

Bases: object

Implements the beam search algorithm for decoding the most likely captions.

Parameters
  • eos_index – The index of the end token ([EOS]) in vocabulary.

  • max_steps – The maximum number of decoding steps.

  • beam_size – The width of the beam used.

  • per_node_beam_size – The maximum number of candidates to consider per node, at each step in the search. Setting this parameter to a number smaller than beam_size may give better results, as it can introduce more diversity into the search. See Beam Search Strategies for Neural Machine Translation. Freitag and Al-Onaizan, 2017.

search(start_predictions: torch.Tensor, step: Callable[[...], torch.Tensor], only_return_best: bool = True) → Tuple[torch.Tensor, torch.Tensor][source]

Given a starting state and a step function, apply beam search to find the most likely target captions.

Parameters
  • start_predictions – Tensor containing the initial predictions, shape (batch_size, ). Usually the initial predictions are just the index of the start token ([SOS]) in the vocabulary.

  • step – A function that is responsible for computing the next most likely tokens, given the past predictions. Predictions from all previous timesteps are required, not just the last timestep. The function is expected to return a tensor of shape (group_size, target_vocab_size) containing the token logits for the next step.

  • only_return_best – Whether to only return the best beam (with highest logprobs). Set this to False to return all the beams. If this is True, then the returned tensor is of shape (batch_size, sequence_length), else will be (batch_size, beam_size, sequence_length).

Returns

Tuple of (predictions, logprobs), where predictions has shape (batch_size, beam_size, max_steps) and logprobs has shape (batch_size, beam_size).

Previous Next

© Copyright 2021, Karan Desai and Justin Johnson.

Built with Sphinx using a theme provided by Read the Docs.