Skip to content

Commit 1ccb5de

Browse files
committed
Add splitting embedding dim across head as default
Signed-off-by: NabJa <nabil.jabareen@gmail.com> DCO Remediation Commit for NabJa <nabil.jabareen@gmail.com> I, NabJa <nabil.jabareen@gmail.com>, hereby add my Signed-off-by to this commit: 139182e Signed-off-by: NabJa <nabil.jabareen@gmail.com>
1 parent f1fd5c8 commit 1ccb5de

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

monai/networks/blocks/selfattention.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(
3232
dropout_rate: float = 0.0,
3333
qkv_bias: bool = False,
3434
save_attn: bool = False,
35-
dim_head: int = 64
35+
dim_head: int | None = None,
3636
) -> None:
3737
"""
3838
Args:
@@ -41,7 +41,7 @@ def __init__(
4141
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
4242
qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
4343
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
44-
dim_head (int, optional): dimension of each head. Defaults to 64.
44+
dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads.
4545
4646
"""
4747

@@ -54,8 +54,8 @@ def __init__(
5454
raise ValueError("hidden size should be divisible by num_heads.")
5555

5656
self.num_heads = num_heads
57-
self.dim_head = dim_head
58-
self.inner_dim = dim_head * num_heads
57+
self.dim_head = hidden_size // num_heads if dim_head is None else dim_head
58+
self.inner_dim = self.dim_head * num_heads
5959

6060
self.out_proj = nn.Linear(self.inner_dim, hidden_size)
6161
self.qkv = nn.Linear(hidden_size, self.inner_dim * 3, bias=qkv_bias)

0 commit comments

Comments
 (0)