Skip to content

Commit

Permalink
📚 Add documentation for Packed Transformer Layers
Browse files Browse the repository at this point in the history
  • Loading branch information
alafage committed Jan 6, 2025
1 parent f484403 commit a09fdd9
Show file tree
Hide file tree
Showing 4 changed files with 275 additions and 25 deletions.
4 changes: 4 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ Ensemble layers

PackedLinear
PackedConv2d
PackedMultiheadAttention
PackedLayerNorm
PackedTransformerEncoderLayer
PackedTransformerDecoderLayer
BatchLinear
BatchConv2d
MaskedLinear
Expand Down
11 changes: 10 additions & 1 deletion torch_uncertainty/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,13 @@
from .channel_layer_norm import ChannelLayerNorm
from .masksembles import MaskedConv2d, MaskedLinear
from .modules import Identity
from .packed import PackedConv1d, PackedConv2d, PackedConv3d, PackedLinear
from .packed import (
PackedConv1d,
PackedConv2d,
PackedConv3d,
PackedLayerNorm,
PackedLinear,
PackedMultiheadAttention,
PackedTransformerDecoderLayer,
PackedTransformerEncoderLayer,
)
248 changes: 230 additions & 18 deletions torch_uncertainty/layers/packed.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,23 +566,6 @@ def bias(self) -> Tensor | None:


class PackedLayerNorm(nn.GroupNorm):
"""Packed-Ensembles-style LayerNorm layer.
Args:
embed_dim (int): the number of features in the input tensor.
num_estimators (int): the number of estimators in the ensemble.
alpha (float): the width multiplier of the layer.
eps (float, optional): a value added to the denominator for numerical stability. Defaults
to 1e-5.
affine (bool, optional): a boolean value that when set to ``True``, this module has
learnable per_channel affine parameters initialized to ones (for weights) and zeros
(for biases). Defaults to ``True``.
Shape:
- Input: :math:`(N, *)` where :math:`*` means any number of additional dimensions.
- Output: :math:`(N, *)` (same shape as input)
"""

