Skip to content

Commit 6f5cea3

Browse files
authored
6676 port losses from monai-generative (#6729)
Work towards addressing issue #6676 ### Description This PR ports spectral, perceptual and patch adversial losses from [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels). ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham <markgraham539@gmail.com>
1 parent d6bafc9 commit 6f5cea3

File tree

12 files changed

+948
-3
lines changed

12 files changed

+948
-3
lines changed

docs/source/installation.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,10 +254,11 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is
254254
- The options are
255255

256256
```
257-
[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr]
257+
[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips]
258258
```
259259

260-
which correspond to `nibabel`, `scikit-image`, `scipy`, `pillow`, `tensorboard`,
261-
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, and `zarr` respectively.
260+
which correspond to `nibabel`, `scikit-image`,`scipy`, `pillow`, `tensorboard`,
261+
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr` and `lpips` respectively.
262+
262263

263264
- `pip install 'monai[all]'` installs all the optional dependencies.

docs/source/losses.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,21 @@ Reconstruction Losses
9999
.. autoclass:: monai.losses.ssim_loss.SSIMLoss
100100
:members:
101101

102+
`PatchAdversarialLoss`
103+
~~~~~~~~~~~~~~~~~~~~~~
104+
.. autoclass:: PatchAdversarialLoss
105+
:members:
106+
107+
`PerceptualLoss`
108+
~~~~~~~~~~~~~~~~~
109+
.. autoclass:: PerceptualLoss
110+
:members:
111+
112+
`JukeboxLoss`
113+
~~~~~~~~~~~~~~
114+
.. autoclass:: JukeboxLoss
115+
:members:
116+
102117

103118
Loss Wrappers
104119
-------------

monai/losses/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from __future__ import annotations
1313

14+
from .adversarial_loss import PatchAdversarialLoss
1415
from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss
1516
from .contrastive import ContrastiveLoss
1617
from .deform import BendingEnergyLoss
@@ -34,7 +35,9 @@
3435
from .giou_loss import BoxGIoULoss, giou
3536
from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss
3637
from .multi_scale import MultiScaleLoss
38+
from .perceptual import PerceptualLoss
3739
from .spatial_mask import MaskedLoss
40+
from .spectral_loss import JukeboxLoss
3841
from .ssim_loss import SSIMLoss
3942
from .tversky import TverskyLoss
4043
from .unified_focal_loss import AsymmetricUnifiedFocalLoss

monai/losses/adversarial_loss.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import warnings
15+
16+
import torch
17+
from torch.nn.modules.loss import _Loss
18+
19+
from monai.networks.layers.utils import get_act_layer
20+
from monai.utils import LossReduction
21+
from monai.utils.enums import StrEnum
22+
23+
24+
class AdversarialCriterions(StrEnum):
25+
BCE = "bce"
26+
HINGE = "hinge"
27+
LEAST_SQUARE = "least_squares"
28+
29+
30+
class PatchAdversarialLoss(_Loss):
31+
"""
32+
Calculates an adversarial loss on a Patch Discriminator or a Multi-scale Patch Discriminator.
33+
Warning: due to the possibility of using different criterions, the output of the discrimination
34+
mustn't be passed to a final activation layer. That is taken care of internally within the loss.
35+
36+
Args:
37+
reduction: {``"none"``, ``"mean"``, ``"sum"``}
38+
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
39+
40+
- ``"none"``: no reduction will be applied.
41+
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
42+
- ``"sum"``: the output will be summed.
43+
44+
criterion: which criterion (hinge, least_squares or bce) you want to use on the discriminators outputs.
45+
Depending on the criterion, a different activation layer will be used. Make sure you don't run the outputs
46+
through an activation layer prior to calling the loss.
47+
no_activation_leastsq: if True, the activation layer in the case of least-squares is removed.
48+
"""
49+
50+
def __init__(
51+
self,
52+
reduction: LossReduction | str = LossReduction.MEAN,
53+
criterion: str = AdversarialCriterions.LEAST_SQUARE,
54+
no_activation_leastsq: bool = False,
55+
) -> None:
56+
super().__init__(reduction=LossReduction(reduction))
57+
58+
if criterion.lower() not in list(AdversarialCriterions):
59+
raise ValueError(
60+
"Unrecognised criterion entered for Adversarial Loss. Must be one in: %s"
61+
% ", ".join(AdversarialCriterions)
62+
)
63+
64+
# Depending on the criterion, a different activation layer is used.
65+
self.real_label = 1.0
66+
self.fake_label = 0.0
67+
self.loss_fct: _Loss
68+
if criterion == AdversarialCriterions.BCE:
69+
self.activation = get_act_layer("SIGMOID")
70+
self.loss_fct = torch.nn.BCELoss(reduction=reduction)
71+
elif criterion == AdversarialCriterions.HINGE:
72+
self.activation = get_act_layer("TANH")
73+
self.fake_label = -1.0
74+
elif criterion == AdversarialCriterions.LEAST_SQUARE:
75+
if no_activation_leastsq:
76+
self.activation = None
77+
else:
78+
self.activation = get_act_layer(name=("LEAKYRELU", {"negative_slope": 0.05}))
79+
self.loss_fct = torch.nn.MSELoss(reduction=reduction)
80+
81+
self.criterion = criterion
82+
self.reduction = reduction
83+
84+
def get_target_tensor(self, input: torch.Tensor, target_is_real: bool) -> torch.Tensor:
85+
"""
86+
Gets the ground truth tensor for the discriminator depending on whether the input is real or fake.
87+
88+
Args:
89+
input: input tensor from the discriminator (output of discriminator, or output of one of the multi-scale
90+
discriminator). This is used to match the shape.
91+
target_is_real: whether the input is real or wannabe-real (1s) or fake (0s).
92+
Returns:
93+
"""
94+
filling_label = self.real_label if target_is_real else self.fake_label
95+
label_tensor = torch.tensor(1).fill_(filling_label).type(input.type()).to(input[0].device)
96+
label_tensor.requires_grad_(False)
97+
return label_tensor.expand_as(input)
98+
99+
def get_zero_tensor(self, input: torch.Tensor) -> torch.Tensor:
100+
"""
101+
Gets a zero tensor.
102+
103+
Args:
104+
input: tensor which shape you want the zeros tensor to correspond to.
105+
Returns:
106+
"""
107+
108+
zero_label_tensor = torch.tensor(0).type(input[0].type()).to(input[0].device)
109+
zero_label_tensor.requires_grad_(False)
110+
return zero_label_tensor.expand_as(input)
111+
112+
def forward(
113+
self, input: torch.Tensor | list, target_is_real: bool, for_discriminator: bool
114+
) -> torch.Tensor | list[torch.Tensor]:
115+
"""
116+
117+
Args:
118+
input: output of Multi-Scale Patch Discriminator or Patch Discriminator; being a list of tensors
119+
or a tensor; they shouldn't have gone through an activation layer.
120+
target_is_real: whereas the input corresponds to discriminator output for real or fake images
121+
for_discriminator: whereas this is being calculated for discriminator or generator loss. In the last
122+
case, target_is_real is set to True, as the generator wants the input to be dimmed as real.
123+
Returns: if reduction is None, returns a list with the loss tensors of each discriminator if multi-scale
124+
discriminator is active, or the loss tensor if there is just one discriminator. Otherwise, it returns the
125+
summed or mean loss over the tensor and discriminator/s.
126+
127+
"""
128+
129+
if not for_discriminator and not target_is_real:
130+
target_is_real = True # With generator, we always want this to be true!
131+
warnings.warn(
132+
"Variable target_is_real has been set to False, but for_discriminator is set"
133+
"to False. To optimise a generator, target_is_real must be set to True."
134+
)
135+
136+
if type(input) is not list:
137+
input = [input]
138+
target_ = []
139+
for _, disc_out in enumerate(input):
140+
if self.criterion != AdversarialCriterions.HINGE:
141+
target_.append(self.get_target_tensor(disc_out, target_is_real))
142+
else:
143+
target_.append(self.get_zero_tensor(disc_out))
144+
145+
# Loss calculation
146+
loss_list = []
147+
for disc_ind, disc_out in enumerate(input):
148+
if self.activation is not None:
149+
disc_out = self.activation(disc_out)
150+
if self.criterion == AdversarialCriterions.HINGE and not target_is_real:
151+
loss_ = self._forward_single(-disc_out, target_[disc_ind])
152+
else:
153+
loss_ = self._forward_single(disc_out, target_[disc_ind])
154+
loss_list.append(loss_)
155+
156+
loss: torch.Tensor | list[torch.Tensor]
157+
if loss_list is not None:
158+
if self.reduction == LossReduction.MEAN:
159+
loss = torch.mean(torch.stack(loss_list))
160+
elif self.reduction == LossReduction.SUM:
161+
loss = torch.sum(torch.stack(loss_list))
162+
else:
163+
loss = loss_list
164+
return loss
165+
166+
def _forward_single(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
167+
forward: torch.Tensor
168+
if self.criterion == AdversarialCriterions.BCE or self.criterion == AdversarialCriterions.LEAST_SQUARE:
169+
forward = self.loss_fct(input, target)
170+
elif self.criterion == AdversarialCriterions.HINGE:
171+
minval = torch.min(input - 1, self.get_zero_tensor(input))
172+
forward = -torch.mean(minval)
173+
return forward

0 commit comments

Comments
 (0)