Skip to content

Commit 34e84a3

Browse files
Lucas RobinetLucas-rbnt
authored andcommitted
Removing L2-norm in contrastive loss (L2-norm already present in cosine-similarity computation)
Signed-off-by: Lucas Robinet <robinet.lucas@iuct-oncopole.fr>
1 parent 960249f commit 34e84a3

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

monai/losses/contrastive.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
6868
temperature_tensor = torch.as_tensor(self.temperature).to(input.device)
6969
batch_size = input.shape[0]
7070

71-
norm_i = F.normalize(input, dim=1)
72-
norm_j = F.normalize(target, dim=1)
73-
7471
negatives_mask = ~torch.eye(batch_size * 2, batch_size * 2, dtype=torch.bool)
7572
negatives_mask = torch.clone(negatives_mask.type(torch.float)).to(input.device)
7673

77-
repr = torch.cat([norm_i, norm_j], dim=0)
74+
repr = torch.cat([input, target], dim=0)
7875
sim_matrix = F.cosine_similarity(repr.unsqueeze(1), repr.unsqueeze(0), dim=2)
7976
sim_ij = torch.diag(sim_matrix, batch_size)
8077
sim_ji = torch.diag(sim_matrix, -batch_size)

0 commit comments

Comments
 (0)