Skip to content

Commit c99e16e

Browse files
author
Virginia Fernandez
committed
Add check of network name
1 parent b065de7 commit c99e16e

File tree

1 file changed

+33
-8
lines changed

1 file changed

+33
-8
lines changed

monai/losses/perceptual.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,21 @@
1717
import torch.nn as nn
1818

1919
from monai.utils import optional_import
20+
2021
from monai.utils.enums import StrEnum
2122

23+
# Valid model name to download from the repository
24+
HF_MONAI_MODELS = (
25+
"medicalnet_resnet10_23datasets",
26+
"medicalnet_resnet50_23datasets",
27+
"radimagenet_resnet50",
28+
)
29+
2230
LPIPS, _ = optional_import("lpips", name="LPIPS")
2331
torchvision, _ = optional_import("torchvision")
2432

2533

34+
2635
class PercetualNetworkType(StrEnum):
2736
"""Types of neural networks that are supported by perceptua loss."""
2837

@@ -86,13 +95,18 @@ def __init__(
8695
if spatial_dims not in [2, 3]:
8796
raise NotImplementedError("Perceptual loss is implemented only in 2D and 3D.")
8897

89-
if (spatial_dims == 2 or is_fake_3d) and "medicalnet_" in network_type:
90-
raise ValueError(
91-
"MedicalNet networks are only compatible with ``spatial_dims=3``."
92-
"Argument is_fake_3d must be set to False."
93-
)
9498

95-
if channel_wise and "medicalnet_" not in network_type:
99+
# Strict validation for MedicalNet
100+
if "medicalnet_" in network_type:
101+
if spatial_dims == 2 or is_fake_3d:
102+
raise ValueError(
103+
"MedicalNet networks are only compatible with ``spatial_dims=3``. Argument is_fake_3d must be set to False."
104+
)
105+
if not channel_wise:
106+
warnings.warn("MedicalNet networks support channel-wise loss. Consider setting channel_wise=True.")
107+
108+
# Channel-wise only for MedicalNet
109+
elif channel_wise:
96110
raise ValueError("Channel-wise loss is only compatible with MedicalNet networks.")
97111

98112
if network_type.lower() not in list(PercetualNetworkType):
@@ -219,8 +233,14 @@ def __init__(
219233
) -> None:
220234
super().__init__()
221235
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
236+
if net not in HF_MONAI_MODELS:
237+
raise ValueError(
238+
f"Invalid download model name '{net}'. Must be one of: {', '.join(HF_MONAI_MODELS)}."
239+
)
240+
222241
self.model = torch.hub.load(
223-
"Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir
242+
"Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir,
243+
trust_repo=True,
224244
)
225245
self.eval()
226246

@@ -309,7 +329,12 @@ class RadImageNetPerceptualSimilarity(nn.Module):
309329

310330
def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False, cache_dir: str | None = None) -> None:
311331
super().__init__()
312-
self.model = torch.hub.load("Project-MONAI/perceptual-models", model=net, verbose=verbose, cache_dir=cache_dir)
332+
if net not in HF_MONAI_MODELS:
333+
raise ValueError(
334+
f"Invalid download model name '{net}'. Must be one of: {', '.join(HF_MONAI_MODELS)}."
335+
)
336+
self.model = torch.hub.load("Project-MONAI/perceptual-models", model=net, verbose=verbose, cache_dir=cache_dir,
337+
trust_repo=True)
313338
self.eval()
314339

315340
for param in self.parameters():

0 commit comments

Comments
 (0)