virtex.modules.embedding
- class virtex.modules.embedding.WordAndPositionalEmbedding(vocab_size: int, hidden_size: int, dropout: float = 0.0, max_caption_length: int = 30, padding_idx: int = 0)[source]
Bases:
torch.nn.modules.module.Module
A
Module
for learned word embeddings and position embeddings for input tokens. Each token is mapped to a fixed dimensional word embedding; and corresponding positional embedding based on its index. These are summed together followed by layer normalization and an optional dropout.- Parameters
vocab_size – Size of token vocabulary.
hidden_size – Size of token embedding vectors.
dropout – Probability for final dropout applied after layer normalization.
max_caption_length – Maximum length of input captions; this is used to create a fixed positional embedding lookup table.
padding_idx – Token index of
[PAD]
token, word embedding for these tokens will be a vector of zeroes (and not trainable).
- forward(tokens: torch.Tensor) torch.Tensor [source]
Get combined word and positional embeddings for input tokens.
- Parameters
tokens – A tensor of shape
(batch_size, max_caption_length)
containing a batch of caption tokens, values in[0, vocab_size)
.- Returns
A tensor of shape
(batch_size, max_caption_length, hidden_size)
containing corresponding token embeddings.