Skip to content

Commit f0e52b8

Browse files
Merge branch 'add-bundleworkflow-arg' of github.com:yiheng-wang-nv/MONAI into add-bundleworkflow-arg
2 parents f793e80 + 7895b59 commit f0e52b8

File tree

12 files changed

+709
-0
lines changed

12 files changed

+709
-0
lines changed

docs/source/losses.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ Segmentation Losses
7373
.. autoclass:: ContrastiveLoss
7474
:members:
7575

76+
`BarlowTwinsLoss`
77+
~~~~~~~~~~~~~~~~~
78+
.. autoclass:: BarlowTwinsLoss
79+
:members:
80+
7681
`HausdorffDTLoss`
7782
~~~~~~~~~~~~~~~~~
7883
.. autoclass:: HausdorffDTLoss
@@ -134,6 +139,11 @@ Reconstruction Losses
134139
.. autoclass:: JukeboxLoss
135140
:members:
136141

142+
`SURELoss`
143+
~~~~~~~~~~
144+
.. autoclass:: SURELoss
145+
:members:
146+
137147

138148
Loss Wrappers
139149
-------------

docs/source/networks.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,11 @@ Layers
408408
.. autoclass:: LLTM
409409
:members:
410410

411+
`ConjugateGradient`
412+
~~~~~~~~~~~~~~~~~~~
413+
.. autoclass:: ConjugateGradient
414+
:members:
415+
411416
`Utilities`
412417
~~~~~~~~~~~
413418
.. automodule:: monai.networks.layers.convutils

monai/losses/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from __future__ import annotations
1313

1414
from .adversarial_loss import PatchAdversarialLoss
15+
from .barlow_twins import BarlowTwinsLoss
1516
from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss
1617
from .contrastive import ContrastiveLoss
1718
from .deform import BendingEnergyLoss, DiffusionLoss
@@ -40,5 +41,6 @@
4041
from .spatial_mask import MaskedLoss
4142
from .spectral_loss import JukeboxLoss
4243
from .ssim_loss import SSIMLoss
44+
from .sure_loss import SURELoss
4345
from .tversky import TverskyLoss
4446
from .unified_focal_loss import AsymmetricUnifiedFocalLoss

monai/losses/barlow_twins.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import torch
15+
from torch.nn.modules.loss import _Loss
16+
17+
18+
class BarlowTwinsLoss(_Loss):
19+
"""
20+
The Barlow Twins cost function takes the representations extracted by a neural network from two
21+
distorted views and seeks to make the cross-correlation matrix of the two representations tend
22+
towards identity. This encourages the neural network to learn similar representations with the least
23+
amount of redundancy. This cost function can be used in particular in multimodal learning to work on
24+
representations from two modalities. The most common use case is for unsupervised learning, where data
25+
augmentations are used to generate 2 distorted views of the same sample to force the encoder to
26+
extract useful features for downstream tasks.
27+
28+
Zbontar, Jure, et al. "Barlow Twins: Self-Supervised Learning via Redundancy Reduction" International
29+
conference on machine learning. PMLR, 2020. (http://proceedings.mlr.press/v139/zbontar21a/zbontar21a.pdf)
30+
31+
Adapted from:
32+
https://github.com/facebookresearch/barlowtwins
33+
34+
"""
35+
36+
def __init__(self, lambd: float = 5e-3) -> None:
37+
"""
38+
Args:
39+
lamb: Can be any float to handle the informativeness and invariance trade-off. Ideally set to 5e-3.
40+
41+
Raises:
42+
ValueError: When an input of dimension length > 2 is passed
43+
ValueError: When input and target are of different shapes
44+
ValueError: When batch size is less than or equal to 1
45+
46+
"""
47+
super().__init__()
48+
self.lambd = lambd
49+
50+
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
51+
"""
52+
Args:
53+
input: the shape should be B[F].
54+
target: the shape should be B[F].
55+
"""
56+
if len(target.shape) > 2 or len(input.shape) > 2:
57+
raise ValueError(
58+
f"Either target or input has dimensions greater than 2 where target "
59+
f"shape is ({target.shape}) and input shape is ({input.shape})"
60+
)
61+
62+
if target.shape != input.shape:
63+
raise ValueError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})")
64+
65+
if target.size(0) <= 1:
66+
raise ValueError(
67+
f"Batch size must be greater than 1 to compute Barlow Twins Loss, but got {target.size(0)}"
68+
)
69+
70+
lambd_tensor = torch.as_tensor(self.lambd).to(input.device)
71+
batch_size = input.shape[0]
72+
73+
# normalize input and target
74+
input_norm = (input - input.mean(0)) / input.std(0).add(1e-6)
75+
target_norm = (target - target.mean(0)) / target.std(0).add(1e-6)
76+
77+
# cross-correlation matrix
78+
c = torch.mm(input_norm.t(), target_norm) / batch_size # input_norm.t() is FxB, target_norm is BxF so c is FxF
79+
80+
# loss
81+
c_diff = (c - torch.eye(c.size(0), device=c.device)).pow_(2) # FxF
82+
c_diff[~torch.eye(c.size(0), device=c.device).bool()] *= lambd_tensor
83+
84+
return c_diff.sum()

