Skip to content

Commit a5dff5b

Browse files
authored
6668 fix hard-coded up_kernel_size in ViTAutoEnc (#6735)
Fixes #6668 . Fix hard-coded `up_kernel_size` in `ViTAutoEnc` without changing the network architecture. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu <yunl@nvidia.com>
1 parent 4afa2ad commit a5dff5b

File tree

4 files changed

+37
-6
lines changed

4 files changed

+37
-6
lines changed

monai/networks/nets/vitautoenc.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from __future__ import annotations
1313

14+
import math
1415
from collections.abc import Sequence
1516

1617
import torch
@@ -19,7 +20,7 @@
1920
from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
2021
from monai.networks.blocks.transformerblock import TransformerBlock
2122
from monai.networks.layers import Conv
22-
from monai.utils import ensure_tuple_rep
23+
from monai.utils import ensure_tuple_rep, is_sqrt
2324

2425
__all__ = ["ViTAutoEnc"]
2526

@@ -78,9 +79,14 @@ def __init__(
7879
"""
7980

8081
super().__init__()
81-
82+
if not is_sqrt(patch_size):
83+
raise ValueError(f"patch_size should be square number, got {patch_size}.")
8284
self.patch_size = ensure_tuple_rep(patch_size, spatial_dims)
85+
self.img_size = ensure_tuple_rep(img_size, spatial_dims)
8386
self.spatial_dims = spatial_dims
87+
for m, p in zip(self.img_size, self.patch_size):
88+
if m % p != 0:
89+
raise ValueError(f"patch_size={patch_size} should be divisible by img_size={img_size}.")
8490

8591
self.patch_embedding = PatchEmbeddingBlock(
8692
in_channels=in_channels,
@@ -100,12 +106,12 @@ def __init__(
100106
)
101107
self.norm = nn.LayerNorm(hidden_size)
102108

103-
new_patch_size = [4] * self.spatial_dims
104109
conv_trans = Conv[Conv.CONVTRANS, self.spatial_dims]
105110
# self.conv3d_transpose* is to be compatible with existing 3d model weights.
106-
self.conv3d_transpose = conv_trans(hidden_size, deconv_chns, kernel_size=new_patch_size, stride=new_patch_size)
111+
up_kernel_size = [int(math.sqrt(i)) for i in self.patch_size]
112+
self.conv3d_transpose = conv_trans(hidden_size, deconv_chns, kernel_size=up_kernel_size, stride=up_kernel_size)
107113
self.conv3d_transpose_1 = conv_trans(
108-
in_channels=deconv_chns, out_channels=out_channels, kernel_size=new_patch_size, stride=new_patch_size
114+
in_channels=deconv_chns, out_channels=out_channels, kernel_size=up_kernel_size, stride=up_kernel_size
109115
)
110116

111117
def forward(self, x):

monai/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
is_module_ver_at_least,
8181
is_scalar,
8282
is_scalar_tensor,
83+
is_sqrt,
8384
issequenceiterable,
8485
list_to_dict,
8586
path_to_uri,

monai/utils/misc.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import inspect
1515
import itertools
16+
import math
1617
import os
1718
import pprint
1819
import random
@@ -853,3 +854,13 @@ def run_cmd(cmd_list: list[str], **kwargs: Any) -> subprocess.CompletedProcess:
853854
output = str(e.stdout.decode(errors="replace"))
854855
errors = str(e.stderr.decode(errors="replace"))
855856
raise RuntimeError(f"subprocess call error {e.returncode}: {errors}, {output}.") from e
857+
858+
859+
def is_sqrt(num: Sequence[int] | int) -> bool:
860+
"""
861+
Determine if the input is a square number or a squence of square numbers.
862+
"""
863+
num = ensure_tuple(num)
864+
sqrt_num = [int(math.sqrt(_num)) for _num in num]
865+
ret = [_i * _j for _i, _j in zip(sqrt_num, sqrt_num)]
866+
return ensure_tuple(ret) == num

tests/test_vitautoenc.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
{
5050
"in_channels": 1,
5151
"img_size": (512, 512, 32),
52-
"patch_size": (16, 16, 16),
52+
"patch_size": (64, 64, 16),
5353
"hidden_size": 768,
5454
"mlp_dim": 3072,
5555
"num_layers": 4,
@@ -147,6 +147,19 @@ def test_ill_arg(self):
147147
dropout_rate=0.3,
148148
)
149149

150+
with self.assertRaises(ValueError):
151+
ViTAutoEnc(
152+
in_channels=4,
153+
img_size=(96, 96, 96),
154+
patch_size=(9, 9, 9),
155+
hidden_size=768,
156+
mlp_dim=3072,
157+
num_layers=12,
158+
num_heads=12,
159+
pos_embed="perc",
160+
dropout_rate=0.3,
161+
)
162+
150163

151164
if __name__ == "__main__":
152165
unittest.main()

0 commit comments

Comments
 (0)