Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add include_fc and use_combined_linear argument in the SABlock #7996

Merged
merged 36 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
37cd5cd
fix #7991
KumoLiu Aug 6, 2024
63ba16d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2024
7255a90
add docstring
KumoLiu Aug 6, 2024
ddbd336
Merge branch 'proj-atten' of https://github.com/KumoLiu/MONAI into pr…
KumoLiu Aug 6, 2024
0337d45
fix #7992
KumoLiu Aug 6, 2024
f198e2c
Merge branch 'linear' into proj-atten
KumoLiu Aug 6, 2024
814e61a
add tests
KumoLiu Aug 6, 2024
7dd22e0
Merge remote-tracking branch 'origin/dev' into proj-atten
KumoLiu Aug 6, 2024
de9eef0
remove transpose in sablock
KumoLiu Aug 7, 2024
2333351
fix docstring
KumoLiu Aug 7, 2024
f9eb6d8
use rearange
KumoLiu Aug 7, 2024
5aeccbe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 7, 2024
3154c7c
minor fix
KumoLiu Aug 7, 2024
81d3605
add in SpatialAttentionBlock
KumoLiu Aug 7, 2024
f47c2c6
Merge remote-tracking branch 'origin/dev' into proj-atten
KumoLiu Aug 7, 2024
754e7f2
fix format
KumoLiu Aug 7, 2024
3cf2124
add tests
KumoLiu Aug 7, 2024
8de91eb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 7, 2024
0b556a5
minor fix
KumoLiu Aug 7, 2024
9a59a15
minor fix
KumoLiu Aug 7, 2024
05e42ce
format fix
KumoLiu Aug 7, 2024
7dc0933
Merge branch 'proj-atten' of https://github.com/KumoLiu/MONAI into pr…
KumoLiu Aug 7, 2024
aae275d
minor fix
KumoLiu Aug 7, 2024
531a831
fix mypy
KumoLiu Aug 7, 2024
48319c0
fix ci
KumoLiu Aug 7, 2024
b854d7a
minor fix
KumoLiu Aug 7, 2024
32d0a5d
address comments
KumoLiu Aug 8, 2024
e5f2cb1
minor fix
KumoLiu Aug 8, 2024
818ba7e
Update tests/test_crossattention.py
KumoLiu Aug 9, 2024
4bef7f0
Update tests/test_selfattention.py
KumoLiu Aug 9, 2024
bfc8f29
minor fix
KumoLiu Aug 9, 2024
3d09b4a
Merge remote-tracking branch 'origin/dev' into proj-atten
KumoLiu Aug 9, 2024
0da115a
address comments
KumoLiu Aug 9, 2024
0d46a6b
Merge branch 'dev' into proj-atten
KumoLiu Aug 9, 2024
1c5599d
fix state dict
KumoLiu Aug 9, 2024
6ed765d
Merge branch 'dev' into proj-atten
KumoLiu Aug 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from __future__ import annotations

from typing import Optional, Tuple
from typing import Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -39,9 +39,11 @@ def __init__(
hidden_input_size: int | None = None,
causal: bool = False,
sequence_length: int | None = None,
rel_pos_embedding: Optional[str] = None,
input_size: Optional[Tuple] = None,
attention_dtype: Optional[torch.dtype] = None,
rel_pos_embedding: str | None = None,
input_size: Tuple | None = None,
attention_dtype: torch.dtype | None = None,
include_fc: bool = True,
use_combined_linear: bool = True,
) -> None:
"""
Args:
Expand All @@ -59,6 +61,8 @@ def __init__(
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
positional parameter size.
attention_dtype: cast attention operations to this dtype.
include_fc: whether to include the final linear layer. Default to True.
use_combined_linear: whether to use a single linear layer for qkv projection, default to True.

"""

Expand Down Expand Up @@ -86,9 +90,17 @@ def __init__(
self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size)

