Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,11 @@ Nets
.. autoclass:: ResNet
:members:

`ResNetFeatures`
~~~~~~~~~~~~~~~~
.. autoclass:: ResNetFeatures
:members:

`SENet`
~~~~~~~
.. autoclass:: SENet
Expand Down
2 changes: 2 additions & 0 deletions monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
ResNet,
ResNetBlock,
ResNetBottleneck,
ResNetEncoder,
ResNetFeatures,
get_medicalnet_pretrained_resnet_args,
get_pretrained_resnet_medicalnet,
resnet10,
Expand Down
13 changes: 8 additions & 5 deletions monai/networks/nets/flexible_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from monai.networks.layers.utils import get_act_layer
from monai.networks.nets import EfficientNetEncoder
from monai.networks.nets.basic_unet import UpCat
from monai.networks.nets.resnet import ResNetEncoder
from monai.utils import InterpolateMode, optional_import

__all__ = ["FlexibleUNet", "FlexUNet", "FLEXUNET_BACKBONE", "FlexUNetEncoderRegister"]
Expand Down Expand Up @@ -78,6 +79,7 @@ def register_class(self, name: type[Any] | str):

FLEXUNET_BACKBONE = FlexUNetEncoderRegister()
FLEXUNET_BACKBONE.register_class(EfficientNetEncoder)
FLEXUNET_BACKBONE.register_class(ResNetEncoder)


class UNetDecoder(nn.Module):
Expand Down Expand Up @@ -238,7 +240,7 @@ def __init__(
) -> None:
"""
A flexible implement of UNet, in which the backbone/encoder can be replaced with
any efficient network. Currently the input must have a 2 or 3 spatial dimension
any efficient or residual network. Currently the input must have a 2 or 3 spatial dimension
and the spatial size of each dimension must be a multiple of 32 if is_pad parameter
is False.
Please notice each output of backbone must be 2x downsample in spatial dimension
Expand All @@ -248,10 +250,11 @@ def __init__(
Args:
in_channels: number of input channels.
out_channels: number of output channels.
backbone: name of backbones to initialize, only support efficientnet right now,
can be from [efficientnet-b0,..., efficientnet-b8, efficientnet-l2].
pretrained: whether to initialize pretrained ImageNet weights, only available
for spatial_dims=2 and batch norm is used, default to False.
backbone: name of backbones to initialize, only support efficientnet and resnet right now,
can be from [efficientnet-b0, ..., efficientnet-b8, efficientnet-l2, resnet10, ..., resnet200].
pretrained: whether to initialize pretrained weights. ImageNet weights are available for efficient networks
if spatial_dims=2 and batch norm is used. MedicalNet weights are available for residual networks
if spatial_dims=3 and in_channels=1. Default to False.
decoder_channels: number of output channels for all feature maps in decoder.
`len(decoder_channels)` should equal to `len(encoder_channels) - 1`,default
to (256, 128, 64, 32, 16).
Expand Down
145 changes: 143 additions & 2 deletions monai/networks/nets/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch
import torch.nn as nn

from monai.networks.blocks.encoder import BaseEncoder
from monai.networks.layers.factories import Conv, Norm, Pool
from monai.networks.layers.utils import get_pool_layer
from monai.utils import ensure_tuple_rep
Expand All @@ -45,6 +46,19 @@
"resnet200",
]


resnet_params = {
# model_name: (block, layers, shortcut_type, bias_downsample, datasets23)
"resnet10": ("basic", [1, 1, 1, 1], "B", False, True),
"resnet18": ("basic", [2, 2, 2, 2], "A", True, True),
"resnet34": ("basic", [3, 4, 6, 3], "A", True, True),
"resnet50": ("bottleneck", [3, 4, 6, 3], "B", False, True),
"resnet101": ("bottleneck", [3, 4, 23, 3], "B", False, False),
"resnet152": ("bottleneck", [3, 8, 36, 3], "B", False, False),
"resnet200": ("bottleneck", [3, 24, 36, 3], "B", False, False),
}


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -335,6 +349,120 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


class ResNetFeatures(ResNet):

def __init__(self, model_name: str, pretrained: bool = True, spatial_dims: int = 3, in_channels: int = 1) -> None:
"""Initialize resnet18 to resnet200 models as a backbone, the backbone can be used as an encoder for
segmentation and objection models.

Compared with the class `ResNet`, the only different place is the forward function.

