Skip to content

Commit 393d415

Browse files
committed
Add Barlow Twins loss for representation learning
Signed-off-by: Lucas Robinet <robinet.lucas@iuct-oncopole.fr>
1 parent 95f69de commit 393d415

File tree

4 files changed

+198
-0
lines changed

4 files changed

+198
-0
lines changed

docs/source/losses.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ Segmentation Losses
7373
.. autoclass:: ContrastiveLoss
7474
:members:
7575

76+
`BarlowTwinsLoss`
77+
~~~~~~~~~~~~~~~~~
78+
.. autoclass:: BarlowTwinsLoss
79+
:members:
80+
7681
`HausdorffDTLoss`
7782
~~~~~~~~~~~~~~~~~
7883
.. autoclass:: HausdorffDTLoss

monai/losses/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from __future__ import annotations
1313

1414
from .adversarial_loss import PatchAdversarialLoss
15+
from .barlow_twins import BarlowTwinsLoss
1516
from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss
1617
from .contrastive import ContrastiveLoss
1718
from .deform import BendingEnergyLoss, DiffusionLoss

monai/losses/barlow_twins.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
from warnings import warn
15+
16+
import torch
17+
from torch.nn.modules.loss import _Loss
18+
19+
20+
class BarlowTwinsLoss(_Loss):
21+
"""
22+
Compute the Barlow Twins loss defined in:
23+
24+
Zbontar, Jure, et al. "Barlow Twins: Self-Supervised Learning via Redundancy Reduction" International
25+
conference on machine learning. PMLR, 2020. (http://proceedings.mlr.press/v139/zbontar21a/zbontar21a.pdf)
26+
27+
Adapted from:
28+
https://github.com/facebookresearch/barlowtwins
29+
30+
"""
31+
32+
def __init__(self, lambd: float = 5e-3, batch_size: int = -1) -> None:
33+
"""
34+
Args:
35+
lamb: Can be any float to handle the informativeness and invariance trade-off. Ideally set to 5e-3.
36+
37+
Raises:
38+
ValueError: When an input of dimension length > 2 is passed
39+
ValueError: When input and target are of different shapes
40+
ValueError: When batch size is less than or equal to 1
41+
42+
"""
43+
super().__init__()
44+
self.lambd = lambd
45+
46+
if batch_size != -1:
47+
warn("batch_size is no longer required to be set. It will be estimated dynamically in the forward call")
48+
49+
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
50+
"""
51+
Args:
52+
input: the shape should be B[F].
53+
target: the shape should be B[F].
54+
"""
55+
if len(target.shape) > 2 or len(input.shape) > 2:
56+
raise ValueError(
57+
f"Either target or input has dimensions greater than 2 where target "
58+
f"shape is ({target.shape}) and input shape is ({input.shape})"
59+
)
60+
61+
if target.shape != input.shape:
62+
raise ValueError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})")
63+
64+
if target.size(0) <= 1:
65+
raise ValueError(
66+
f"Batch size must be greater than 1 to compute Barlow Twins Loss, but got {target.size(0)}"
67+
)
68+
69+
lambd_tensor = torch.as_tensor(self.lambd).to(input.device)
70+
batch_size = input.shape[0]
71+
72+
# normalize input and target
73+
input_norm = (input - input.mean(0)) / input.std(0).add(1e-6)
74+
target_norm = (target - target.mean(0)) / target.std(0).add(1e-6)
75+
76+
# cross-correlation matrix
77+
c = torch.mm(input_norm.t(), target_norm) / batch_size # input_norm.t() is FxB, target_norm is BxF so c is FxF
78+
79+
# loss
80+
c_diff = (c - torch.eye(c.size(0), device=c.device)).pow_(2) # FxF
81+
c_diff[~torch.eye(c.size(0), device=c.device).bool()] *= lambd_tensor
82+
83+
return c_diff.sum()

tests/test_barlow_twins_loss.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
import numpy as np
17+
import torch
18+
from parameterized import parameterized
19+
20+
from monai.losses import BarlowTwinsLoss
21+
22+
TEST_CASES = [
23+
[ # shape: (2, 4), (2, 4)
24+
{"lambd": 5e-3},
25+
{
26+
"input": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]),
27+
"target": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]),
28+
},
29+
4.0,
30+
],
31+
[ # shape: (2, 4), (2, 4)
32+
{"lambd": 5e-3},
33+
{
34+
"input": torch.tensor([[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]),
35+
"target": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]),
36+
},
37+
4.0,
38+
],
39+
[ # shape: (2, 4), (2, 4)
40+
{"lambd": 5e-3},
41+
{
42+
"input": torch.tensor([[1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 1.0, 0.0]]),
43+
"target": torch.tensor([[1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 0.0, 1.0]]),
44+
},
45+
5.2562,
46+
],
47+
[ # shape: (2, 4), (2, 4)
48+
{"lambd": 5e-4},
49+
{
50+
"input": torch.tensor([[2.0, 3.0, 1.0, 2.0], [0.0, 1.0, 2.0, 5.0]]),
51+
"target": torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]),
52+
},
53+
5.0015,
54+
],
55+
[ # shape: (4, 4), (4, 4)
56+
{"lambd": 5e-3},
57+
{
58+
"input": torch.tensor(
59+
[[1.0, 2.0, 1.0, 1.0], [3.0, 1.0, 1.0, 2.0], [1.0, 1.0, 1.0, 1.0], [2.0, 1.0, 1.0, 0.0]]
60+
),
61+
"target": torch.tensor(
62+
[
63+
[0.0, 1.0, -1.0, 0.0],
64+
[1 / 3, 0.0, -2 / 3, 1 / 3],
65+
[-2 / 3, -1.0, 7 / 3, 1 / 3],
66+
[1 / 3, 0.0, 1 / 3, -2 / 3],
67+
]
68+
),
69+
},
70+
1.4736,
71+
],
72+
]
73+
74+
75+
class TestBarlowTwinsLoss(unittest.TestCase):
76+
77+
@parameterized.expand(TEST_CASES)
78+
def test_result(self, input_param, input_data, expected_val):
79+
barlowtwinsloss = BarlowTwinsLoss(**input_param)
80+
result = barlowtwinsloss(**input_data)
81+
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)
82+
83+
def test_ill_shape(self):
84+
loss = BarlowTwinsLoss(lambd=5e-3)
85+
with self.assertRaisesRegex(ValueError, ""):
86+
loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
87+
88+
def test_ill_batch_size(self):
89+
loss = BarlowTwinsLoss(lambd=5e-3, batch_size=1)
90+
with self.assertRaisesRegex(ValueError, ""):
91+
loss(torch.ones((1, 2)), torch.ones((1, 2)))
92+
93+
def test_with_cuda(self):
94+
loss = BarlowTwinsLoss(lambd=5e-3)
95+
i = torch.ones((2, 10))
96+
j = torch.ones((2, 10))
97+
if torch.cuda.is_available():
98+
i = i.cuda()
99+
j = j.cuda()
100+
output = loss(i, j)
101+
np.testing.assert_allclose(output.detach().cpu().numpy(), 10.0, atol=1e-4, rtol=1e-4)
102+
103+
def check_warning_rasied(self):
104+
with self.assertWarns(Warning):
105+
BarlowTwinsLoss(lambd=5e-3, batch_size=1)
106+
107+
108+
if __name__ == "__main__":
109+
unittest.main()

0 commit comments

Comments
 (0)