Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dimensionality of heads argument to SABlock #7664

Merged
merged 13 commits into from
May 8, 2024

Conversation

NabJa
Copy link
Contributor

@NabJa NabJa commented Apr 18, 2024

Fixes #7661.

Description

The changes made add a parameter (dim_head) to set the output paramters of all the heads in the Self-attention Block (SABlock). Currently the output dimension is set to be hidden_size and when increasing the number of heads this is equally distributed among all heads.

Example

The original implementation automatically determines equally_distributed_head_dim:
(qkv * num_heds * equally_distributed_head_dim = 3*hidden_size
in this example -> 3 * 8 * 16 = 384)

block = SABlock(hidden_size=128, num_heads=8)
x = torch.zeros(1, 256, 128)
x = block.qkv(x)
print(x.shape)
x = block.input_rearrange(x)
print(x.shape)

> torch.Size([1, 256, 384])
> torch.Size([3, 1, 8, 256, 16]) # <- This corresponds to (qkv batch num_heads sequence_length equally_distributed_head_dim)

The propesed implementation fixes this by setting the new argument dim_head:

block_new = SABlock(hidden_size=128, num_heads=8, dim_head=32)
x = torch.zeros(1, 256, 128)
x = block_new.qkv(x)
print(x.shape)
x = block_new.input_rearrange(x)
print(x.shape)

> torch.Size([1, 256, 384])
> torch.Size([3, 1, 8, 256, 32]) # <- This corresponds to (qkv batch num_heads sequence_length dim_head)

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

@ericspod
Copy link
Member

Hi @NabJa thanks for the contribution. You'll have to fix some of the issues that you're seeing but some others may be related to our testing system. Please follow the DCO instructions for doing a remediation commit, and please run ./runtests.sh --autofix to fix other issues. @marksgraham is there anything here that would affect merging the generative code?

@KumoLiu
Copy link
Contributor

KumoLiu commented Apr 24, 2024

Hi @NabJa, I guess it's expected behavior.

Increasing the num_heads in the self-attention block (SABlock) does not increase the number of trainable parameters.
The original input embeddings are divided into smaller chunks across a specified number of attention heads in a multi-head attention mechanism. Each of these heads then independently performs an attention mechanism on their allocated chunk of the data.
Refer: https://github.com/pytorch/pytorch/blob/81740fd1f6fcd70c6ba4812c1289fe7efcc82908/torch/nn/modules/activation.py#L1010
https://discuss.huggingface.co/t/what-does-increasing-number-of-heads-do-in-the-multi-head-attention/1847
https://www.mathworks.com/matlabcentral/answers/2068031-attention-layer-number-of-parameters-doesn-t-change-when-changing-number-of-heads

But @ericspod and @marksgraham might have more expertise on this. What are your thoughts?
If we indeed want to make this change, perhaps we need to include the original implementation.
Thanks.

@NabJa
Copy link
Contributor Author

NabJa commented Apr 24, 2024

Hi @KumoLiu ,

thank you for the references! Indeed, the official PyTorch implementation splits the embeddings across all heads resulting in a head dimension of embedding dimension // number heads.
However, I still think having the option of manually setting this makes a lot of sense, because I might want to be able to increase the number of heads without loosing representational power used for every attention map. Other frequently used Attention implementations also manually set the head dimension (lucidrains).
I will make another commit changing it in a way that the default behaviour stays as it is, additionally having the option of manually setting the dimension head.
Looking forward to your opinions / reviews.

@marksgraham
Copy link
Contributor

We need to be careful not to break backwards compatibility here so the default behaviour stays the same. If that happens then I don't see anything that would affect the generative code merge

Signed-off-by: NabJa <nabil.jabareen@gmail.com>

DCO Remediation Commit for NabJa <nabil.jabareen@gmail.com>

I, NabJa <nabil.jabareen@gmail.com>, hereby add my Signed-off-by to this commit: 139182e

Signed-off-by: NabJa <nabil.jabareen@gmail.com>
@NabJa
Copy link
Contributor Author

NabJa commented Apr 24, 2024

@marksgraham complete backward compatibility should be guaranteed with 1ccb5de .
@ericspod DCO is updated and linting passes the checks.

@marksgraham
Copy link
Contributor

Would it be possible to hold off merging this for a couple of days? I'm working on some changes for self-attention/cross-attention for the MONAI Generative merge and it would be good to get a little bit further into them so I can work out if i need to make any further changes to accommodate this PR.

@ericspod
Copy link
Member

Would it be possible to hold off merging this for a couple of days? I'm working on some changes for self-attention/cross-attention for the MONAI Generative merge and it would be good to get a little bit further into them so I can work out if i need to make any further changes to accommodate this PR.

I think we're good to delay until you're set, thanks.

@KumoLiu
Copy link
Contributor

KumoLiu commented May 8, 2024

/build

@KumoLiu KumoLiu merged commit d83fa56 into Project-MONAI:dev May 8, 2024
28 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

SABlock parameters when using more heads
4 participants