Args:
model_name: name of model to initialize, can be from [resnet10, ..., resnet200].
pretrained: whether to initialize pretrained MedicalNet weights,
only available for spatial_dims=3 and in_channels=1.
spatial_dims: number of spatial dimensions of the input image.
in_channels: number of input channels for first convolutional layer.
"""
if model_name not in resnet_params:
model_name_string = ", ".join(resnet_params.keys())
raise ValueError(f"invalid model_name {model_name} found, must be one of {model_name_string} ")

block, layers, shortcut_type, bias_downsample, datasets23 = resnet_params[model_name]

super().__init__(
block=block,
layers=layers,
block_inplanes=get_inplanes(),
spatial_dims=spatial_dims,
n_input_channels=in_channels,
conv1_t_stride=2,
shortcut_type=shortcut_type,
feed_forward=False,
bias_downsample=bias_downsample,
)
if pretrained:
if spatial_dims == 3 and in_channels == 1:
_load_state_dict(self, model_name, datasets23=datasets23)
else:
raise ValueError("Pretrained resnet models are only available for in_channels=1 and spatial_dims=3.")

def forward(self, inputs: torch.Tensor):
"""
Args:
inputs: input should have spatially N dimensions
``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``, N is defined by `dimensions`.

Returns:
a list of torch Tensors.
"""
x = self.conv1(inputs)
x = self.bn1(x)
x = self.relu(x)

features = []
features.append(x)

if not self.no_max_pool:
x = self.maxpool(x)

x = self.layer1(x)
features.append(x)

x = self.layer2(x)
features.append(x)

x = self.layer3(x)
features.append(x)

x = self.layer4(x)
features.append(x)

return features


class ResNetEncoder(ResNetFeatures, BaseEncoder):
"""Wrap the original resnet to an encoder for flexible-unet."""

backbone_names = ["resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet200"]

@classmethod
def get_encoder_parameters(cls) -> list[dict]:
"""Get the initialization parameter for resnet backbones."""
parameter_list = []
for backbone_name in cls.backbone_names:
parameter_list.append(
{"model_name": backbone_name, "pretrained": True, "spatial_dims": 3, "in_channels": 1}
)
return parameter_list

@classmethod
def num_channels_per_output(cls) -> list[tuple[int, ...]]:
"""Get number of resnet backbone output feature maps channel."""
return [
(64, 64, 128, 256, 512),
(64, 64, 128, 256, 512),
(64, 64, 128, 256, 512),
(64, 256, 512, 1024, 2048),
(64, 256, 512, 1024, 2048),
(64, 256, 512, 1024, 2048),
(64, 256, 512, 1024, 2048),
]

@classmethod
def num_outputs(cls) -> list[int]:
"""Get number of resnet backbone output feature maps.

Since every backbone contains the same 5 output feature maps, the number list should be `[5] * 7`.
"""
return [5] * 7

@classmethod
def get_encoder_names(cls) -> list[str]:
"""Get names of resnet backbones."""
return cls.backbone_names


def _resnet(
arch: str,
block: type[ResNetBlock | ResNetBottleneck],
Expand Down Expand Up @@ -477,7 +605,7 @@ def resnet200(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->

def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", datasets23: bool = True):
"""
Donwlad resnet pretrained weights from https://huggingface.co/TencentMedicalNet
Download resnet pretrained weights from https://huggingface.co/TencentMedicalNet

Args:
resnet_depth: depth of the pretrained model. Supported values are 10, 18, 34, 50, 101, 152 and 200
Expand Down Expand Up @@ -533,11 +661,24 @@ def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", dat
def get_medicalnet_pretrained_resnet_args(resnet_depth: int):
"""
Return correct shortcut_type and bias_downsample
for pretrained MedicalNet weights according to resnet depth
for pretrained MedicalNet weights according to resnet depth.
"""
# After testing
# False: 10, 50, 101, 152, 200
# Any: 18, 34
bias_downsample = -1 if resnet_depth in [18, 34] else 0 # 18, 10, 34
shortcut_type = "A" if resnet_depth in [18, 34] else "B"
return bias_downsample, shortcut_type


def _load_state_dict(model: nn.Module, model_name: str, datasets23: bool = True) -> None:
search_res = re.search(r"resnet(\d+)", model_name)
if search_res:
resnet_depth = int(search_res.group(1))
datasets23 = model_name.endswith("_23datasets")
else:
raise ValueError("model_name argument should contain resnet depth. Example: resnet18 or resnet18_23datasets.")

model_state_dict = get_pretrained_resnet_medicalnet(resnet_depth, device="cpu", datasets23=datasets23)
model_state_dict = {key.replace("module.", ""): value for key, value in model_state_dict.items()}
model.load_state_dict(model_state_dict)
Loading