@@ -21,8 +21,9 @@ class PerceptualLoss(nn.Module):
2121 Perceptual loss using features from pretrained deep neural networks trained. The function supports networks
2222 pretrained on: ImageNet that use the LPIPS approach from Zhang, et al. "The unreasonable effectiveness of deep
2323 features as a perceptual metric." https://arxiv.org/abs/1801.03924 ; RadImagenet from Mei, et al. "RadImageNet: An
24- Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"; and MedicalNet from Chen et al.
25- "Med3D: Transfer Learning for 3D Medical Image Analysis" .
24+ Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"
25+ https://pubs.rsna.org/doi/full/10.1148/ryai.210315 ; and MedicalNet from Chen et al. "Med3D: Transfer Learning for
26+ 3D Medical Image Analysis" https://arxiv.org/abs/1904.00625 .
2627
2728 The fake 3D implementation is based on a 2.5D approach where we calculate the 2D perceptual on slices from the
2829 three axis.
@@ -48,11 +49,14 @@ def __init__(
4849 if spatial_dims not in [2 , 3 ]:
4950 raise NotImplementedError ("Perceptual loss is implemented only in 2D and 3D." )
5051
52+ if spatial_dims == 2 and "medicalnet_" in network_type :
53+ raise ValueError ("MedicalNet networks are only compatible with ``spatial_dims=3``." )
54+
5155 self .spatial_dims = spatial_dims
5256 if spatial_dims == 3 and is_fake_3d is False :
53- self .perceptual_function = MedicalNetPerceptualComponent (net = network_type , verbose = False )
57+ self .perceptual_function = MedicalNetPerceptualSimilarity (net = network_type , verbose = False )
5458 elif "radimagenet_" in network_type :
55- self .perceptual_function = RadImageNetPerceptualComponent (net = network_type , verbose = False )
59+ self .perceptual_function = RadImageNetPerceptualSimilarity (net = network_type , verbose = False )
5660 else :
5761 self .perceptual_function = LPIPS (
5862 pretrained = True ,
@@ -134,7 +138,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
134138 return torch .mean (loss )
135139
136140
137- class MedicalNetPerceptualComponent (nn .Module ):
141+ class MedicalNetPerceptualSimilarity (nn .Module ):
138142 """
139143 Component to perform the perceptual evaluation with the networks pretrained by Chen, et al. "Med3D: Transfer
140144 Learning for 3D Medical Image Analysis". This class uses torch Hub to download the networks from
@@ -200,7 +204,7 @@ def medicalnet_intensity_normalisation(volume):
200204 return (volume - mean ) / std
201205
202206
203- class RadImageNetPerceptualComponent (nn .Module ):
207+ class RadImageNetPerceptualSimilarity (nn .Module ):
204208 """
205209 Component to perform the perceptual evaluation with the networks pretrained on RadImagenet (pretrained by Mei, et
206210 al. "RadImageNet: An Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"). This class
0 commit comments