Skip to content

Commit

Permalink
add transposed conv for unpooling
Browse files Browse the repository at this point in the history
  • Loading branch information
thibaultdvx committed Oct 16, 2024
1 parent 6a13dcb commit 1cd7597
Show file tree
Hide file tree
Showing 8 changed files with 285 additions and 93 deletions.
212 changes: 153 additions & 59 deletions clinicadl/monai_networks/nn/autoencoder.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
from collections.abc import Sequence
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

import numpy as np
import torch.nn as nn

from .cnn import CNN
from .conv_encoder import ConvEncoder
from .generator import Generator
from .layers.utils import ActivationParameters, UnpoolingLayer, UpsamplingMode
from .layers.utils import (
ActivationParameters,
PoolingLayer,
SingleLayerPoolingParameters,
SingleLayerUnpoolingParameters,
UnpoolingLayer,
UnpoolingMode,
)
from .mlp import MLP
from .utils import (
calculate_conv_out_shape,
calculate_convtranspose_out_shape,
calculate_pool_out_shape,
)


Expand All @@ -24,7 +32,8 @@ class AutoEncoder(nn.Sequential):
symmetrical network.
More precisely, to build the decoder, the order of the encoding layers is reverted, convolutions are
replaced by transposed convolutions and pooling layers are replaced by upsampling layers.
replaced by transposed convolutions and pooling layers are replaced by either upsampling or transposed
convolution layers.
Please note that the order of `Activation`, `Dropout` and `Normalization`, defined with the
argument `adn_ordering` in `conv_args`, is the same for the encoder and the decoder.
Expand Down Expand Up @@ -57,9 +66,18 @@ class AutoEncoder(nn.Sequential):
`relu`, `relu6`, `selu`, `sigmoid`, `softmax`, `tanh`}. Please refer to PyTorch's [activationfunctions]
(https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) to know the optional
arguments for each of them.
upsampling_mode : Union[str, UpsamplingMode] (optional, default=UpsamplingMode.NEAREST)
interpolation mode for upsampling (see: https://pytorch.org/docs/stable/generated/
torch.nn.Upsample.html).
unpooling_mode : Union[str, UnpoolingMode] (optional, default=UnpoolingMode.NEAREST)
type of unpooling. Can be either `"nearest"`, `"linear"`, `"bilinear"`, `"bicubic"`, `"trilinear"` or
`"convtranspose"`.\n
- `nearest`: unpooling is performed by upsampling with the :italic:`nearest` algorithm (see [PyTorch's Upsample layer]
(https://pytorch.org/docs/stable/generated/torch.nn.Upsample.html)).
- `linear`: unpooling is performed by upsampling with the :italic:`linear` algorithm. Only works with 1D images (excluding the
channel dimension).
- `bilinear`: unpooling is performed by upsampling with the :italic:`bilinear` algorithm. Only works with 2D images.
- `bicubic`: unpooling is performed by upsampling with the :italic:`bicubic` algorithm. Only works with 2D images.
- `trilinear`: unpooling is performed by upsampling with the :italic:`trilinear` algorithm. Only works with 3D images.
- `convtranspose`: unpooling is performed with a transposed convolution, whose parameters (kernel size, stride, etc.) are
computed to reverse the pooling operation.
Examples
--------
Expand All @@ -74,7 +92,7 @@ class AutoEncoder(nn.Sequential):
mlp_args={"hidden_channels": [32], "output_act": "relu"},
out_channels=2,
output_act="sigmoid",
upsampling_mode="bilinear",
unpooling_mode="bilinear",
)
AutoEncoder(
(encoder): CNN(
Expand Down Expand Up @@ -149,13 +167,14 @@ def __init__(
mlp_args: Optional[Dict[str, Any]] = None,
out_channels: Optional[int] = None,
output_act: Optional[ActivationParameters] = None,
upsampling_mode: Union[str, UpsamplingMode] = UpsamplingMode.NEAREST,
unpooling_mode: Union[str, UnpoolingMode] = UnpoolingMode.NEAREST,
) -> None:
super().__init__()
self.in_shape = in_shape
self.upsampling_mode = self._check_upsampling_mode(upsampling_mode)
self.unpooling_mode = self._check_unpooling_mode(unpooling_mode)
self.out_channels = out_channels if out_channels else self.in_shape[0]
self.output_act = output_act
self.spatial_dims = len(in_shape[1:])

self.encoder = CNN(
in_shape=self.in_shape,
Expand Down Expand Up @@ -194,28 +213,27 @@ def _invert_conv_args(
part of the decoder.
"""
if len(args["channels"]) == 0:
return {"channels": []}

args["channels"] = self._invert_list_arg(conv.channels[:-1]) + [
self.out_channels
]
args["channels"] = []
else:
args["channels"] = self._invert_list_arg(conv.channels[:-1]) + [
self.out_channels
]
args["kernel_size"] = self._invert_list_arg(conv.kernel_size)
args["stride"] = self._invert_list_arg(conv.stride)
args["padding"] = self._invert_list_arg(conv.padding)
args["dilation"] = self._invert_list_arg(conv.dilation)
args["output_padding"] = self._get_output_padding_list(conv)
args["padding"], args["output_padding"] = self._get_paddings_list(conv)

args["unpooling_indices"] = (
conv.n_layers - np.array(conv.pooling_indices) - 2
).astype(int)
args["unpooling"] = []
size_before_pools = [
sizes_before_pooling = [
size
for size, (layer_name, _) in zip(conv.size_details, conv.named_children())
if "pool" in layer_name
]
for size in size_before_pools[::-1]:
args["unpooling"].append(self._invert_pooling_layer(size))
for size, pooling in zip(sizes_before_pooling[::-1], conv.pooling[::-1]):
args["unpooling"].append(self._invert_pooling_layer(size, pooling))

if "pooling" in args:
del args["pooling"]
Expand All @@ -234,21 +252,80 @@ def _invert_list_arg(cls, arg: Union[Any, List[Any]]) -> Union[Any, List[Any]]:
return list(arg[::-1]) if isinstance(arg, Sequence) else arg

def _invert_pooling_layer(
self, size_before_pool: Sequence[int]
) -> Tuple[UnpoolingLayer, Dict[str, Any]]:
self,
size_before_pool: Sequence[int],
pooling: SingleLayerPoolingParameters,
) -> SingleLayerUnpoolingParameters:
"""
Gets the unpooling layer (always upsample).
Gets the unpooling layer.
"""
return (
UnpoolingLayer.UPSAMPLE,
{"size": size_before_pool, "mode": self.upsampling_mode},
)
if self.unpooling_mode == UnpoolingMode.CONV_TRANS:
return (
UnpoolingLayer.CONV_TRANS,
self._invert_pooling_with_convtranspose(size_before_pool, pooling),
)
else:
return (
UnpoolingLayer.UPSAMPLE,
{"size": size_before_pool, "mode": self.unpooling_mode},
)

@classmethod
def _get_output_padding_list(cls, conv: ConvEncoder) -> List[Tuple[int, ...]]:
def _invert_pooling_with_convtranspose(
cls,
size_before_pool: Sequence[int],
pooling: SingleLayerPoolingParameters,
) -> Dict[str, Any]:
"""
Computes the arguments of the transposed convolution, based on the pooling layer.
"""
pooling_mode, pooling_args = pooling
if (
pooling_mode == PoolingLayer.ADAPT_AVG
or pooling_mode == PoolingLayer.ADAPT_MAX
):
input_size_np = np.array(size_before_pool)
output_size_np = np.array(pooling_args["output_size"])
stride_np = input_size_np // output_size_np # adaptive pooling formulas
kernel_size_np = (
input_size_np - (output_size_np - 1) * stride_np
) # adaptive pooling formulas
args = {
"kernel_size": tuple(int(k) for k in kernel_size_np),
"stride": tuple(int(s) for s in stride_np),
}
padding, output_padding = cls._find_convtranspose_paddings(
pooling_mode,
size_before_pool,
output_size=pooling_args["output_size"],
**args,
)

elif pooling_mode == PoolingLayer.MAX or pooling_mode == PoolingLayer.AVG:
if "stride" not in pooling_args:
pooling_args["stride"] = pooling_args["kernel_size"]
args = {
arg: value
for arg, value in pooling_args.items()
if arg in ["kernel_size", "stride", "padding", "dilation"]
}
padding, output_padding = cls._find_convtranspose_paddings(
pooling_mode,
size_before_pool,
**pooling_args,
)

args["padding"] = padding # pylint: disable=possibly-used-before-assignment
args["output_padding"] = output_padding # pylint: disable=possibly-used-before-assignment

return args

@classmethod
def _get_paddings_list(cls, conv: ConvEncoder) -> List[Tuple[int, ...]]:
"""
Finds output padding list.
"""
padding = []
output_padding = []
size_before_convs = [
size
Expand All @@ -262,61 +339,78 @@ def _get_output_padding_list(cls, conv: ConvEncoder) -> List[Tuple[int, ...]]:
conv.padding,
conv.dilation,
):
out_p = cls._find_output_padding(size, k, s, p, d)
p, out_p = cls._find_convtranspose_paddings(
"conv", size, kernel_size=k, stride=s, padding=p, dilation=d
)
padding.append(p)
output_padding.append(out_p)

return cls._invert_list_arg(output_padding)
return cls._invert_list_arg(padding), cls._invert_list_arg(output_padding)

@classmethod
def _find_output_padding(
def _find_convtranspose_paddings(
cls,
layer_type: Union[Literal["conv"], PoolingLayer],
in_shape: Union[Sequence[int], int],
kernel_size: Union[Sequence[int], int],
stride: Union[Sequence[int], int],
padding: Union[Sequence[int], int],
dilation: Union[Sequence[int], int],
) -> Tuple[int, ...]:
padding: Union[Sequence[int], int] = 0,
**kwargs,
) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
"""
Finds output padding necessary to recover the right image size after
Finds padding and output padding necessary to recover the right image size after
a transposed convolution.
"""
in_shape_np = np.atleast_1d(in_shape)
conv_out_shape = calculate_conv_out_shape(
in_shape_np, kernel_size, stride, padding, dilation
)
convt_out_shape = calculate_convtranspose_out_shape(
conv_out_shape, kernel_size, stride, padding, 0, dilation
)
output_padding = in_shape_np - np.atleast_1d(convt_out_shape)
if layer_type == "conv":
layer_out_shape = calculate_conv_out_shape(in_shape, **kwargs)
elif layer_type in list(PoolingLayer):
layer_out_shape = calculate_pool_out_shape(layer_type, in_shape, **kwargs)

convt_out_shape = calculate_convtranspose_out_shape(layer_out_shape, **kwargs) # pylint: disable=possibly-used-before-assignment
output_padding = np.atleast_1d(in_shape) - np.atleast_1d(convt_out_shape)

if (
output_padding < 0
).any(): # can happen with ceil_mode=True for maxpool. Then, add some padding
padding = np.atleast_1d(padding) * np.ones_like(
output_padding
) # to have the same shape as output_padding
padding[output_padding < 0] += np.maximum(np.abs(output_padding) // 2, 1)[
output_padding < 0
] # //2 because 2*padding pixels are removed

convt_out_shape = calculate_convtranspose_out_shape(
layer_out_shape, padding=padding, **kwargs
)
output_padding = np.atleast_1d(in_shape) - np.atleast_1d(convt_out_shape)
padding = tuple(int(s) for s in padding)

return tuple(int(s) for s in output_padding)
return padding, tuple(int(s) for s in output_padding)

def _check_upsampling_mode(
self, upsampling_mode: Union[str, UpsamplingMode]
) -> UpsamplingMode:
def _check_unpooling_mode(
self, unpooling_mode: Union[str, UnpoolingMode]
) -> UnpoolingMode:
"""
Checks consistency between data shape and upsampling mode.
Checks consistency between data shape and unpooling mode.
"""
upsampling_mode = UpsamplingMode(upsampling_mode)
if upsampling_mode == "linear" and len(self.in_shape) != 2:
unpooling_mode = UnpoolingMode(unpooling_mode)
if unpooling_mode == UnpoolingMode.LINEAR and len(self.in_shape) != 2:
raise ValueError(
f"upsampling mode `linear` only works with 2D data (counting the channel dimension). "
f"unpooling mode `linear` only works with 2D data (counting the channel dimension). "
f"Got in_shape={self.in_shape}, which is understood as {len(self.in_shape)}D data."
)
elif upsampling_mode == "bilinear" and len(self.in_shape) != 3:
elif unpooling_mode == UnpoolingMode.BILINEAR and len(self.in_shape) != 3:
raise ValueError(
f"upsampling mode `bilinear` only works with 3D data (counting the channel dimension). "
f"unpooling mode `bilinear` only works with 3D data (counting the channel dimension). "
f"Got in_shape={self.in_shape}, which is understood as {len(self.in_shape)}D data."
)
elif upsampling_mode == "bicubic" and len(self.in_shape) != 3:
elif unpooling_mode == UnpoolingMode.BICUBIC and len(self.in_shape) != 3:
raise ValueError(
f"upsampling mode `bicubic` only works with 3D data (counting the channel dimension). "
f"unpooling mode `bicubic` only works with 3D data (counting the channel dimension). "
f"Got in_shape={self.in_shape}, which is understood as {len(self.in_shape)}D data."
)
elif upsampling_mode == "trilinear" and len(self.in_shape) != 4:
elif unpooling_mode == UnpoolingMode.TRILINEAR and len(self.in_shape) != 4:
raise ValueError(
f"upsampling mode `trilinear` only works with 4D data (counting the channel dimension). "
f"unpooling mode `trilinear` only works with 4D data (counting the channel dimension). "
f"Got in_shape={self.in_shape}, which is understood as {len(self.in_shape)}D data."
)

return upsampling_mode
return unpooling_mode
18 changes: 15 additions & 3 deletions clinicadl/monai_networks/nn/conv_encoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Callable, Optional, Tuple
from typing import Callable, List, Optional, Tuple

import numpy as np
import torch.nn as nn
Expand Down Expand Up @@ -305,10 +305,20 @@ def _get_pool_layer(self, pooling: SingleLayerPoolingParameters) -> nn.Module:
Gets the parametrized pooling layer and updates the current output size.
"""
pool_layer = get_pool_layer(pooling, spatial_dims=self.spatial_dims)
old_size = self.final_size
self.final_size = lambda size: calculate_pool_out_shape(
pool_mode=pooling[0], in_shape=size, **pool_layer.__dict__
)

if (
self.final_size is not None
and (np.array(old_size) < np.array(self.final_size)).any()
):
raise ValueError(
f"You passed {pooling} as a pooling layer. But before this layer, the size of the image "
f"was {old_size}. So, pooling can't be performed."
)

return pool_layer

def _check_size(self) -> None:
Expand Down Expand Up @@ -354,7 +364,9 @@ def _check_single_pool_layer(
f"Got {args}"
)

def _check_pool_layers(self, pooling: PoolingParameters) -> PoolingParameters:
def _check_pool_layers(
self, pooling: PoolingParameters
) -> List[SingleLayerPoolingParameters]:
"""
Check argument pooling.
"""
Expand All @@ -371,7 +383,7 @@ def _check_pool_layers(self, pooling: PoolingParameters) -> PoolingParameters:
)
elif isinstance(pooling, tuple):
self._check_single_pool_layer(pooling)
pooling = (pooling,) * len(self.pooling_indices)
pooling = [pooling] * len(self.pooling_indices)
else:
raise ValueError(
f"pooling can be either None, a double (string, dictionary) or a list of such doubles. Got {pooling}"
Expand Down
2 changes: 1 addition & 1 deletion clinicadl/monai_networks/nn/layers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
NormLayer,
PoolingLayer,
UnpoolingLayer,
UpsamplingMode,
UnpoolingMode,
)
from .types import (
ActivationParameters,
Expand Down
5 changes: 3 additions & 2 deletions clinicadl/monai_networks/nn/layers/utils/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,12 @@ class ConvNormLayer(CaseInsensitiveEnum):
INSTANCE = "instance"


class UpsamplingMode(CaseInsensitiveEnum):
"""Supported interpolation mode for Upsampling in ClinicaDL."""
class UnpoolingMode(CaseInsensitiveEnum):
"""Supported unpooling mode for AutoEncoders in ClinicaDL."""

NEAREST = "nearest"
LINEAR = "linear"
BILINEAR = "bilinear"
BICUBIC = "bicubic"
TRILINEAR = "trilinear"
CONV_TRANS = "convtranspose"
Loading

0 comments on commit 1cd7597

Please sign in to comment.