|
20 | 20 | torchvision, _ = optional_import("torchvision") |
21 | 21 |
|
22 | 22 |
|
23 | | - |
24 | 23 | class PerceptualLoss(nn.Module): |
25 | 24 | """ |
26 | 25 | Perceptual loss using features from pretrained deep neural networks trained. The function supports networks |
@@ -78,7 +77,7 @@ def __init__( |
78 | 77 | torch.hub.set_dir(cache_dir) |
79 | 78 |
|
80 | 79 | self.spatial_dims = spatial_dims |
81 | | - self.perceptual_function : nn.Module |
| 80 | + self.perceptual_function: nn.Module |
82 | 81 | if spatial_dims == 3 and is_fake_3d is False: |
83 | 82 | self.perceptual_function = MedicalNetPerceptualSimilarity(net=network_type, verbose=False) |
84 | 83 | elif "radimagenet_" in network_type: |
@@ -196,7 +195,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
196 | 195 | feats_input = normalize_tensor(outs_input) |
197 | 196 | feats_target = normalize_tensor(outs_target) |
198 | 197 |
|
199 | | - results : torch.Tensor = (feats_input - feats_target) ** 2 |
| 198 | + results: torch.Tensor = (feats_input - feats_target) ** 2 |
200 | 199 | results = spatial_average_3d(results.sum(dim=1, keepdim=True), keepdim=True) |
201 | 200 |
|
202 | 201 | return results |
@@ -345,7 +344,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
345 | 344 | feats_input = normalize_tensor(outs_input) |
346 | 345 | feats_target = normalize_tensor(outs_target) |
347 | 346 |
|
348 | | - results : torch.Tensor = (feats_input - feats_target) ** 2 |
| 347 | + results: torch.Tensor = (feats_input - feats_target) ** 2 |
349 | 348 | results = spatial_average(results.sum(dim=1, keepdim=True), keepdim=True) |
350 | 349 |
|
351 | 350 | return results |
|
0 commit comments