22
22
import torch .nn as nn
23
23
24
24
from monai .networks .blocks .encoder import BaseEncoder
25
- from monai .networks .layers .factories import Conv , Norm , Pool
26
- from monai .networks .layers .utils import get_act_layer , get_pool_layer
25
+ from monai .networks .layers .factories import Conv , Pool
26
+ from monai .networks .layers .utils import get_act_layer , get_norm_layer , get_pool_layer
27
27
from monai .utils import ensure_tuple_rep
28
28
from monai .utils .module import look_up_option , optional_import
29
29
@@ -79,6 +79,7 @@ def __init__(
79
79
stride : int = 1 ,
80
80
downsample : nn .Module | partial | None = None ,
81
81
act : str | tuple = ("relu" , {"inplace" : True }),
82
+ norm : str | tuple = "batch" ,
82
83
) -> None :
83
84
"""
84
85
Args:
@@ -88,17 +89,18 @@ def __init__(
88
89
stride: stride to use for first conv layer.
89
90
downsample: which downsample layer to use.
90
91
act: activation type and arguments. Defaults to relu.
92
+ norm: feature normalization type and arguments. Defaults to batch norm.
91
93
"""
92
94
super ().__init__ ()
93
95
94
96
conv_type : Callable = Conv [Conv .CONV , spatial_dims ]
95
- norm_type : Callable = Norm [ Norm . BATCH , spatial_dims ]
97
+ norm_layer = get_norm_layer ( name = norm , spatial_dims = spatial_dims , channels = planes )
96
98
97
99
self .conv1 = conv_type (in_planes , planes , kernel_size = 3 , padding = 1 , stride = stride , bias = False )
98
- self .bn1 = norm_type ( planes )
100
+ self .bn1 = norm_layer
99
101
self .act = get_act_layer (name = act )
100
102
self .conv2 = conv_type (planes , planes , kernel_size = 3 , padding = 1 , bias = False )
101
- self .bn2 = norm_type ( planes )
103
+ self .bn2 = norm_layer
102
104
self .downsample = downsample
103
105
self .stride = stride
104
106
@@ -132,6 +134,7 @@ def __init__(
132
134
stride : int = 1 ,
133
135
downsample : nn .Module | partial | None = None ,
134
136
act : str | tuple = ("relu" , {"inplace" : True }),
137
+ norm : str | tuple = "batch" ,
135
138
) -> None :
136
139
"""
137
140
Args:
@@ -141,19 +144,20 @@ def __init__(
141
144
stride: stride to use for second conv layer.
142
145
downsample: which downsample layer to use.
143
146
act: activation type and arguments. Defaults to relu.
147
+ norm: feature normalization type and arguments. Defaults to batch norm.
144
148
"""
145
149
146
150
super ().__init__ ()
147
151
148
152
conv_type : Callable = Conv [Conv .CONV , spatial_dims ]
149
- norm_type : Callable = Norm [ Norm . BATCH , spatial_dims ]
153
+ norm_layer = partial ( get_norm_layer , name = norm , spatial_dims = spatial_dims )
150
154
151
155
self .conv1 = conv_type (in_planes , planes , kernel_size = 1 , bias = False )
152
- self .bn1 = norm_type ( planes )
156
+ self .bn1 = norm_layer ( channels = planes )
153
157
self .conv2 = conv_type (planes , planes , kernel_size = 3 , stride = stride , padding = 1 , bias = False )
154
- self .bn2 = norm_type ( planes )
158
+ self .bn2 = norm_layer ( channels = planes )
155
159
self .conv3 = conv_type (planes , planes * self .expansion , kernel_size = 1 , bias = False )
156
- self .bn3 = norm_type ( planes * self .expansion )
160
+ self .bn3 = norm_layer ( channels = planes * self .expansion )
157
161
self .act = get_act_layer (name = act )
158
162
self .downsample = downsample
159
163
self .stride = stride
@@ -207,6 +211,7 @@ class ResNet(nn.Module):
207
211
feed_forward: whether to add the FC layer for the output, default to `True`.
208
212
bias_downsample: whether to use bias term in the downsampling block when `shortcut_type` is 'B', default to `True`.
209
213
act: activation type and arguments. Defaults to relu.
214
+ norm: feature normalization type and arguments. Defaults to batch norm.
210
215
211
216
"""
212
217
@@ -226,6 +231,7 @@ def __init__(
226
231
feed_forward : bool = True ,
227
232
bias_downsample : bool = True , # for backwards compatibility (also see PR #5477)
228
233
act : str | tuple = ("relu" , {"inplace" : True }),
234
+ norm : str | tuple = "batch" ,
229
235
) -> None :
230
236
super ().__init__ ()
231
237
@@ -238,7 +244,6 @@ def __init__(
238
244
raise ValueError ("Unknown block '%s', use basic or bottleneck" % block )
239
245
240
246
conv_type : type [nn .Conv1d | nn .Conv2d | nn .Conv3d ] = Conv [Conv .CONV , spatial_dims ]
241
- norm_type : type [nn .BatchNorm1d | nn .BatchNorm2d | nn .BatchNorm3d ] = Norm [Norm .BATCH , spatial_dims ]
242
247
pool_type : type [nn .MaxPool1d | nn .MaxPool2d | nn .MaxPool3d ] = Pool [Pool .MAX , spatial_dims ]
243
248
avgp_type : type [nn .AdaptiveAvgPool1d | nn .AdaptiveAvgPool2d | nn .AdaptiveAvgPool3d ] = Pool [
244
249
Pool .ADAPTIVEAVG , spatial_dims
@@ -262,7 +267,9 @@ def __init__(
262
267
padding = tuple (k // 2 for k in conv1_kernel_size ),
263
268
bias = False ,
264
269
)
265
- self .bn1 = norm_type (self .in_planes )
270
+
271
+ norm_layer = get_norm_layer (name = norm , spatial_dims = spatial_dims , channels = self .in_planes )
272
+ self .bn1 = norm_layer
266
273
self .act = get_act_layer (name = act )
267
274
self .maxpool = pool_type (kernel_size = 3 , stride = 2 , padding = 1 )
268
275
self .layer1 = self ._make_layer (block , block_inplanes [0 ], layers [0 ], spatial_dims , shortcut_type )
@@ -275,7 +282,7 @@ def __init__(
275
282
for m in self .modules ():
276
283
if isinstance (m , conv_type ):
277
284
nn .init .kaiming_normal_ (torch .as_tensor (m .weight ), mode = "fan_out" , nonlinearity = "relu" )
278
- elif isinstance (m , norm_type ):
285
+ elif isinstance (m , type ( norm_layer ) ):
279
286
nn .init .constant_ (torch .as_tensor (m .weight ), 1 )
280
287
nn .init .constant_ (torch .as_tensor (m .bias ), 0 )
281
288
elif isinstance (m , nn .Linear ):
@@ -295,9 +302,9 @@ def _make_layer(
295
302
spatial_dims : int ,
296
303
shortcut_type : str ,
297
304
stride : int = 1 ,
305
+ norm : str | tuple = "batch" ,
298
306
) -> nn .Sequential :
299
307
conv_type : Callable = Conv [Conv .CONV , spatial_dims ]
300
- norm_type : Callable = Norm [Norm .BATCH , spatial_dims ]
301
308
302
309
downsample : nn .Module | partial | None = None
303
310
if stride != 1 or self .in_planes != planes * block .expansion :
@@ -317,18 +324,23 @@ def _make_layer(
317
324
stride = stride ,
318
325
bias = self .bias_downsample ,
319
326
),
320
- norm_type ( planes * block .expansion ),
327
+ get_norm_layer ( name = norm , spatial_dims = spatial_dims , channels = planes * block .expansion ),
321
328
)
322
329
323
330
layers = [
324
331
block (
325
- in_planes = self .in_planes , planes = planes , spatial_dims = spatial_dims , stride = stride , downsample = downsample
332
+ in_planes = self .in_planes ,
333
+ planes = planes ,
334
+ spatial_dims = spatial_dims ,
335
+ stride = stride ,
336
+ downsample = downsample ,
337
+ norm = norm ,
326
338
)
327
339
]
328
340
329
341
self .in_planes = planes * block .expansion
330
342
for _i in range (1 , blocks ):
331
- layers .append (block (self .in_planes , planes , spatial_dims = spatial_dims ))
343
+ layers .append (block (self .in_planes , planes , spatial_dims = spatial_dims , norm = norm ))
332
344
333
345
return nn .Sequential (* layers )
334
346
0 commit comments