From c9fed96f8f09302de98bf1b5ff73a81fe85158a7 Mon Sep 17 00:00:00 2001 From: Lucas Robinet Date: Wed, 27 Mar 2024 06:10:45 +0100 Subject: [PATCH] Fixing gradient in sincos positional encoding in monai/networks/blocks/patchembedding.py (#7564) ### Description When you choose to put the argument `pos_embed_type='sincos'` in the `PatchEmbeddingBlock` class, it still return a learnable positional encoding To reproduce: ```python from monai.networks.blocks import PatchEmbeddingBlock patcher = PatchEmbeddingBlock( in_channels=1, img_size=(32, 32, 32), patch_size=(8, 8, 8), hidden_size=96, num_heads=8, pos_embed_type="sincos", dropout_rate=0.5, ) print(patcher.position_embeddings.requires_grad) >>> True ``` In the literature, we sometimes use either positional encoding in sincos which are fixed and non-trainable as in the original Attention Is All You Need [paper](https://arxiv.org/abs/1706.03762) or a learnable positional embedding as in the ViT [paper](https://arxiv.org/abs/2010.11929). If you choose to use a sincos, then it seems that is must be fixed which is not the case here. I'm not completely sure of the desired result in MONAI since there's already a learnable possibility, so if we choose sincos we'd like gradient-free parameters. However the documentation of `build_sincos_position_embedding`in the `pos_embed_utils.py`files stipulate: "The sin-cos position embedding as a learnable parameter" which seems a bit confusing. Especially as the encoding construction function seems to aim to set the require gradient to False (see below) ```python pos_embed = nn.Parameter(pos_emb) pos_embed.requires_grad = False return pos_embed ``` But these changes are not maintained by torch's `copy_` function, which does not copy gradient parameters (see the cpp code https://github.com/pytorch/pytorch/blob/148a8de6397be6e4b4ca1508b03b82d117bfb03c/torch/csrc/lazy/ts_backend/tensor_aten_ops.cpp#L51). This `copy_`is used in the `PatchEmbeddingBlock` class to instantiate the positional embedding. I propose a small fix to overcome this problem as well as test cases to ensure that positional embedding behaves correctly. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Lucas Robinet Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/blocks/patchembedding.py | 1 + monai/networks/blocks/pos_embed_utils.py | 2 +- tests/test_patchembedding.py | 26 ++++++++++++++++++++++++ 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index 7d56045814..44774ce5da 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -123,6 +123,7 @@ def __init__( with torch.no_grad(): pos_embeddings = build_sincos_position_embedding(grid_size, hidden_size, spatial_dims) self.position_embeddings.data.copy_(pos_embeddings.float()) + self.position_embeddings.requires_grad = False else: raise ValueError(f"pos_embed_type {self.pos_embed_type} not supported.") diff --git a/monai/networks/blocks/pos_embed_utils.py b/monai/networks/blocks/pos_embed_utils.py index e03553307e..21586e56da 100644 --- a/monai/networks/blocks/pos_embed_utils.py +++ b/monai/networks/blocks/pos_embed_utils.py @@ -46,7 +46,7 @@ def build_sincos_position_embedding( temperature (float): The temperature for the sin-cos position embedding. Returns: - pos_embed (nn.Parameter): The sin-cos position embedding as a learnable parameter. + pos_embed (nn.Parameter): The sin-cos position embedding as a fixed parameter. """ if spatial_dims == 2: diff --git a/tests/test_patchembedding.py b/tests/test_patchembedding.py index f8610d9214..d059145033 100644 --- a/tests/test_patchembedding.py +++ b/tests/test_patchembedding.py @@ -93,6 +93,32 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) + def test_sincos_pos_embed(self): + net = PatchEmbeddingBlock( + in_channels=1, + img_size=(32, 32, 32), + patch_size=(8, 8, 8), + hidden_size=96, + num_heads=8, + pos_embed_type="sincos", + dropout_rate=0.5, + ) + + self.assertEqual(net.position_embeddings.requires_grad, False) + + def test_learnable_pos_embed(self): + net = PatchEmbeddingBlock( + in_channels=1, + img_size=(32, 32, 32), + patch_size=(8, 8, 8), + hidden_size=96, + num_heads=8, + pos_embed_type="learnable", + dropout_rate=0.5, + ) + + self.assertEqual(net.position_embeddings.requires_grad, True) + def test_ill_arg(self): with self.assertRaises(ValueError): PatchEmbeddingBlock(