Skip to content

Commit f579763

Browse files
authored
Update nacl_loss.py
Call contiguous after permute to avoid reshaping issue.
1 parent db9daeb commit f579763

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

monai/losses/nacl_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,11 @@ def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor:
9595
rmask: torch.Tensor
9696

9797
if self.dim == 2:
98-
oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 3, 1, 2).float()
98+
oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).permute(0, 3, 1, 2).contiguous().float()
9999
rmask = self.svls_layer(oh_labels)
100100

101101
if self.dim == 3:
102-
oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 4, 1, 2, 3).float()
102+
oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).permute(0, 4, 1, 2, 3).contiguous().float()
103103
rmask = self.svls_layer(oh_labels)
104104

105105
return rmask

0 commit comments

Comments
 (0)