Skip to content

Commit

Permalink
Merge branch 'refactoring' into cb_extract_validator
Browse files Browse the repository at this point in the history
  • Loading branch information
thibaultdvx authored Oct 18, 2024
2 parents 4521ca4 + 1b49939 commit 6b0b178
Show file tree
Hide file tree
Showing 79 changed files with 8,837 additions and 2,027 deletions.
4 changes: 2 additions & 2 deletions clinicadl/monai_networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .config import ImplementedNetworks, NetworkConfig, create_network_config
from .factory import get_network
from .config import ImplementedNetworks, NetworkConfig
from .factory import get_network, get_network_from_config
3 changes: 1 addition & 2 deletions clinicadl/monai_networks/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from .base import NetworkConfig
from .base import ImplementedNetworks, NetworkConfig, NetworkType
from .factory import create_network_config
from .utils.enum import ImplementedNetworks
92 changes: 24 additions & 68 deletions clinicadl/monai_networks/config/autoencoder.py
Original file line number Diff line number Diff line change
@@ -1,89 +1,45 @@
from typing import Optional, Tuple, Union
from typing import Optional, Sequence, Union

from pydantic import (
NonNegativeInt,
PositiveInt,
computed_field,
model_validator,
)
from pydantic import PositiveInt, computed_field

from clinicadl.monai_networks.nn.layers.utils import (
ActivationParameters,
UnpoolingMode,
)
from clinicadl.utils.factories import DefaultFromLibrary

from .base import VaryingDepthNetworkConfig
from .utils.enum import ImplementedNetworks

__all__ = ["AutoEncoderConfig", "VarAutoEncoderConfig"]

from .base import ImplementedNetworks, NetworkConfig
from .conv_encoder import ConvEncoderOptions
from .mlp import MLPOptions

class AutoEncoderConfig(VaryingDepthNetworkConfig):
"""Config class for autoencoders."""

spatial_dims: PositiveInt
in_channels: PositiveInt
out_channels: PositiveInt
class AutoEncoderConfig(NetworkConfig):
"""Config class for AutoEncoder."""

