Skip to content

Commit

Permalink
Fixing gradient in sincos positional encoding in monai/networks/block…
Browse files Browse the repository at this point in the history
…s/patchembedding.py (#7564)

### Description

When you choose to put the argument `pos_embed_type='sincos'` in the
`PatchEmbeddingBlock` class, it still return a learnable positional
encoding

To reproduce:
```python
from monai.networks.blocks import PatchEmbeddingBlock
patcher = PatchEmbeddingBlock(
            in_channels=1,
            img_size=(32, 32, 32),
            patch_size=(8, 8, 8),
            hidden_size=96,
            num_heads=8,
            pos_embed_type="sincos",
            dropout_rate=0.5,
        )
print(patcher.position_embeddings.requires_grad) 
>>> True
```

In the literature, we sometimes use either positional encoding in sincos
which are fixed and non-trainable as in the original Attention Is All
You Need [paper](https://arxiv.org/abs/1706.03762) or a learnable
positional embedding as in the ViT
[paper](https://arxiv.org/abs/2010.11929).
If you choose to use a sincos, then it seems that is must be fixed which
is not the case here.
I'm not completely sure of the desired result in MONAI since there's
already a learnable possibility, so if we choose sincos we'd like
gradient-free parameters. However the documentation of
`build_sincos_position_embedding`in the `pos_embed_utils.py`files
stipulate: "The sin-cos position embedding as a learnable parameter"
which seems a bit confusing. Especially as the encoding construction
function seems to aim to set the require gradient to False (see below)

```python
pos_embed = nn.Parameter(pos_emb)
pos_embed.requires_grad = False

return pos_embed
```
But these changes are not maintained by torch's `copy_` function, which
does not copy gradient parameters (see the cpp code
https://github.com/pytorch/pytorch/blob/148a8de6397be6e4b4ca1508b03b82d117bfb03c/torch/csrc/lazy/ts_backend/tensor_aten_ops.cpp#L51).
This `copy_`is used in the `PatchEmbeddingBlock` class to instantiate
the positional embedding.

I propose a small fix to overcome this problem as well as test cases to
ensure that positional embedding behaves correctly.


### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: Lucas Robinet <robinet.lucas@iuct-oncopole.fr>
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
  • Loading branch information
Lucas-rbnt and KumoLiu authored Mar 27, 2024
1 parent 2716b6a commit c9fed96
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 1 deletion.
1 change: 1 addition & 0 deletions monai/networks/blocks/patchembedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(
with torch.no_grad():
pos_embeddings = build_sincos_position_embedding(grid_size, hidden_size, spatial_dims)
self.position_embeddings.data.copy_(pos_embeddings.float())
self.position_embeddings.requires_grad = False
else:
raise ValueError(f"pos_embed_type {self.pos_embed_type} not supported.")

Expand Down
2 changes: 1 addition & 1 deletion monai/networks/blocks/pos_embed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def build_sincos_position_embedding(
temperature (float): The temperature for the sin-cos position embedding.
Returns:
pos_embed (nn.Parameter): The sin-cos position embedding as a learnable parameter.
pos_embed (nn.Parameter): The sin-cos position embedding as a fixed parameter.
"""

if spatial_dims == 2:
Expand Down
26 changes: 26 additions & 0 deletions tests/test_patchembedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,32 @@ def test_shape(self, input_param, input_shape, expected_shape):
result = net(torch.randn(input_shape))
self.assertEqual(result.shape, expected_shape)

def test_sincos_pos_embed(self):
net = PatchEmbeddingBlock(
in_channels=1,
img_size=(32, 32, 32),
patch_size=(8, 8, 8),
hidden_size=96,
num_heads=8,
pos_embed_type="sincos",
dropout_rate=0.5,
)

self.assertEqual(net.position_embeddings.requires_grad, False)

def test_learnable_pos_embed(self):
net = PatchEmbeddingBlock(
in_channels=1,
img_size=(32, 32, 32),
patch_size=(8, 8, 8),
hidden_size=96,
num_heads=8,
pos_embed_type="learnable",
dropout_rate=0.5,
)

self.assertEqual(net.position_embeddings.requires_grad, True)

def test_ill_arg(self):
with self.assertRaises(ValueError):
PatchEmbeddingBlock(
Expand Down

0 comments on commit c9fed96

Please sign in to comment.