self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias)
self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads)
self.out_rearrange = Rearrange("b h l d -> b l (h d)")
if use_combined_linear:
self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias)
self.to_q = self.to_k = self.to_v = nn.Identity() # add to enable torchscript
self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads)
else:
self.to_q = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias)
self.to_k = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias)
self.to_v = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias)
self.qkv = nn.Identity() # add to enable torchscript
self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads)
self.out_rearrange = Rearrange("b l h d -> b h (l d)")
ericspod marked this conversation as resolved.
Show resolved Hide resolved
self.drop_output = nn.Dropout(dropout_rate)
self.drop_weights = nn.Dropout(dropout_rate)
self.scale = self.dim_head**-0.5
Expand All @@ -97,6 +109,8 @@ def __init__(
self.attention_dtype = attention_dtype
self.causal = causal
self.sequence_length = sequence_length
self.include_fc = include_fc
self.use_combined_linear = use_combined_linear

if causal and sequence_length is not None:
# causal mask to ensure that attention is only applied to the left in the input sequence
Expand All @@ -123,8 +137,13 @@ def forward(self, x):
Return:
torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
"""
output = self.input_rearrange(self.qkv(x))
q, k, v = output[0], output[1], output[2]
if self.use_combined_linear:
output = self.input_rearrange(self.qkv(x))
q, k, v = output[0], output[1], output[2]
else:
q = self.input_rearrange(self.to_q(x))
k = self.input_rearrange(self.to_k(x))
v = self.input_rearrange(self.to_v(x))

if self.attention_dtype is not None:
q = q.to(self.attention_dtype)
Expand All @@ -148,6 +167,7 @@ def forward(self, x):
att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
x = self.out_rearrange(x)
x = self.out_proj(x)
if self.include_fc:
x = self.out_proj(x)
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
x = self.drop_output(x)
return x
47 changes: 35 additions & 12 deletions tests/test_selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from monai.networks.blocks.selfattention import SABlock
from monai.networks.layers.factories import RelPosEmbedding
from monai.utils import optional_import
from tests.utils import test_script_save

einops, has_einops = optional_import("einops")

Expand All @@ -31,18 +32,22 @@
for num_heads in [4, 6, 8, 12]:
for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]:
for input_size in [(16, 32), (8, 8, 8)]:
test_case = [
{
"hidden_size": hidden_size,
"num_heads": num_heads,
"dropout_rate": dropout_rate,
"rel_pos_embedding": rel_pos_embedding,
"input_size": input_size,
},
(2, 512, hidden_size),
(2, 512, hidden_size),
]
TEST_CASE_SABLOCK.append(test_case)
for include_fc in [True, False]:
for use_combined_linear in [True, False]:
test_case = [
{
"hidden_size": hidden_size,
"num_heads": num_heads,
"dropout_rate": dropout_rate,
"rel_pos_embedding": rel_pos_embedding,
"input_size": input_size,
"include_fc": include_fc,
"use_combined_linear": use_combined_linear,
},
(2, 512, hidden_size),
(2, 512, hidden_size),
]
TEST_CASE_SABLOCK.append(test_case)


class TestResBlock(unittest.TestCase):
Expand Down Expand Up @@ -138,6 +143,24 @@ def count_sablock_params(*args, **kwargs):
nparams_default_more_heads = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads * 2)
self.assertEqual(nparams_default, nparams_default_more_heads)

@skipUnless(has_einops, "Requires einops")
def test_script(self):
for include_fc in [True, False]:
for use_combined_linear in [True, False]:
input_param = {
"hidden_size": 360,
"num_heads": 4,
"dropout_rate": 0.0,
"rel_pos_embedding": None,
"input_size": (16, 32),
"include_fc": include_fc,
"use_combined_linear": use_combined_linear,
}
net = SABlock(**input_param)
input_shape = (2, 512, 360)
test_data = torch.randn(input_shape)
test_script_save(net, test_data)


if __name__ == "__main__":
unittest.main()
Loading