Description
The dot product attention can be formulated as:
Above,
-
$Q$ has a shape of$[\text{bs} \times N \times D]$ - where
$\text{bs}$ is the batch size,$N$ is the max target sequence length in a mini-batch (sequences in a mini-batch are all padded to having the same length), and$D$ is the hidden size.
- where
-
$K$ has a shape of$[\text{bs} \times M \times D]$ - where
$\text{bs}$ is the batch size,$M$ is the max source sequence length in a mini-batch (sequences in a mini-batch are all padded to having the same length), and$D$ is the hidden size.
- where
-
$V$ has a shape$[\text{bs} \times M \times D]$ .- where
$\text{bs}$ is the batch size,$M$ is the max source sequence length in a mini-batch (sequences in a mini-batch are all padded to having the same length), and$D$ is the hidden size.
- where
With the above notation in hand, we have:
-
Suppose
$W = \text{softmax}(QK^T)$ .$W$ is the attention weight with a shape$[\text{bs} \times N \times M]$ . -
Suppose
$C = WV$ is the context vector with a shape$[\text{bs} \times N \times D]$ .
From the above computation, to use batched matrix multiplication (potentially can be optimized to achieve a better computation efficiency?), each source sequence and the target sequence in one mini-batch have to have the same length (length of source sentence and length of target sentence can be different).
This requires padding sequences in one mini-batch to have the same length:
- The padding had to be fixed to zeros so that it does not affect the softmax normalization.
- The padding should not be changed during training and it does not need gradients.
Torch implements this by making padding_idx
a special token for the look_up_table_op
: http://pytorch.org/docs/0.3.0/nn.html?highlight=embedding#torch.nn.Embedding .
Maybe it can also be implemented as a mask.
The other side:
- If we do not pad sequences in the mini-batch to have the same length, the dot-product attention have to be computed in a
for
loop. I am not sure about the differences in computation speed between padding and no padding.