Skip to content

Commit 373c003

Browse files
Pkaps25Peter KaplinskyKumoLiu
authored
Add norm param to ResNet (#7752)
Fixes #7294 . ### Description Adds a `norm` param to ResNet ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Peter Kaplinsky <peterkaplinsky@gmail.com> Co-authored-by: Peter Kaplinsky <peterkaplinsky@gmail.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent 66a2fae commit 373c003

File tree

3 files changed

+62
-25
lines changed

3 files changed

+62
-25
lines changed

monai/networks/nets/daf3d.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from collections import OrderedDict
1515
from collections.abc import Callable, Sequence
16+
from functools import partial
1617

1718
import torch
1819
import torch.nn as nn
@@ -25,6 +26,7 @@
2526
from monai.networks.blocks.convolutions import Convolution
2627
from monai.networks.blocks.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork
2728
from monai.networks.layers.factories import Conv, Norm
29+
from monai.networks.layers.utils import get_norm_layer
2830
from monai.networks.nets.resnet import ResNet, ResNetBottleneck
2931

3032
__all__ = [
@@ -170,27 +172,31 @@ class Daf3dResNetBottleneck(ResNetBottleneck):
170172
spatial_dims: number of spatial dimensions of the input image.
171173
stride: stride to use for second conv layer.
172174
downsample: which downsample layer to use.
175+
norm: which normalization layer to use. Defaults to group.
173176
"""
174177

175178
expansion = 2
176179

177-
def __init__(self, in_planes, planes, spatial_dims=3, stride=1, downsample=None):
178-
norm_type: Callable = Norm[Norm.GROUP, spatial_dims]
180+
def __init__(
181+
self, in_planes, planes, spatial_dims=3, stride=1, downsample=None, norm=("group", {"num_groups": 32})
182+
):
179183
conv_type: Callable = Conv[Conv.CONV, spatial_dims]
180184

185+
norm_layer = partial(get_norm_layer, name=norm, spatial_dims=spatial_dims)
186+
181187
# in case downsample uses batch norm, change to group norm
182188
if isinstance(downsample, nn.Sequential):
183189
downsample = nn.Sequential(
184190
conv_type(in_planes, planes * self.expansion, kernel_size=1, stride=stride, bias=False),
185-
norm_type(num_groups=32, num_channels=planes * self.expansion),
191+
norm_layer(channels=planes * self.expansion),
186192
)
187193

188194
super().__init__(in_planes, planes, spatial_dims, stride, downsample)
189195

190196
# change norm from batch to group norm
191-
self.bn1 = norm_type(num_groups=32, num_channels=planes)
192-
self.bn2 = norm_type(num_groups=32, num_channels=planes)
193-
self.bn3 = norm_type(num_groups=32, num_channels=planes * self.expansion)
197+
self.bn1 = norm_layer(channels=planes)
198+
self.bn2 = norm_layer(channels=planes)
199+
self.bn3 = norm_layer(channels=planes * self.expansion)
194200

195201
# adapt second convolution to work with groups
196202
self.conv2 = conv_type(planes, planes, kernel_size=3, padding=1, stride=stride, groups=32, bias=False)
@@ -212,8 +218,10 @@ class Daf3dResNetDilatedBottleneck(Daf3dResNetBottleneck):
212218
downsample: which downsample layer to use.
213219
"""
214220

215-
def __init__(self, in_planes, planes, spatial_dims=3, stride=1, downsample=None):
216-
super().__init__(in_planes, planes, spatial_dims, stride, downsample)
221+
def __init__(
222+
self, in_planes, planes, spatial_dims=3, stride=1, downsample=None, norm=("group", {"num_groups": 32})
223+
):
224+
super().__init__(in_planes, planes, spatial_dims, stride, downsample, norm)
217225

218226
# add dilation in second convolution
219227
conv_type: Callable = Conv[Conv.CONV, spatial_dims]

monai/networks/nets/resnet.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
import torch.nn as nn
2323

2424
from monai.networks.blocks.encoder import BaseEncoder
25-
from monai.networks.layers.factories import Conv, Norm, Pool
26-
from monai.networks.layers.utils import get_act_layer, get_pool_layer
25+
from monai.networks.layers.factories import Conv, Pool
26+
from monai.networks.layers.utils import get_act_layer, get_norm_layer, get_pool_layer
2727
from monai.utils import ensure_tuple_rep
2828
from monai.utils.module import look_up_option, optional_import
2929

@@ -79,6 +79,7 @@ def __init__(
7979
stride: int = 1,
8080
downsample: nn.Module | partial | None = None,
8181
act: str | tuple = ("relu", {"inplace": True}),
82+
norm: str | tuple = "batch",
8283
) -> None:
8384
"""
8485
Args:
@@ -88,17 +89,18 @@ def __init__(
8889
stride: stride to use for first conv layer.
8990
downsample: which downsample layer to use.
9091
act: activation type and arguments. Defaults to relu.
92+
norm: feature normalization type and arguments. Defaults to batch norm.
9193
"""
9294
super().__init__()
9395

9496
conv_type: Callable = Conv[Conv.CONV, spatial_dims]
95-
norm_type: Callable = Norm[Norm.BATCH, spatial_dims]
97+
norm_layer = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=planes)
9698

9799
self.conv1 = conv_type(in_planes, planes, kernel_size=3, padding=1, stride=stride, bias=False)
98-
self.bn1 = norm_type(planes)
100+
self.bn1 = norm_layer
99101
self.act = get_act_layer(name=act)
100102
self.conv2 = conv_type(planes, planes, kernel_size=3, padding=1, bias=False)
101-
self.bn2 = norm_type(planes)
103+
self.bn2 = norm_layer
102104
self.downsample = downsample
103105
self.stride = stride
104106

@@ -132,6 +134,7 @@ def __init__(
132134
stride: int = 1,
133135
downsample: nn.Module | partial | None = None,
134136
act: str | tuple = ("relu", {"inplace": True}),
137+
norm: str | tuple = "batch",
135138
) -> None:
136139
"""
137140
Args:
@@ -141,19 +144,20 @@ def __init__(
141144
stride: stride to use for second conv layer.
142145
downsample: which downsample layer to use.
143146
act: activation type and arguments. Defaults to relu.
147+
norm: feature normalization type and arguments. Defaults to batch norm.
144148
"""
145149

146150
super().__init__()
147151

148152
conv_type: Callable = Conv[Conv.CONV, spatial_dims]
149-
norm_type: Callable = Norm[Norm.BATCH, spatial_dims]
153+
norm_layer = partial(get_norm_layer, name=norm, spatial_dims=spatial_dims)
150154

151155
self.conv1 = conv_type(in_planes, planes, kernel_size=1, bias=False)
152-
self.bn1 = norm_type(planes)
156+
self.bn1 = norm_layer(channels=planes)
153157
self.conv2 = conv_type(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
154-
self.bn2 = norm_type(planes)
158+
self.bn2 = norm_layer(channels=planes)
155159
self.conv3 = conv_type(planes, planes * self.expansion, kernel_size=1, bias=False)
156-
self.bn3 = norm_type(planes * self.expansion)
160+
self.bn3 = norm_layer(channels=planes * self.expansion)
157161
self.act = get_act_layer(name=act)
158162
self.downsample = downsample
159163
self.stride = stride
@@ -207,6 +211,7 @@ class ResNet(nn.Module):
207211
feed_forward: whether to add the FC layer for the output, default to `True`.
208212
bias_downsample: whether to use bias term in the downsampling block when `shortcut_type` is 'B', default to `True`.
209213
act: activation type and arguments. Defaults to relu.
214+
norm: feature normalization type and arguments. Defaults to batch norm.
210215
211216
"""
212217

@@ -226,6 +231,7 @@ def __init__(
226231
feed_forward: bool = True,
227232
bias_downsample: bool = True, # for backwards compatibility (also see PR #5477)
228233
act: str | tuple = ("relu", {"inplace": True}),
234+
norm: str | tuple = "batch",
229235
) -> None:
230236
super().__init__()
231237

@@ -238,7 +244,6 @@ def __init__(
238244
raise ValueError("Unknown block '%s', use basic or bottleneck" % block)
239245

240246
conv_type: type[nn.Conv1d | nn.Conv2d | nn.Conv3d] = Conv[Conv.CONV, spatial_dims]
241-
norm_type: type[nn.BatchNorm1d | nn.BatchNorm2d | nn.BatchNorm3d] = Norm[Norm.BATCH, spatial_dims]
242247
pool_type: type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d] = Pool[Pool.MAX, spatial_dims]
243248
avgp_type: type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d] = Pool[
244249
Pool.ADAPTIVEAVG, spatial_dims
@@ -262,7 +267,9 @@ def __init__(
262267
padding=tuple(k // 2 for k in conv1_kernel_size),
263268
bias=False,
264269
)
265-
self.bn1 = norm_type(self.in_planes)
270+
271+
norm_layer = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=self.in_planes)
272+
self.bn1 = norm_layer
266273
self.act = get_act_layer(name=act)
267274
self.maxpool = pool_type(kernel_size=3, stride=2, padding=1)
268275
self.layer1 = self._make_layer(block, block_inplanes[0], layers[0], spatial_dims, shortcut_type)
@@ -275,7 +282,7 @@ def __init__(
275282
for m in self.modules():
276283
if isinstance(m, conv_type):
277284
nn.init.kaiming_normal_(torch.as_tensor(m.weight), mode="fan_out", nonlinearity="relu")
278-
elif isinstance(m, norm_type):
285+
elif isinstance(m, type(norm_layer)):
279286
nn.init.constant_(torch.as_tensor(m.weight), 1)
280287
nn.init.constant_(torch.as_tensor(m.bias), 0)
281288
elif isinstance(m, nn.Linear):
@@ -295,9 +302,9 @@ def _make_layer(
295302
spatial_dims: int,
296303
shortcut_type: str,
297304
stride: int = 1,
305+
norm: str | tuple = "batch",
298306
) -> nn.Sequential:
299307
conv_type: Callable = Conv[Conv.CONV, spatial_dims]
300-
norm_type: Callable = Norm[Norm.BATCH, spatial_dims]
301308

302309
downsample: nn.Module | partial | None = None
303310
if stride != 1 or self.in_planes != planes * block.expansion:
@@ -317,18 +324,23 @@ def _make_layer(
317324
stride=stride,
318325
bias=self.bias_downsample,
319326
),
320-
norm_type(planes * block.expansion),
327+
get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=planes * block.expansion),
321328
)
322329

323330
layers = [
324331
block(
325-
in_planes=self.in_planes, planes=planes, spatial_dims=spatial_dims, stride=stride, downsample=downsample
332+
in_planes=self.in_planes,
333+
planes=planes,
334+
spatial_dims=spatial_dims,
335+
stride=stride,
336+
downsample=downsample,
337+
norm=norm,
326338
)
327339
]
328340

329341
self.in_planes = planes * block.expansion
330342
for _i in range(1, blocks):
331-
layers.append(block(self.in_planes, planes, spatial_dims=spatial_dims))
343+
layers.append(block(self.in_planes, planes, spatial_dims=spatial_dims, norm=norm))
332344

333345
return nn.Sequential(*layers)
334346

tests/test_resnet.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,13 +202,30 @@
202202
(1, 3),
203203
]
204204

205+
TEST_CASE_9 = [ # Layer norm
206+
{
207+
"block": ResNetBlock,
208+
"layers": [3, 4, 6, 3],
209+
"block_inplanes": [64, 128, 256, 512],
210+
"spatial_dims": 1,
211+
"n_input_channels": 2,
212+
"num_classes": 3,
213+
"conv1_t_size": [3],
214+
"conv1_t_stride": 1,
215+
"act": ("relu", {"inplace": False}),
216+
"norm": ("layer", {"normalized_shape": (64, 32)}),
217+
},
218+
(1, 2, 32),
219+
(1, 3),
220+
]
221+
205222
TEST_CASES = []
206223
PRETRAINED_TEST_CASES = []
207224
for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]:
208225
for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]:
209226
TEST_CASES.append([model, *case])
210227
PRETRAINED_TEST_CASES.append([model, *case])
211-
for case in [TEST_CASE_5, TEST_CASE_5_A, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8]:
228+
for case in [TEST_CASE_5, TEST_CASE_5_A, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_9]:
212229
TEST_CASES.append([ResNet, *case])
213230

214231
TEST_SCRIPT_CASES = [

0 commit comments

Comments
 (0)