|
17 | 17 | import torch.nn as nn |
18 | 18 |
|
19 | 19 | from monai.utils import optional_import |
| 20 | + |
20 | 21 | from monai.utils.enums import StrEnum |
21 | 22 |
|
| 23 | +# Valid model name to download from the repository |
| 24 | +HF_MONAI_MODELS = ( |
| 25 | + "medicalnet_resnet10_23datasets", |
| 26 | + "medicalnet_resnet50_23datasets", |
| 27 | + "radimagenet_resnet50", |
| 28 | +) |
| 29 | + |
22 | 30 | LPIPS, _ = optional_import("lpips", name="LPIPS") |
23 | 31 | torchvision, _ = optional_import("torchvision") |
24 | 32 |
|
25 | 33 |
|
| 34 | + |
26 | 35 | class PercetualNetworkType(StrEnum): |
27 | 36 | """Types of neural networks that are supported by perceptua loss.""" |
28 | 37 |
|
@@ -86,13 +95,18 @@ def __init__( |
86 | 95 | if spatial_dims not in [2, 3]: |
87 | 96 | raise NotImplementedError("Perceptual loss is implemented only in 2D and 3D.") |
88 | 97 |
|
89 | | - if (spatial_dims == 2 or is_fake_3d) and "medicalnet_" in network_type: |
90 | | - raise ValueError( |
91 | | - "MedicalNet networks are only compatible with ``spatial_dims=3``." |
92 | | - "Argument is_fake_3d must be set to False." |
93 | | - ) |
94 | 98 |
|
95 | | - if channel_wise and "medicalnet_" not in network_type: |
| 99 | + # Strict validation for MedicalNet |
| 100 | + if "medicalnet_" in network_type: |
| 101 | + if spatial_dims == 2 or is_fake_3d: |
| 102 | + raise ValueError( |
| 103 | + "MedicalNet networks are only compatible with ``spatial_dims=3``. Argument is_fake_3d must be set to False." |
| 104 | + ) |
| 105 | + if not channel_wise: |
| 106 | + warnings.warn("MedicalNet networks support channel-wise loss. Consider setting channel_wise=True.") |
| 107 | + |
| 108 | + # Channel-wise only for MedicalNet |
| 109 | + elif channel_wise: |
96 | 110 | raise ValueError("Channel-wise loss is only compatible with MedicalNet networks.") |
97 | 111 |
|
98 | 112 | if network_type.lower() not in list(PercetualNetworkType): |
@@ -219,8 +233,14 @@ def __init__( |
219 | 233 | ) -> None: |
220 | 234 | super().__init__() |
221 | 235 | torch.hub._validate_not_a_forked_repo = lambda a, b, c: True |
| 236 | + if net not in HF_MONAI_MODELS: |
| 237 | + raise ValueError( |
| 238 | + f"Invalid download model name '{net}'. Must be one of: {', '.join(HF_MONAI_MODELS)}." |
| 239 | + ) |
| 240 | + |
222 | 241 | self.model = torch.hub.load( |
223 | | - "Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir |
| 242 | + "Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir, |
| 243 | + trust_repo=True, |
224 | 244 | ) |
225 | 245 | self.eval() |
226 | 246 |
|
@@ -309,7 +329,12 @@ class RadImageNetPerceptualSimilarity(nn.Module): |
309 | 329 |
|
310 | 330 | def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False, cache_dir: str | None = None) -> None: |
311 | 331 | super().__init__() |
312 | | - self.model = torch.hub.load("Project-MONAI/perceptual-models", model=net, verbose=verbose, cache_dir=cache_dir) |
| 332 | + if net not in HF_MONAI_MODELS: |
| 333 | + raise ValueError( |
| 334 | + f"Invalid download model name '{net}'. Must be one of: {', '.join(HF_MONAI_MODELS)}." |
| 335 | + ) |
| 336 | + self.model = torch.hub.load("Project-MONAI/perceptual-models", model=net, verbose=verbose, cache_dir=cache_dir, |
| 337 | + trust_repo=True) |
313 | 338 | self.eval() |
314 | 339 |
|
315 | 340 | for param in self.parameters(): |
|
0 commit comments