Skip to content

Commit 6cb4ced

Browse files
authored
Fix constructors for DenseNet derived classes (#5846)
### Description We noted it was possible to instantiate classes derived from DenseNet only if spatial_dims, in_channels, and out_channels parameters were passed by keywords. Passing them via positional scheme was not working. This small bug should be fixed now. ### Example: Before my fix: ``` import monai net = monai.networks.nets.DenseNet(3,1,2) # Working net = monai.networks.nets.DenseNet121(3,1,2) # NOT woking net = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2) # Woking ``` After my fix: ``` import monai net = monai.networks.nets.DenseNet121(3,1,2) # Woking ``` Thanks to @robsver for pointing this issue out. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder.
1 parent 33d41d7 commit 6cb4ced

File tree

1 file changed

+51
-7
lines changed

1 file changed

+51
-7
lines changed

monai/networks/nets/densenet.py

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -296,16 +296,27 @@ class DenseNet121(DenseNet):
296296

297297
def __init__(
298298
self,
299+
spatial_dims: int,
300+
in_channels: int,
301+
out_channels: int,
299302
init_features: int = 64,
300303
growth_rate: int = 32,
301304
block_config: Sequence[int] = (6, 12, 24, 16),
302305
pretrained: bool = False,
303306
progress: bool = True,
304307
**kwargs,
305308
) -> None:
306-
super().__init__(init_features=init_features, growth_rate=growth_rate, block_config=block_config, **kwargs)
309+
super().__init__(
310+
spatial_dims=spatial_dims,
311+
in_channels=in_channels,
312+
out_channels=out_channels,
313+
init_features=init_features,
314+
growth_rate=growth_rate,
315+
block_config=block_config,
316+
**kwargs,
317+
)
307318
if pretrained:
308-
if kwargs["spatial_dims"] > 2:
319+
if spatial_dims > 2:
309320
raise NotImplementedError(
310321
"Parameter `spatial_dims` is > 2 ; currently PyTorch Hub does not"
311322
"provide pretrained models for more than two spatial dimensions."
@@ -318,16 +329,27 @@ class DenseNet169(DenseNet):
318329

319330
def __init__(
320331
self,
332+
spatial_dims: int,
333+
in_channels: int,
334+
out_channels: int,
321335
init_features: int = 64,
322336
growth_rate: int = 32,
323337
block_config: Sequence[int] = (6, 12, 32, 32),
324338
pretrained: bool = False,
325339
progress: bool = True,
326340
**kwargs,
327341
) -> None:
328-
super().__init__(init_features=init_features, growth_rate=growth_rate, block_config=block_config, **kwargs)
342+
super().__init__(
343+
spatial_dims=spatial_dims,
344+
in_channels=in_channels,
345+
out_channels=out_channels,
346+
init_features=init_features,
347+
growth_rate=growth_rate,
348+
block_config=block_config,
349+
**kwargs,
350+
)
329351
if pretrained:
330-
if kwargs["spatial_dims"] > 2:
352+
if spatial_dims > 2:
331353
raise NotImplementedError(
332354
"Parameter `spatial_dims` is > 2 ; currently PyTorch Hub does not"
333355
"provide pretrained models for more than two spatial dimensions."
@@ -340,16 +362,27 @@ class DenseNet201(DenseNet):
340362

341363
def __init__(
342364
self,
365+
spatial_dims: int,
366+
in_channels: int,
367+
out_channels: int,
343368
init_features: int = 64,
344369
growth_rate: int = 32,
345370
block_config: Sequence[int] = (6, 12, 48, 32),
346371
pretrained: bool = False,
347372
progress: bool = True,
348373
**kwargs,
349374
) -> None:
350-
super().__init__(init_features=init_features, growth_rate=growth_rate, block_config=block_config, **kwargs)
375+
super().__init__(
376+
spatial_dims=spatial_dims,
377+
in_channels=in_channels,
378+
out_channels=out_channels,
379+
init_features=init_features,
380+
growth_rate=growth_rate,
381+
block_config=block_config,
382+
**kwargs,
383+
)
351384
if pretrained:
352-
if kwargs["spatial_dims"] > 2:
385+
if spatial_dims > 2:
353386
raise NotImplementedError(
354387
"Parameter `spatial_dims` is > 2 ; currently PyTorch Hub does not"
355388
"provide pretrained models for more than two spatial dimensions."
@@ -362,14 +395,25 @@ class DenseNet264(DenseNet):
362395

363396
def __init__(
364397
self,
398+
spatial_dims: int,
399+
in_channels: int,
400+
out_channels: int,
365401
init_features: int = 64,
366402
growth_rate: int = 32,
367403
block_config: Sequence[int] = (6, 12, 64, 48),
368404
pretrained: bool = False,
369405
progress: bool = True,
370406
**kwargs,
371407
) -> None:
372-
super().__init__(init_features=init_features, growth_rate=growth_rate, block_config=block_config, **kwargs)
408+
super().__init__(
409+
spatial_dims=spatial_dims,
410+
in_channels=in_channels,
411+
out_channels=out_channels,
412+
init_features=init_features,
413+
growth_rate=growth_rate,
414+
block_config=block_config,
415+
**kwargs,
416+
)
373417
if pretrained:
374418
raise NotImplementedError("Currently PyTorch Hub does not provide densenet264 pretrained models.")
375419

0 commit comments

Comments
 (0)