virtex.utils.beam_search
This Beam Search implementation is adapted with minor modifications from AllenNLP.
Thanks to the developers of AllenNLP!
Update (v1.2): The “backpointer” trick in Beam Search (as implemented in AllenNLP) does not work well with autoregressive models (transformers). It is now removed and it improves qualitative predictions and captioning metrics (CIDEr/SPICE) for VirTex. Updated captioning results are on ArXiv v3. Refer CHANGELOG and Release Page for more details.
Huge thanks to Nicolas Carion (@alcinos) and Aishwarya Kamath (@ashkamath) for helping me fix this bug!
- class virtex.utils.beam_search.AutoRegressiveBeamSearch(eos_index: int, max_steps: int = 50, beam_size: int = 5, per_node_beam_size: int = 2)[source]
Bases:
object
Implements the beam search algorithm for decoding the most likely captions.
- Parameters
eos_index – The index of the end token (
[EOS]
) in vocabulary.max_steps – The maximum number of decoding steps.
beam_size – The width of the beam used.
per_node_beam_size – The maximum number of candidates to consider per node, at each step in the search. Setting this parameter to a number smaller than
beam_size
may give better results, as it can introduce more diversity into the search. See Beam Search Strategies for Neural Machine Translation. Freitag and Al-Onaizan, 2017.
- search(start_predictions: torch.Tensor, step: Callable[[...], torch.Tensor], only_return_best: bool = True) Tuple[torch.Tensor, torch.Tensor] [source]
Given a starting state and a step function, apply beam search to find the most likely target captions.
- Parameters
start_predictions – Tensor containing the initial predictions, shape
(batch_size, )
. Usually the initial predictions are just the index of the start token ([SOS]
) in the vocabulary.step – A function that is responsible for computing the next most likely tokens, given the past predictions. Predictions from all previous timesteps are required, not just the last timestep. The function is expected to return a tensor of shape
(group_size, target_vocab_size)
containing the token logits for the next step.only_return_best – Whether to only return the best beam (with highest logprobs). Set this to
False
to return all the beams. If this isTrue
, then the returned tensor is of shape(batch_size, sequence_length)
, else will be(batch_size, beam_size, sequence_length)
.
- Returns
Tuple of
(predictions, logprobs)
, wherepredictions
has shape(batch_size, beam_size, max_steps)
andlogprobs
has shape(batch_size, beam_size)
.