inter_channels: Union[
Optional[Tuple[PositiveInt, ...]], DefaultFromLibrary
] = DefaultFromLibrary.YES
inter_dilations: Union[
Optional[Tuple[PositiveInt, ...]], DefaultFromLibrary
in_shape: Sequence[PositiveInt]
latent_size: PositiveInt
conv_args: ConvEncoderOptions
mlp_args: Union[Optional[MLPOptions], DefaultFromLibrary] = DefaultFromLibrary.YES
out_channels: Union[
Optional[PositiveInt], DefaultFromLibrary
] = DefaultFromLibrary.YES
num_inter_units: Union[NonNegativeInt, DefaultFromLibrary] = DefaultFromLibrary.YES
padding: Union[
Optional[Union[PositiveInt, Tuple[PositiveInt, ...]]], DefaultFromLibrary
output_act: Union[
Optional[ActivationParameters], DefaultFromLibrary
] = DefaultFromLibrary.YES
unpooling_mode: Union[UnpoolingMode, DefaultFromLibrary] = DefaultFromLibrary.YES

@computed_field
@property
def network(self) -> ImplementedNetworks:
def name(self) -> ImplementedNetworks:
"""The name of the network."""
return ImplementedNetworks.AE

@computed_field
@property
def dim(self) -> int:
"""Dimension of the images."""
return self.spatial_dims

@model_validator(mode="after")
def model_validator(self):
"""Checks coherence between parameters."""
if self.padding != DefaultFromLibrary.YES:
assert self._check_dimensions(
self.padding
), f"You must passed an int or a sequence of {self.dim} ints (the dimensionality of your images) for padding. You passed {self.padding}."
if isinstance(self.inter_channels, tuple) and isinstance(
self.inter_dilations, tuple
):
assert len(self.inter_channels) == len(
self.inter_dilations
), "inter_channels and inter_dilations muust have the same size."
elif isinstance(self.inter_dilations, tuple) and not isinstance(
self.inter_channels, tuple
):
raise ValueError(
"You passed inter_dilations but didn't pass inter_channels."
)
return self


class VarAutoEncoderConfig(AutoEncoderConfig):
"""Config class for variational autoencoders."""

in_shape: Tuple[PositiveInt, ...]
in_channels: Optional[int] = None
latent_size: PositiveInt
use_sigmoid: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES
class VAEConfig(AutoEncoderConfig):
"""Config class for Variational AutoEncoder."""

@computed_field
@property
def network(self) -> ImplementedNetworks:
def name(self) -> ImplementedNetworks:
"""The name of the network."""
return ImplementedNetworks.VAE

@model_validator(mode="after")
def model_validator_bis(self):
"""Checks coherence between parameters."""
assert (
len(self.in_shape[1:]) == self.spatial_dims
), f"You passed {self.spatial_dims} for spatial_dims, but in_shape suggests {len(self.in_shape[1:])} spatial dimensions."
210 changes: 70 additions & 140 deletions clinicadl/monai_networks/config/base.py
Original file line number Diff line number Diff line change
@@ -1,168 +1,98 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, Optional, Tuple, Union

from pydantic import (
BaseModel,
ConfigDict,
NonNegativeFloat,
NonNegativeInt,
PositiveInt,
computed_field,
field_validator,
model_validator,
)
from typing import Optional, Union

from pydantic import BaseModel, ConfigDict, PositiveInt, computed_field

from clinicadl.monai_networks.nn.layers.utils import ActivationParameters
from clinicadl.utils.factories import DefaultFromLibrary

from .utils.enum import (
ImplementedActFunctions,
ImplementedNetworks,
ImplementedNormLayers,
)

class ImplementedNetworks(str, Enum):
"""Implemented neural networks in ClinicaDL."""

MLP = "MLP"
CONV_ENCODER = "ConvEncoder"
CONV_DECODER = "ConvDecoder"
CNN = "CNN"
GENERATOR = "Generator"
AE = "AutoEncoder"
VAE = "VAE"
DENSENET = "DenseNet"
DENSENET_121 = "DenseNet-121"
DENSENET_161 = "DenseNet-161"
DENSENET_169 = "DenseNet-169"
DENSENET_201 = "DenseNet-201"
RESNET = "ResNet"
RESNET_18 = "ResNet-18"
RESNET_34 = "ResNet-34"
RESNET_50 = "ResNet-50"
RESNET_101 = "ResNet-101"
RESNET_152 = "ResNet-152"
SE_RESNET = "SEResNet"
SE_RESNET_50 = "SEResNet-50"
SE_RESNET_101 = "SEResNet-101"
SE_RESNET_152 = "SEResNet-152"
UNET = "UNet"
ATT_UNET = "AttentionUNet"
VIT = "ViT"
VIT_B_16 = "ViT-B/16"
VIT_B_32 = "ViT-B/32"
VIT_L_16 = "ViT-L/16"
VIT_L_32 = "ViT-L/32"

@classmethod
def _missing_(cls, value):
raise ValueError(
f"{value} is not implemented. Implemented neural networks are: "
+ ", ".join([repr(m.value) for m in cls])
)


class NetworkType(str, Enum):
"""
Useful to know where to look for the network.
See :py:func:`clinicadl.monai_networks.factory.get_network`
"""

CUSTOM = "custom" # our own networks
RESNET = "sota-ResNet"
DENSENET = "sota-DenseNet"
SE_RESNET = "sota-SEResNet"
VIT = "sota-ViT"


class NetworkConfig(BaseModel, ABC):
"""Base config class to configure neural networks."""

kernel_size: Union[
PositiveInt, Tuple[PositiveInt, ...], DefaultFromLibrary
] = DefaultFromLibrary.YES
up_kernel_size: Union[
PositiveInt, Tuple[PositiveInt, ...], DefaultFromLibrary
] = DefaultFromLibrary.YES
num_res_units: Union[NonNegativeInt, DefaultFromLibrary] = DefaultFromLibrary.YES
act: Union[
ImplementedActFunctions,
Tuple[ImplementedActFunctions, Dict[str, Any]],
DefaultFromLibrary,
] = DefaultFromLibrary.YES
norm: Union[
ImplementedNormLayers,
Tuple[ImplementedNormLayers, Dict[str, Any]],
DefaultFromLibrary,
] = DefaultFromLibrary.YES
bias: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES
adn_ordering: Union[Optional[str], DefaultFromLibrary] = DefaultFromLibrary.YES
# pydantic config
model_config = ConfigDict(
validate_assignment=True,
use_enum_values=True,
validate_default=True,
protected_namespaces=(),
)

@computed_field
@property
@abstractmethod
def network(self) -> ImplementedNetworks:
def name(self) -> ImplementedNetworks:
"""The name of the network."""

@computed_field
@property
@abstractmethod
def dim(self) -> int:
"""Dimension of the images."""
def _type(self) -> NetworkType:
"""
To know where to look for the network.
Default to 'custom'.
"""
return NetworkType.CUSTOM

@classmethod
def base_validator_dropout(cls, v):
"""Checks that dropout is between 0 and 1."""
if isinstance(v, float):
assert (
0 <= v <= 1
), f"dropout must be between 0 and 1 but it has been set to {v}."
return v

@field_validator("kernel_size", "up_kernel_size")
@classmethod
def base_is_odd(cls, value, field):
"""Checks if a field is odd."""
if value != DefaultFromLibrary.YES:
if isinstance(value, int):
value_ = (value,)
else:
value_ = value
for v in value_:
assert v % 2 == 1, f"{field.field_name} must be odd."
return value

@field_validator("adn_ordering", mode="after")
@classmethod
def base_adn_validator(cls, v):
"""Checks ADN sequence."""
if v != DefaultFromLibrary.YES:
for letter in v:
assert (
letter in {"A", "D", "N"}
), f"adn_ordering must be composed by 'A', 'D' or/and 'N'. You passed {letter}."
assert len(v) == len(
set(v)
), "adn_ordering cannot contain duplicated letter."

return v

@classmethod
def base_at_least_2d(cls, v, ctx):
"""Checks that a tuple has at least a length of two."""
if isinstance(v, tuple):
assert (
len(v) >= 2
), f"{ctx.field_name} should have at least two dimensions (with the first one for the channel)."
return v

@model_validator(mode="after")
def base_model_validator(self):
"""Checks coherence between parameters."""
if self.kernel_size != DefaultFromLibrary.YES:
assert self._check_dimensions(
self.kernel_size
), f"You must passed an int or a sequence of {self.dim} ints (the dimensionality of your images) for kernel_size. You passed {self.kernel_size}."
if self.up_kernel_size != DefaultFromLibrary.YES:
assert self._check_dimensions(
self.up_kernel_size
), f"You must passed an int or a sequence of {self.dim} ints (the dimensionality of your images) for up_kernel_size. You passed {self.up_kernel_size}."
return self

def _check_dimensions(
self,
value: Union[float, Tuple[float, ...]],
) -> bool:
"""Checks if a tuple has the right dimension."""
if isinstance(value, tuple):
return len(value) == self.dim
return True


class VaryingDepthNetworkConfig(NetworkConfig, ABC):
"""
Base config class to configure neural networks.
More precisely, we refer to MONAI's networks with 'channels' and 'strides' parameters.
"""
class PreTrainedConfig(NetworkConfig):
"""Base config class for SOTA networks."""

channels: Tuple[PositiveInt, ...]
strides: Tuple[Union[PositiveInt, Tuple[PositiveInt, ...]], ...]
dropout: Union[
Optional[NonNegativeFloat], DefaultFromLibrary
num_outputs: Optional[PositiveInt]
output_act: Union[
Optional[ActivationParameters], DefaultFromLibrary
] = DefaultFromLibrary.YES

@field_validator("dropout")
@classmethod
def validator_dropout(cls, v):
"""Checks that dropout is between 0 and 1."""
return cls.base_validator_dropout(v)

@model_validator(mode="after")
def channels_strides_validator(self):
"""Checks coherence between parameters."""
n_layers = len(self.channels)
assert (
len(self.strides) == n_layers
), f"There are {n_layers} layers but you passed {len(self.strides)} strides."
for s in self.strides:
assert self._check_dimensions(
s
), f"You must passed an int or a sequence of {self.dim} ints (the dimensionality of your images) for strides. You passed {s}."

return self
pretrained: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES
Loading

0 comments on commit 6b0b178

Please sign in to comment.