Skip to content

Commit cd01b59

Browse files
committed
Fixes more typing errors
1 parent 4b1d801 commit cd01b59

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

monai/losses/adversarial_loss.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(
6464
# Depending on the criterion, a different activation layer is used.
6565
self.real_label = 1.0
6666
self.fake_label = 0.0
67-
self.loss_fct : _Loss
67+
self.loss_fct: _Loss
6868
if criterion == AdversarialCriterions.BCE.value:
6969
self.activation = get_act_layer("SIGMOID")
7070
self.loss_fct = torch.nn.BCELoss(reduction=reduction)
@@ -153,16 +153,18 @@ def forward(
153153
loss_ = self.forward_single(disc_out, target_[disc_ind])
154154
loss_list.append(loss_)
155155

156+
loss: torch.Tensor | list[torch.Tensor]
156157
if loss_list is not None:
157158
if self.reduction == LossReduction.MEAN.value:
158159
loss = torch.mean(torch.stack(loss_list))
159160
elif self.reduction == LossReduction.SUM.value:
160161
loss = torch.sum(torch.stack(loss_list))
161-
162+
else:
163+
loss = loss_list
162164
return loss
163165

164166
def forward_single(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
165-
forward : torch.Tensor
167+
forward: torch.Tensor
166168
if (
167169
self.criterion == AdversarialCriterions.BCE.value
168170
or self.criterion == AdversarialCriterions.LEAST_SQUARE.value

monai/losses/perceptual.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
torchvision, _ = optional_import("torchvision")
2121

2222

23-
2423
class PerceptualLoss(nn.Module):
2524
"""
2625
Perceptual loss using features from pretrained deep neural networks trained. The function supports networks
@@ -78,7 +77,7 @@ def __init__(
7877
torch.hub.set_dir(cache_dir)
7978

8079
self.spatial_dims = spatial_dims
81-
self.perceptual_function : nn.Module
80+
self.perceptual_function: nn.Module
8281
if spatial_dims == 3 and is_fake_3d is False:
8382
self.perceptual_function = MedicalNetPerceptualSimilarity(net=network_type, verbose=False)
8483
elif "radimagenet_" in network_type:
@@ -196,7 +195,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
196195
feats_input = normalize_tensor(outs_input)
197196
feats_target = normalize_tensor(outs_target)
198197

199-
results : torch.Tensor = (feats_input - feats_target) ** 2
198+
results: torch.Tensor = (feats_input - feats_target) ** 2
200199
results = spatial_average_3d(results.sum(dim=1, keepdim=True), keepdim=True)
201200

202201
return results
@@ -345,7 +344,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
345344
feats_input = normalize_tensor(outs_input)
346345
feats_target = normalize_tensor(outs_target)
347346

348-
results : torch.Tensor = (feats_input - feats_target) ** 2
347+
results: torch.Tensor = (feats_input - feats_target) ** 2
349348
results = spatial_average(results.sum(dim=1, keepdim=True), keepdim=True)
350349

351350
return results

0 commit comments

Comments
 (0)