def __init__(
self,
embed_dim: int,
Expand All @@ -593,6 +576,26 @@ def __init__(
device=None,
dtype=None,
) -> None:
r"""Packed-Ensembles-style LayerNorm layer.
Args:
embed_dim (int): the number of features in the input tensor.
num_estimators (int): the number of estimators in the ensemble.
alpha (float): the width multiplier of the layer.
eps (float, optional): a value added to the denominator for numerical stability. Defaults
to 1e-5.
affine (bool, optional): a boolean value that when set to ``True``, this module has
learnable per_channel affine parameters initialized to ones (for weights) and zeros
(for biases). Defaults to ``True``.
device (torch.device, optional): the device to use for the layer's parameters. Defaults
to ``None``.
dtype (torch.dtype, optional): the dtype to use for the layer's parameters. Defaults to
``None``.
Shape:
- Input: :math:`(B, *)` where :math:`*` means any number of additional dimensions.
- Output: :math:`(B, *)` (same shape as input)
"""
super().__init__(
num_groups=num_estimators,
num_channels=int(embed_dim * alpha),
Expand Down Expand Up @@ -638,6 +641,42 @@ def __init__(
device=None,
dtype=None,
) -> None:
r"""Packed-Ensembles-style MultiheadAttention layer.
Args:
embed_dim (int): Size of the embedding dimension.
num_heads (int): Number of parallel attention heads.
alpha (float): The width multiplier of the embedding dimension.
num_estimators (int): The number of estimators packed in the layer.
gamma (int, optional): Defaults to ``1``.
dropout (float, optional): Dropout probability on ``attn_output_weights``. Defaults to ``0.0``
(no dropout).
bias (bool, optional): Ì specified, adds bias to input / output projection layers.
Defaults to ``True``.
add_bias_kv (bool, optional): If specified, adds bias to the key and value sequences at
``dim=0``. Defaults to ``False``.
add_zero_attn (bool, optional): If specified, adds a new batch of zeros to the key and
value sequences at ``dim=1``. Defaults to ``False``.
kdim (int | None, optional): Total number of features for keys. Defaults to ``None``
(uses ``kdim=embed_dim``).
vdim (int | None, optional): Total number of features for values. Defaults to ``None``
(uses ``vdim=embed_dim``).
batch_first (bool, optional): If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Defaults to ``False`` (seq, batch, feature).
first (bool, optional): Whether this is the first layer of the network. Defaults to
``False``.
last (bool, optional): Whether this is the last layer of the network. Defaults to
``False``.
device (torch.device, optional): The device to use for the layer's parameters. Defaults
to ``None``.
dtype (torch.dtype, optional): The dtype to use for the layer's parameters. Defaults to
``None``.
Reference:
- `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_: Original Multihead Attention formulation.
- `Hierarchical Light Tranformer Ensembles for Multimodal Trajectory Forecasting <https://arxiv.org/abs/2403.17678>`_
: Packed-Ensembles-style Multihead Attention formulation.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()

Expand Down Expand Up @@ -765,7 +804,61 @@ def forward(
attn_mask: Tensor | None = None,
average_attn_weights: bool = True,
is_causal: bool = False,
) -> tuple[Tensor, Tensor | None]:
) -> tuple[Tensor, None]:
r"""Computes attention outputs given query, key, and value tensors.
Args:
query (Tensor): Query embeddings of shape :math:`(L, E_q)` for unbatched input,
:math:`(L, B, E_q)` when ``batch_first=False`` or :math:`(B, L, E_q)` when
``batch_first=True``, where :math:`L` is the target sequence length, :math:`B` is
the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
key (Tensor): Key embeddingd of shape :math:`(S, E_k)` for unbatched input,
:math:`(S, B, E_k)` when ``batch_first=False`` or :math:`(B, S, E_k)` when
``batch_first=True``, where :math:`S` is the source sequence length, :math:`B` is
the batch size and :math:`E_k` is the key embedding dimension ``kdim``.
value (Tensor): Value embeddings of shape :math:`(S, E_v)` for unbatched input,
:math:`(S, B, E_v)` when ``batch_first=False`` or :math:`(B, S, E_v)` when
``batch_first=True``, where :math:`S` is the source sequence length, :math:`B` is
the batch size and :math:`E_v` is the value embedding dimension ``vdim``.
key_padding_mask (Tensor | None, optional): If specified, a mask of shape
:math:`(B, S)` indicating which elements within ``key`` to ignore for the purpose
of attention (i.e. treat as "padding"). For unbatched `query`, shape should be
:math:`(S)`. Binary and float masks are supported. For a binary mask, a ``True``
value indicates that the corresponding ``key`` value will be ignored for the
purpose of attention. For a float mask, it will be directly added to the
corresponding ``key`` value. Defaults to ``None``.
need_weights (bool, optional): If specified, returns ``attn_output_weights`` in
addition to ``attn_outputs``. Set ``need_weights=False`` to use the optimized
``scale_dot_product_attention`` and achieve the best performance for MHA.
Defaults to ``False``.
attn_mask (Tensor | None, optional): If specified, a 2D or 3D mask preventing attention
to certain positions. Must be of shape :math:`(L,S)` or
:math:`(B \times \text{num_heads}, L, S)`, where :math:`B` is the batch size, :math:`L`
is the target sequence length, and :math:`S` is the source sequence length. A 2D mask
will be broadcasted across the batch while a 3D mask allows for a different mask for
each entry in the batch. Binary and float masks are supported. For a binary mask, a
``True`` value indicates that the corresponding position is not allowed to attend to.
For a float mask, the mask values will be added to the attention weight. If both
``attn_mask`` and ``key_padding_mask`` are provided, their types should match.
Defaults to ``None``.
average_attn_weights (bool, optional): If ``True``, indicates that the returned
``attn_weights`` should be averaged across heads. Otherwise, ``attn_weights`` are
provided separately per head. Note that this flag only has an effect when
``need_weights=True``. Defaults to ``True``.
is_causal (bool, optional): _description_. Defaults to ``False``.
Warning:
``need_weights=True`` and therefore ``average_attn_weights`` are not supported yet thus
have no effect.
Returns:
tuple[Tensor, None]:
- *attn_output* (Tensor): The output tensor of shape :math:`(L, E_q)`, :math:`(L, B, E_q)`
or :math:`(B, L, E_q)` where :math:`L` is the target sequence length, :math:`B` is
the batch size, and :math:`E_q` is the embedding dimension ``embed_dim``.
- *attn_output_weights* (None): Always ``None`` has we do not support
``need_weights=True`` yet.
"""
is_batched = query.dim() == 3

key_padding_mask = F._canonical_mask(
Expand Down Expand Up @@ -879,6 +972,44 @@ def __init__(
device=None,
dtype=None,
) -> None:
r"""Packed-Ensembles-style TransformerEncoderLayer (made up of self-attention followed by a
feedforward network).
Args:
d_model (int): the number of expected features in the input.
nhead (int): the number of heads in the multiheadattention models.
alpha (float): the width multiplier of the layer.
num_estimators (int): the number of estimators packed in the layer.
gamma (int, optional): Defaults to ``1``.
dim_feedforward (int, optional): the dimension of the feedforward network model. Defaults
to ``2048``.
dropout (float, optional): the dropout value. Defaults to ``0.1``.
activation (Callable[[Tensor], Tensor], optional): the activation function of the
intermediate layer, that is a unary callable. Defaults to ``F.relu``.
layer_norm_eps (float, optional): the eps value in layer normalization components. Defaults
to ``1e-5``.
bias (bool, optional): If ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an
additive bias. Defaults to ``True``.
batch_first (bool, optional): If ``True``, then the input and output tensors are provided
as :math:`(\text{batch}, \text{seq}, \text{d_model})`. Defaults to ``False``
:math:`(\text{seq}, \text{batch}, \text{d_model})`.
norm_first (bool, optional): If ``True``, the layer norm is done prior to attention and
feedforward operations, respectively. Otherwise, it is done after. Defaults to
``False``.
first (bool, optional): Whether this is the first layer of the network. Defaults to
``False``.
last (bool, optional): Whether this is the last layer of the network. Defaults to
``False``.
device (torch.device, optional): The device to use for the layer's parameters. Defaults
to ``None``.
dtype (torch.dtype, optional): The dtype to use for the layer's parameters. Defaults to
``None``.
Reference:
- `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_: Original Multihead Attention formulation.
- `Hierarchical Light Tranformer Ensembles for Multimodal Trajectory Forecasting <https://arxiv.org/abs/2403.17678>`_
: Packed-Ensembles-style Multihead Attention formulation.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()

Expand Down Expand Up @@ -965,6 +1096,22 @@ def forward(
src_key_padding_mask: Tensor | None = None,
is_causal: bool = False,
) -> Tensor:
r"""Pass the input through the encoder layer.
Args:
src (Tensor): The sequence to the encoder layer. Shape: :math:`(B, L, E)` or
:math:`(L, B, E)`.
src_mask (Tensor | None, optional): The mask for the ``src`` sequence. Defaults to ``None``.
src_key_padding_mask (Tensor | None, optional): The mask for the ``src`` keys per
batch. Defaults to ``None``.
is_causal (bool, optional): If specified, applies a causal mask as ``src_mask``.
Defaults to ``False``. Warning: ``is_causal`` provides a hint the ``src_mask`` is
a causal mask. Providing incorrect hints can result in incorrect execution,
including forward and backward compatibility.
Returns:
Tensor: The output of the encoder layer. Shape: :math:`(B, L, E)` or :math:`(L, B, E)`.
"""
src_key_padding_mask = F._canonical_mask(
mask=src_key_padding_mask,
mask_name="src_key_padding_mask",
Expand Down Expand Up @@ -1045,6 +1192,44 @@ def __init__(
device=None,
dtype=None,
) -> None:
r"""Packed-Ensembles-style TransformerDecoderLayer (made up of self-attention, multi-head
attention, and feedforward network).
Args:
d_model (int): the number of expected features in the input.
nhead (int): the number of heads in the multiheadattention models.
alpha (float): the width multiplier of the layer.
num_estimators (int): the number of estimators packed in the layer.
gamma (int, optional): Defaults to ``1``.
dim_feedforward (int, optional): the dimension of the feedforward network model. Defaults
to ``2048``.
dropout (float, optional): the dropout value. Defaults to ``0.1``.
activation (Callable[[Tensor], Tensor], optional): the activation function of the
intermediate layer, that is a unary callable. Defaults to ``F.relu``.
layer_norm_eps (float, optional): the eps value in layer normalization components. Defaults
to ``1e-5``.
bias (bool, optional): If ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an
additive bias. Defaults to ``True``.
batch_first (bool, optional): If ``True``, then the input and output tensors are provided
as :math:`(\text{batch}, \text{seq}, \text{d_model})`. Defaults to ``False``
:math:`(\text{seq}, \text{batch}, \text{d_model})`.
norm_first (bool, optional): If ``True``, the layer norm is done prior to attention and
feedforward operations, respectively. Otherwise, it is done after. Defaults to
``False``.
first (bool, optional): Whether this is the first layer of the network. Defaults to
``False``.
last (bool, optional): Whether this is the last layer of the network. Defaults to
``False``.
device (torch.device, optional): The device to use for the layer's parameters. Defaults
to ``None``.
dtype (torch.dtype, optional): The dtype to use for the layer's parameters. Defaults to
``None``.
Reference:
- `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_: Original Multihead Attention formulation.
- `Hierarchical Light Tranformer Ensembles for Multimodal Trajectory Forecasting <https://arxiv.org/abs/2403.17678>`_
: Packed-Ensembles-style Multihead Attention formulation.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()

Expand Down Expand Up @@ -1156,6 +1341,33 @@ def forward(
tgt_is_causal: bool = False,
memory_is_causal: bool = False,
) -> Tensor:
r"""Pass the input (and mask) through the decoder layer.
Args:
tgt (Tensor): The sequence to the decoder layer. Shape: :math:`(B, L, E)` or
:math:`(L, B, E)`.
memory (Tensor): The sequence from the last layer of the encoder. Shape:
:math:`(B, S, E)` or :math:`(S, B, E)`.
tgt_mask (Tensor | None, optional): The mask for the ``tgt`` sequence. Defaults to
``None``.
memory_mask (Tensor | None, optional): The mask for the ``memory`` sequence. Defaults
to ``None``.
tgt_key_padding_mask (Tensor | None, optional): The mask for the ``tgt`` keys per
batch. Defaults to ``None``.
memory_key_padding_mask (Tensor | None, optional): The mask for the ``memory`` keys per
batch. Defaults to ``None``.
tgt_is_causal (bool, optional): If specified, applies a causal mask as ``tgt_mask``.
Defaults to ``False``. Warning: ``tgt_is_causal`` provides a hint the ``tgt_mask``
is a causal mask. Providing incorrect hints can result in incorrect execution,
including forward and backward compatibility.
memory_is_causal (bool, optional): If specified, applies a causal mask as ``memory_mask``.
Defaults to ``False``. Warning: ``memory_is_causal`` provides a hint the ``memory_mask``
is a causal mask. Providing incorrect hints can result in incorrect execution,
including forward and backward compatibility.
Returns:
Tensor: The output of the encoder layer. Shape: :math:`(B, L, E)` or :math:`(L, B, E)`.
"""
x = tgt
if self.norm_first:
x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal)
Expand Down
Loading

0 comments on commit a09fdd9

Please sign in to comment.