Skip to content

Commit

Permalink
🐛 Fix packed_linear functional when bias is None
Browse files Browse the repository at this point in the history
+improve coverage
  • Loading branch information
alafage committed Dec 26, 2024
1 parent ec2d921 commit 1a53be2
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 58 deletions.
68 changes: 41 additions & 27 deletions tests/layers/test_packed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
PackedConv1d,
PackedConv2d,
PackedConv3d,
PackedLayerNorm,
PackedLinear,
PackedMultiheadAttention,
)
Expand Down Expand Up @@ -46,30 +47,6 @@ def voxels_input() -> torch.Tensor:
return torch.rand((5, 6, 3, 3, 3))


@pytest.fixture()
def unbatched_sequence() -> torch.Tensor:
return torch.rand((3, 6)) # (L, Hin)


@pytest.fixture()
def batched_sequence() -> torch.Tensor:
return torch.rand((2, 3, 6)) # (B, L, Hin)


@pytest.fixture()
def unbatched_sequences() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return torch.rand((3, 6)), torch.rand((4, 2)), torch.rand((4, 4)) # (L, Eq), (S, Ek), (S, Ev)


@pytest.fixture()
def batched_sequences() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return (
torch.rand((2, 3, 6)),
torch.rand((2, 4, 2)),
torch.rand((2, 4, 4)),
) # (B, L, Eq), (B, S, Ek), (B, S, Ev)


@pytest.fixture()
def unbatched_qkv() -> torch.Tensor:
return torch.rand((3, 6))
Expand Down Expand Up @@ -336,9 +313,18 @@ def test_conv3_failures(self):
_ = PackedConv3d(5, 2, kernel_size=1, alpha=1, num_estimators=1, gamma=-1)


class TestPackedGroupNorm:
class TestPackedLayerNorm:
"""Testing the PackedGroupNorm layer class."""

def test_one_estimator_forward(self, batched_qkv: torch.Tensor):
packed_layer_norm = PackedLayerNorm(
embed_dim=6,
num_estimators=1,
alpha=1,
)
out = packed_layer_norm(batched_qkv)
assert out.shape == torch.Size([2, 3, 6])


