Skip to content
19 changes: 16 additions & 3 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,12 @@ def __init__(
)
self.input_size = input_size

def forward(self, x):
def forward(self, x, attn_mask: torch.Tensor | None = None):
"""
Args:
x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C
attn_mask (torch.Tensor, optional): mask to apply to the attention matrix.
B x (s_dim_1 * ... * s_dim_n). Defaults to None.

Return:
torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
Expand All @@ -176,7 +178,13 @@ def forward(self, x):

if self.use_flash_attention:
x = F.scaled_dot_product_attention(
query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal
query=q,
key=k,
value=v,
attn_mask=attn_mask,
scale=self.scale,
dropout_p=self.dropout_rate,
is_causal=self.causal,
)
else:
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
Expand All @@ -186,10 +194,15 @@ def forward(self, x):
att_mat = self.rel_positional_embedding(x, att_mat, q)

if self.causal:
assert attn_mask is None, "Causal attention does not support attention masks."
att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf"))

att_mat = att_mat.softmax(dim=-1)
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(1).unsqueeze(2)
attn_mask = attn_mask.expand(-1, self.num_heads, -1, -1)
att_mat = att_mat.masked_fill(attn_mask == 0, float("-inf"))

att_mat = att_mat.softmax(dim=-1)
if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
Expand Down
6 changes: 4 additions & 2 deletions monai/networks/blocks/transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,10 @@ def __init__(
use_flash_attention=use_flash_attention,
)

def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
x = x + self.attn(self.norm1(x))
def forward(
self, x: torch.Tensor, context: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
x = x + self.attn(self.norm1(x), attn_mask=attn_mask)
if self.with_cross_attention:
x = x + self.cross_attn(self.norm_cross_attn(x), context=context)
x = x + self.mlp(self.norm2(x))
Expand Down
18 changes: 18 additions & 0 deletions tests/test_selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,24 @@ def test_causal(self):
# check upper triangular part of the attention matrix is zero
assert torch.triu(block.att_mat, diagonal=1).sum() == 0

def test_masked_selfattention(self):
n = 64
block = SABlock(hidden_size=128, num_heads=1, dropout_rate=0.1, sequence_length=16, save_attn=True)
input_shape = (1, n, 128)
# generate a mask randomly with zeros and ones of shape (1, n)
mask = torch.randint(0, 2, (1, n)).bool()
block(torch.randn(input_shape), attn_mask=mask)
att_mat = block.att_mat.squeeze()
# ensure all masked columns are zeros
assert torch.allclose(att_mat[:, ~mask.squeeze(0)], torch.zeros_like(att_mat[:, ~mask.squeeze(0)]))

def test_causal_and_mask(self):
with self.assertRaises(AssertionError):
block = SABlock(hidden_size=128, num_heads=1, causal=True, sequence_length=64)
inputs = torch.randn(2, 64, 128)
mask = torch.randint(0, 2, (2, 64)).bool()
block(inputs, attn_mask=mask)

@skipUnless(has_einops, "Requires einops")
def test_access_attn_matrix(self):
# input format
Expand Down
Loading