monai/losses/sure_loss.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
from typing import Callable, Optional
15+
16+
import torch
17+
import torch.nn as nn
18+
from torch.nn.modules.loss import _Loss
19+
20+
21+
def complex_diff_abs_loss(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
22+
"""
23+
First compute the difference in the complex domain,
24+
then get the absolute value and take the mse
25+
26+
Args:
27+
x, y - B, 2, H, W real valued tensors representing complex numbers
28+
or B,1,H,W complex valued tensors
29+
Returns:
30+
l2_loss - scalar
31+
"""
32+
if not x.is_complex():
33+
x = torch.view_as_complex(x.permute(0, 2, 3, 1).contiguous())
34+
if not y.is_complex():
35+
y = torch.view_as_complex(y.permute(0, 2, 3, 1).contiguous())
36+
37+
diff = torch.abs(x - y)
38+
return nn.functional.mse_loss(diff, torch.zeros_like(diff), reduction="mean")
39+
40+
41+
def sure_loss_function(
42+
operator: Callable,
43+
x: torch.Tensor,
44+
y_pseudo_gt: torch.Tensor,
45+
y_ref: Optional[torch.Tensor] = None,
46+
eps: Optional[float] = -1.0,
47+
perturb_noise: Optional[torch.Tensor] = None,
48+
complex_input: Optional[bool] = False,
49+
) -> torch.Tensor:
50+
"""
51+
Args:
52+
operator (function): The operator function that takes in an input
53+
tensor x and returns an output tensor y. We will use this to compute
54+
the divergence. More specifically, we will perturb the input x by a
55+
small amount and compute the divergence between the perturbed output
56+
and the reference output
57+
58+
x (torch.Tensor): The input tensor of shape (B, C, H, W) to the
59+
operator. For complex input, the shape is (B, 2, H, W) aka C=2 real.
60+
For real input, the shape is (B, 1, H, W) real.
61+
62+
y_pseudo_gt (torch.Tensor): The pseudo ground truth tensor of shape
63+
(B, C, H, W) used to compute the L2 loss. For complex input, the shape is
64+
(B, 2, H, W) aka C=2 real. For real input, the shape is (B, 1, H, W)
65+
real.
66+
67+
y_ref (torch.Tensor, optional): The reference output tensor of shape
68+
(B, C, H, W) used to compute the divergence. Defaults to None. For
69+
complex input, the shape is (B, 2, H, W) aka C=2 real. For real input,
70+
the shape is (B, 1, H, W) real.
71+
72+
eps (float, optional): The perturbation scalar. Set to -1 to set it
73+
automatically estimated based on y_pseudo_gtk
74+
75+
perturb_noise (torch.Tensor, optional): The noise vector of shape (B, C, H, W).
76+
Defaults to None. For complex input, the shape is (B, 2, H, W) aka C=2 real.
77+
For real input, the shape is (B, 1, H, W) real.
78+
79+
complex_input(bool, optional): Whether the input is complex or not.
80+
Defaults to False.
81+
82+
Returns:
83+
sure_loss (torch.Tensor): The SURE loss scalar.
84+
"""
85+
# perturb input
86+
if perturb_noise is None:
87+
perturb_noise = torch.randn_like(x)
88+
if eps == -1.0:
89+
eps = float(torch.abs(y_pseudo_gt.max())) / 1000
90+
# get y_ref if not provided
91+
if y_ref is None:
92+
y_ref = operator(x)
93+
94+
# get perturbed output
95+
x_perturbed = x + eps * perturb_noise
96+
y_perturbed = operator(x_perturbed)
97+
# divergence
98+
divergence = torch.sum(1.0 / eps * torch.matmul(perturb_noise.permute(0, 1, 3, 2), y_perturbed - y_ref)) # type: ignore
99+
# l2 loss between y_ref, y_pseudo_gt
100+
if complex_input:
101+
l2_loss = complex_diff_abs_loss(y_ref, y_pseudo_gt)
102+
else:
103+
# real input
104+
l2_loss = nn.functional.mse_loss(y_ref, y_pseudo_gt, reduction="mean")
105+
106+
# sure loss
107+
sure_loss = l2_loss * divergence / (x.shape[0] * x.shape[2] * x.shape[3])
108+
return sure_loss
109+
110+
111+
class SURELoss(_Loss):
112+
"""
113+
Calculate the Stein's Unbiased Risk Estimator (SURE) loss for a given operator.
114+
115+
This is a differentiable loss function that can be used to train/guide an
116+
operator (e.g. neural network), where the pseudo ground truth is available
117+
but the reference ground truth is not. For example, in the MRI
118+
reconstruction, the pseudo ground truth is the zero-filled reconstruction
119+
and the reference ground truth is the fully sampled reconstruction. Often,
120+
the reference ground truth is not available due to the lack of fully sampled
121+
data.
122+
123+
The original SURE loss is proposed in [1]. The SURE loss used for guiding
124+
the diffusion model based MRI reconstruction is proposed in [2].
125+
126+
Reference
127+
128+
[1] Stein, C.M.: Estimation of the mean of a multivariate normal distribution. Annals of Statistics
129+
130+
[2] B. Ozturkler et al. SMRD: SURE-based Robust MRI Reconstruction with Diffusion Models.
131+
(https://arxiv.org/pdf/2310.01799.pdf)
132+
"""
133+
134+
def __init__(self, perturb_noise: Optional[torch.Tensor] = None, eps: Optional[float] = None) -> None:
135+
"""
136+
Args:
137+
perturb_noise (torch.Tensor, optional): The noise vector of shape
138+
(B, C, H, W). Defaults to None. For complex input, the shape is (B, 2, H, W) aka C=2 real.
139+
For real input, the shape is (B, 1, H, W) real.
140+
141+
eps (float, optional): The perturbation scalar. Defaults to None.
142+
"""
143+
super().__init__()
144+
self.perturb_noise = perturb_noise
145+
self.eps = eps
146+
147+
def forward(
148+
self,
149+
operator: Callable,
150+
x: torch.Tensor,
151+
y_pseudo_gt: torch.Tensor,
152+
y_ref: Optional[torch.Tensor] = None,
153+
complex_input: Optional[bool] = False,
154+
) -> torch.Tensor:
155+
"""
156+
Args:
157+
operator (function): The operator function that takes in an input
158+
tensor x and returns an output tensor y. We will use this to compute
159+
the divergence. More specifically, we will perturb the input x by a
160+
small amount and compute the divergence between the perturbed output
161+
and the reference output
162+
163+
x (torch.Tensor): The input tensor of shape (B, C, H, W) to the
164+
operator. C=1 or 2: For complex input, the shape is (B, 2, H, W) aka
165+
C=2 real. For real input, the shape is (B, 1, H, W) real.
166+
167+
y_pseudo_gt (torch.Tensor): The pseudo ground truth tensor of shape
168+
(B, C, H, W) used to compute the L2 loss. C=1 or 2: For complex
169+
input, the shape is (B, 2, H, W) aka C=2 real. For real input, the
170+
shape is (B, 1, H, W) real.
171+
172+
y_ref (torch.Tensor, optional): The reference output tensor of the
173+
same shape as y_pseudo_gt
174+
175+
Returns:
176+
sure_loss (torch.Tensor): The SURE loss scalar.
177+
"""
178+
179+
# check inputs shapes
180+
if x.dim() != 4:
181+
raise ValueError(f"Input tensor x should be 4D, got {x.dim()}.")
182+
if y_pseudo_gt.dim() != 4:
183+
raise ValueError(f"Input tensor y_pseudo_gt should be 4D, but got {y_pseudo_gt.dim()}.")
184+
if y_ref is not None and y_ref.dim() != 4:
185+
raise ValueError(f"Input tensor y_ref should be 4D, but got {y_ref.dim()}.")
186+
if x.shape != y_pseudo_gt.shape:
187+
raise ValueError(
188+
f"Input tensor x and y_pseudo_gt should have the same shape, but got x shape {x.shape}, "
189+
f"y_pseudo_gt shape {y_pseudo_gt.shape}."
190+
)
191+
if y_ref is not None and y_pseudo_gt.shape != y_ref.shape:
192+
raise ValueError(
193+
f"Input tensor y_pseudo_gt and y_ref should have the same shape, but got y_pseudo_gt shape {y_pseudo_gt.shape}, "
194+
f"y_ref shape {y_ref.shape}."
195+
)
196+
197+
# compute loss
198+
loss = sure_loss_function(operator, x, y_pseudo_gt, y_ref, self.eps, self.perturb_noise, complex_input)
199+
200+
return loss

monai/networks/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from __future__ import annotations
1313

14+
from .conjugate_gradient import ConjugateGradient
1415
from .convutils import calculate_out_shape, gaussian_1d, polyval, same_padding, stride_minus_kernel_padding
1516
from .drop_path import DropPath
1617
from .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, split_args

0 commit comments

Comments
 (0)