class TestPackedMultiheadAttention:
"""Testing the PackedMultiheadAttention layer class."""
Expand Down Expand Up @@ -371,6 +357,7 @@ def test_one_estimator_qkv(self, unbatched_qkv: torch.Tensor, batched_qkv: torch
alpha=1,
num_estimators=1,
batch_first=True,
bias=False,
)
out, _ = layer(
query=batched_qkv,
Expand All @@ -379,14 +366,19 @@ def test_one_estimator_qkv(self, unbatched_qkv: torch.Tensor, batched_qkv: torch
)
assert out.shape == torch.Size([2, 3, 6])

def test_one_estimator_q_kv(self, unbatched_q_kv: torch.Tensor, batched_q_kv: torch.Tensor):
def test_one_estimator_q_kv(
self,
unbatched_q_kv: tuple[torch.Tensor, torch.Tensor],
batched_q_kv: tuple[torch.Tensor, torch.Tensor],
):
layer = PackedMultiheadAttention(
embed_dim=6,
num_heads=2,
alpha=1,
num_estimators=1,
kdim=2,
vdim=2,
add_zero_attn=True,
)
out, _ = layer(
query=unbatched_q_kv[0],
Expand Down Expand Up @@ -418,14 +410,19 @@ def test_one_estimator_q_kv(self, unbatched_q_kv: torch.Tensor, batched_q_kv: to
)
assert out.shape == torch.Size([2, 3, 6])

def test_one_estimator_q_k_v(self, unbatched_q_k_v: torch.Tensor, batched_q_k_v: torch.Tensor):
def test_one_estimator_q_k_v(
self,
unbatched_q_k_v: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
batched_q_k_v: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
):
layer = PackedMultiheadAttention(
embed_dim=6,
num_heads=2,
alpha=1,
num_estimators=1,
kdim=2,
vdim=4,
add_bias_kv=True,
)
out, _ = layer(
query=unbatched_q_k_v[0],
Expand All @@ -452,9 +449,26 @@ def test_one_estimator_q_k_v(self, unbatched_q_k_v: torch.Tensor, batched_q_k_v:
vdim=4,
batch_first=True,
)

layer.eval()

attn_mask = torch.zeros(3, 4, dtype=torch.bool)
key_padding_mask = torch.zeros(2, 4, dtype=torch.bool)

out, _ = layer(
query=batched_q_k_v[0],
key=batched_q_k_v[1],
value=batched_q_k_v[2],
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
)
assert out.shape == torch.Size([2, 3, 6])
assert out.isfinite().all()


class TestPackedTransformerEncoderLayer:
"""Testing the PackedTransformerEncoderLayer class."""


class TestPackedTransformerDecoderLayer:
"""Testing the PackedTransformerDecoderLayer class."""
21 changes: 12 additions & 9 deletions torch_uncertainty/layers/functional/packed.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,19 @@ def packed_linear(
block_diag = torch.block_diag(*weight)
return F.linear(inputs, block_diag, bias)
if implementation == "sparse":
return (inputs @ weight.transpose(0, 1)) + bias
out = inputs @ weight.transpose(0, 1)
if bias is not None:
out += bias
return out
if implementation == "einsum":
return (
torch.einsum(
"...ki,kij->...kj",
rearrange(inputs, "... (m d) -> ... m d", m=num_groups),
weight.transpose(1, 2),
).flatten(start_dim=-2)
+ bias
)
out = torch.einsum(
"...ki,kij->...kj",
rearrange(inputs, "... (m d) -> ... m d", m=num_groups),
weight.transpose(1, 2),
).flatten(start_dim=-2)
if bias is not None:
out += bias
return out
raise ValueError(f"Unknown implementation: {implementation}")


Expand Down
70 changes: 48 additions & 22 deletions torch_uncertainty/layers/packed.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,17 +566,48 @@ def bias(self) -> Tensor | None:


class PackedLayerNorm(nn.GroupNorm):
"""Packed-Ensembles-style LayerNorm layer.
Args:
embed_dim (int): the number of features in the input tensor.
num_estimators (int): the number of estimators in the ensemble.
alpha (float): the width multiplier of the layer.
eps (float, optional): a value added to the denominator for numerical stability. Defaults
to 1e-5.
affine (bool, optional): a boolean value that when set to ``True``, this module has
learnable per_channel affine parameters initialized to ones (for weights) and zeros
(for biases). Defaults to ``True``.
Shape:
- Input: :math:`(N, *)` where :math:`*` means any number of additional dimensions.
- Output: :math:`(N, *)` (same shape as input)
"""

def __init__(
self,
embed_dim: int,
num_estimators: int,
alpha: float,
eps: float = 1e-5,
affine: bool = True,
) -> None:
super().__init__(
num_groups=num_estimators,
num_channels=int(embed_dim * alpha),
eps=eps,
affine=affine,
)

def forward(self, inputs: Tensor) -> Tensor:
b, _, _ = inputs.size()
x = rearrange(inputs, "b s h -> (b s) h")
x = rearrange(inputs, "b ... h -> b h ...")
x = F.group_norm(
x,
self.num_groups,
self.weight,
self.bias,
self.eps,
)
return rearrange(x, "(b s) h -> b s h", b=b)
return rearrange(x, "b h ... -> b ... h")


class PackedMultiheadAttention(nn.Module):
Expand Down Expand Up @@ -683,6 +714,12 @@ def __init__(
else:
self.register_parameter("in_proj_bias", None)

if add_bias_kv:
self.bias_k = nn.Parameter(torch.empty((1, 1, self.embed_dim), **factory_kwargs))
self.bias_v = nn.Parameter(torch.empty((1, 1, self.embed_dim), **factory_kwargs))
else:
self.bias_k = self.bias_v = None

self.out_proj = PackedLinear(
in_features=embed_dim,
out_features=embed_dim,
Expand All @@ -696,6 +733,8 @@ def __init__(
**factory_kwargs,
)

self.add_zero_attn = add_zero_attn

self._reset_parameters()

def _reset_parameters(self):
Expand All @@ -712,19 +751,6 @@ def _reset_parameters(self):
nn.init.constant_(self.in_proj_bias, 0.0)
nn.init.constant_(self.out_proj.bias, 0.0)

def __setstate__(self, state):
"""Support loading old MultiheadAttention checkpoints generated by
v1.1.0.
Args:
state (_type_): _description_
"""
#
if "_qkv_same_embed_dim" not in state:
state["_qkv_same_embed_dim"] = True

super().__setstate__(state)

def forward(
self,
query: Tensor,
Expand Down Expand Up @@ -779,9 +805,9 @@ def forward(
self.num_groups,
self.in_proj_weight,
self.in_proj_bias,
None,
None,
False,
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
Expand Down Expand Up @@ -809,9 +835,9 @@ def forward(
self.num_groups,
self.in_proj_weight,
self.in_proj_bias,
None,
None,
False,
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
Expand Down

0 comments on commit 1a53be2

Please sign in to comment.