Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit baa5ee5

Browse files
authored
Add error when "medicalnet_..." network_type used with spatial_dims==2 (#167)
Signed-off-by: Walter Hugo Lopez Pinaya <ianonimato@hotmail.com>
1 parent 221260a commit baa5ee5

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

generative/losses/perceptual.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ class PerceptualLoss(nn.Module):
2121
Perceptual loss using features from pretrained deep neural networks trained. The function supports networks
2222
pretrained on: ImageNet that use the LPIPS approach from Zhang, et al. "The unreasonable effectiveness of deep
2323
features as a perceptual metric." https://arxiv.org/abs/1801.03924 ; RadImagenet from Mei, et al. "RadImageNet: An
24-
Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"; and MedicalNet from Chen et al.
25-
"Med3D: Transfer Learning for 3D Medical Image Analysis" .
24+
Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"
25+
https://pubs.rsna.org/doi/full/10.1148/ryai.210315 ; and MedicalNet from Chen et al. "Med3D: Transfer Learning for
26+
3D Medical Image Analysis" https://arxiv.org/abs/1904.00625 .
2627
2728
The fake 3D implementation is based on a 2.5D approach where we calculate the 2D perceptual on slices from the
2829
three axis.
@@ -48,11 +49,14 @@ def __init__(
4849
if spatial_dims not in [2, 3]:
4950
raise NotImplementedError("Perceptual loss is implemented only in 2D and 3D.")
5051

52+
if spatial_dims == 2 and "medicalnet_" in network_type:
53+
raise ValueError("MedicalNet networks are only compatible with ``spatial_dims=3``.")
54+
5155
self.spatial_dims = spatial_dims
5256
if spatial_dims == 3 and is_fake_3d is False:
53-
self.perceptual_function = MedicalNetPerceptualComponent(net=network_type, verbose=False)
57+
self.perceptual_function = MedicalNetPerceptualSimilarity(net=network_type, verbose=False)
5458
elif "radimagenet_" in network_type:
55-
self.perceptual_function = RadImageNetPerceptualComponent(net=network_type, verbose=False)
59+
self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False)
5660
else:
5761
self.perceptual_function = LPIPS(
5862
pretrained=True,
@@ -134,7 +138,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
134138
return torch.mean(loss)
135139

136140

137-
class MedicalNetPerceptualComponent(nn.Module):
141+
class MedicalNetPerceptualSimilarity(nn.Module):
138142
"""
139143
Component to perform the perceptual evaluation with the networks pretrained by Chen, et al. "Med3D: Transfer
140144
Learning for 3D Medical Image Analysis". This class uses torch Hub to download the networks from
@@ -200,7 +204,7 @@ def medicalnet_intensity_normalisation(volume):
200204
return (volume - mean) / std
201205

202206

203-
class RadImageNetPerceptualComponent(nn.Module):
207+
class RadImageNetPerceptualSimilarity(nn.Module):
204208
"""
205209
Component to perform the perceptual evaluation with the networks pretrained on RadImagenet (pretrained by Mei, et
206210
al. "RadImageNet: An Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"). This class

tests/test_perceptual_loss.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,13 @@ def test_1d(self):
7676
with self.assertRaises(NotImplementedError):
7777
PerceptualLoss(spatial_dims=1)
7878

79+
def test_medicalnet_on_2d_data(self):
80+
with self.assertRaises(ValueError):
81+
PerceptualLoss(spatial_dims=2, network_type="medicalnet_resnet10_23datasets")
82+
83+
with self.assertRaises(ValueError):
84+
PerceptualLoss(spatial_dims=2, network_type="medicalnet_resnet50_23datasets")
85+
7986

8087
if __name__ == "__main__":
8188
unittest.main()

0 commit comments

Comments
 (0)