Skip to content

Commit 316934a

Browse files
mingxin-zhengpre-commit-ci[bot]KumoLiu
authored
Add checks for monai bundles after download and warn if incompatible (#7938)
Fixes #7930 . ### Description Check the monai version in metadata JSON and warn if the version is newer than the package being used. ### Demonstration Warning when the version is hardcoded to 1.2 from monaihosting ``` root@MS-7D31:/workspace/MONAI# python -m monai.bundle download spleen_ct_segmentation 2024-07-23 11:00:31,286 - INFO - --- input summary of monai.bundle.scripts.download --- 2024-07-23 11:00:31,286 - INFO - > name: 'spleen_ct_segmentation' 2024-07-23 11:00:31,286 - INFO - > source: 'monaihosting' 2024-07-23 11:00:31,286 - INFO - > remove_prefix: 'monai_' 2024-07-23 11:00:31,286 - INFO - > progress: True 2024-07-23 11:00:31,286 - INFO - --- 2024-07-23 11:00:31,985 - INFO - Expected md5 is None, skip md5 check for file /root/.cache/torch/hub/bundle/spleen_ct_segmentation_v0.5.8.zip. 2024-07-23 11:00:31,986 - INFO - File exists: /root/.cache/torch/hub/bundle/spleen_ct_segmentation_v0.5.8.zip, skipped downloading. 2024-07-23 11:00:31,986 - INFO - Writing into directory: /root/.cache/torch/hub/bundle. 2024-07-23 11:00:32,176 - WARNING - Your MONAI version is 1.2, but the bundle is built on MONAI version 1.3.2. ``` Auto select version if the download src is from NGC ``` root@MS-7D31:/workspace/MONAI# BUNDLE_DOWNLOAD_SRC=ngc python -m monai.bundle download spleen_ct_segmentation 2024-07-23 11:02:12,277 - INFO - --- input summary of monai.bundle.scripts.download --- 2024-07-23 11:02:12,277 - INFO - > name: 'spleen_ct_segmentation' 2024-07-23 11:02:12,277 - INFO - > source: 'ngc' 2024-07-23 11:02:12,277 - INFO - > remove_prefix: 'monai_' 2024-07-23 11:02:12,277 - INFO - > progress: True 2024-07-23 11:02:12,277 - INFO - --- monai_spleen_ct_segmentation_v0.3.7.zip: 34.0MB [00:01, 24.1MB/s] 2024-07-23 11:02:17,953 - INFO - Downloaded: /root/.cache/torch/hub/bundle/monai_spleen_ct_segmentation_v0.3.7.zip 2024-07-23 11:02:17,954 - INFO - Expected md5 is None, skip md5 check for file /root/.cache/torch/hub/bundle/monai_spleen_ct_segmentation_v0.3.7.zip. 2024-07-23 11:02:17,954 - INFO - Writing into directory: /root/.cache/torch/hub/bundle/spleen_ct_segmentation. ``` ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mingxin Zheng <mingxinz@nvidia.com> Signed-off-by: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent 37917e0 commit 316934a

File tree

2 files changed

+162
-27
lines changed

2 files changed

+162
-27
lines changed

monai/bundle/scripts.py

Lines changed: 111 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import torch
2828
from torch.cuda import is_available
2929

30-
from monai.apps.mmars.mmars import _get_all_ngc_models
30+
from monai._version import get_versions
3131
from monai.apps.utils import _basename, download_url, extractall, get_logger
3232
from monai.bundle.config_item import ConfigComponent
3333
from monai.bundle.config_parser import ConfigParser
@@ -67,6 +67,9 @@
6767
DEFAULT_DOWNLOAD_SOURCE = os.environ.get("BUNDLE_DOWNLOAD_SRC", "monaihosting")
6868
PPRINT_CONFIG_N = 5
6969

70+
MONAI_HOSTING_BASE_URL = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting"
71+
NGC_BASE_URL = "https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit"
72+
7073

7174
def update_kwargs(args: str | dict | None = None, ignore_none: bool = True, **kwargs: Any) -> dict:
7275
"""
@@ -169,16 +172,19 @@ def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filenam
169172

170173

171174
def _get_ngc_bundle_url(model_name: str, version: str) -> str:
172-
return f"https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit/{model_name.lower()}/versions/{version}/zip"
175+
return f"{NGC_BASE_URL}/{model_name.lower()}/versions/{version}/zip"
176+
177+
178+
def _get_ngc_private_base_url(repo: str) -> str:
179+
return f"https://api.ngc.nvidia.com/v2/{repo}/models"
173180

174181

175182
def _get_ngc_private_bundle_url(model_name: str, version: str, repo: str) -> str:
176-
return f"https://api.ngc.nvidia.com/v2/{repo}/models/{model_name.lower()}/versions/{version}/zip"
183+
return f"{_get_ngc_private_base_url(repo)}/{model_name.lower()}/versions/{version}/zip"
177184

178185

179186
def _get_monaihosting_bundle_url(model_name: str, version: str) -> str:
180-
monaihosting_root_path = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting"
181-
return f"{monaihosting_root_path}/{model_name.lower()}/versions/{version}/files/{model_name}_v{version}.zip"
187+
return f"{MONAI_HOSTING_BASE_URL}/{model_name.lower()}/versions/{version}/files/{model_name}_v{version}.zip"
182188

183189

184190
def _download_from_github(repo: str, download_path: Path, filename: str, progress: bool = True) -> None:
@@ -267,8 +273,7 @@ def _get_ngc_token(api_key, retry=0):
267273

268274

269275
def _get_latest_bundle_version_monaihosting(name):
270-
url = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting"
271-
full_url = f"{url}/{name.lower()}"
276+
full_url = f"{MONAI_HOSTING_BASE_URL}/{name.lower()}"
272277
requests_get, has_requests = optional_import("requests", name="get")
273278
if has_requests:
274279
resp = requests_get(full_url)
@@ -279,36 +284,114 @@ def _get_latest_bundle_version_monaihosting(name):
279284
return model_info["model"]["latestVersionIdStr"]
280285

281286

282-
def _get_latest_bundle_version_private_registry(name, repo, headers=None):
283-
url = f"https://api.ngc.nvidia.com/v2/{repo}/models"
284-
full_url = f"{url}/{name.lower()}"
285-
requests_get, has_requests = optional_import("requests", name="get")
286-
if has_requests:
287-
headers = {} if headers is None else headers
288-
resp = requests_get(full_url, headers=headers)
289-
resp.raise_for_status()
290-
else:
291-
raise ValueError("NGC API requires requests package. Please install it.")
287+
def _examine_monai_version(monai_version: str) -> tuple[bool, str]:
288+
"""Examine if the package version is compatible with the MONAI version in the metadata."""
289+
version_dict = get_versions()
290+
package_version = version_dict.get("version", "0+unknown")
291+
if package_version == "0+unknown":
292+
return False, "Package version is not available. Skipping version check."
293+
if monai_version == "0+unknown":
294+
return False, "MONAI version is not specified in the bundle. Skipping version check."
295+
# treat rc versions as the same as the release version
296+
package_version = re.sub(r"rc\d.*", "", package_version)
297+
monai_version = re.sub(r"rc\d.*", "", monai_version)
298+
if package_version < monai_version:
299+
return (
300+
False,
301+
f"Your MONAI version is {package_version}, but the bundle is built on MONAI version {monai_version}.",
302+
)
303+
return True, ""
304+
305+
306+
def _check_monai_version(bundle_dir: PathLike, name: str) -> None:
307+
"""Get the `monai_version` from the metadata.json and compare if it is smaller than the installed `monai` package version"""
308+
metadata_file = Path(bundle_dir) / name / "configs" / "metadata.json"
309+
if not metadata_file.exists():
310+
logger.warning(f"metadata file not found in {metadata_file}.")
311+
return
312+
with open(metadata_file) as f:
313+
metadata = json.load(f)
314+
is_compatible, msg = _examine_monai_version(metadata.get("monai_version", "0+unknown"))
315+
if not is_compatible:
316+
logger.warning(msg)
317+
318+
319+
def _list_latest_versions(data: dict, max_versions: int = 3) -> list[str]:
320+
"""
321+
Extract the latest versions from the data dictionary.
322+
323+
Args:
324+
data: the data dictionary.
325+
max_versions: the maximum number of versions to return.
326+
327+
Returns:
328+
versions of the latest models in the reverse order of creation date, e.g. ['1.0.0', '0.9.0', '0.8.0'].
329+
"""
330+
# Check if the data is a dictionary and it has the key 'modelVersions'
331+
if not isinstance(data, dict) or "modelVersions" not in data:
332+
raise ValueError("The data is not a dictionary or it does not have the key 'modelVersions'.")
333+
334+
# Extract the list of model versions
335+
model_versions = data["modelVersions"]
336+
337+
if (
338+
not isinstance(model_versions, list)
339+
or len(model_versions) == 0
340+
or "createdDate" not in model_versions[0]
341+
or "versionId" not in model_versions[0]
342+
):
343+
raise ValueError(
344+
"The model versions are not a list or it is empty or it does not have the keys 'createdDate' and 'versionId'."
345+
)
346+
347+
# Sort the versions by the 'createdDate' in descending order
348+
sorted_versions = sorted(model_versions, key=lambda x: x["createdDate"], reverse=True)
349+
return [v["versionId"] for v in sorted_versions[:max_versions]]
350+
351+
352+
def _get_latest_bundle_version_ngc(name: str, repo: str | None = None, headers: dict | None = None) -> str:
353+
base_url = _get_ngc_private_base_url(repo) if repo else NGC_BASE_URL
354+
version_endpoint = base_url + f"/{name.lower()}/versions/"
355+
356+
if not has_requests:
357+
raise ValueError("requests package is required, please install it.")
358+
359+
version_header = {"Accept-Encoding": "gzip, deflate"} # Excluding 'zstd' to fit NGC requirements
360+
if headers:
361+
version_header.update(headers)
362+
resp = requests_get(version_endpoint, headers=version_header)
363+
resp.raise_for_status()
292364
model_info = json.loads(resp.text)
293-
return model_info["model"]["latestVersionIdStr"]
365+
latest_versions = _list_latest_versions(model_info)
366+
367+
for version in latest_versions:
368+
file_endpoint = base_url + f"/{name.lower()}/versions/{version}/files/configs/metadata.json"
369+
resp = requests_get(file_endpoint, headers=headers)
370+
metadata = json.loads(resp.text)
371+
resp.raise_for_status()
372+
# if the package version is not available or the model is compatible with the package version
373+
is_compatible, _ = _examine_monai_version(metadata["monai_version"])
374+
if is_compatible:
375+
if version != latest_versions[0]:
376+
logger.info(f"Latest version is {latest_versions[0]}, but the compatible version is {version}.")
377+
return version
378+
379+
# if no compatible version is found, return the latest version
380+
return latest_versions[0]
294381

295382

296383
def _get_latest_bundle_version(
297384
source: str, name: str, repo: str, **kwargs: Any
298385
) -> dict[str, list[str] | str] | Any | None:
299386
if source == "ngc":
300387
name = _add_ngc_prefix(name)
301-
model_dict = _get_all_ngc_models(name)
302-
for v in model_dict.values():
303-
if v["name"] == name:
304-
return v["latest"]
305-
return None
388+
return _get_latest_bundle_version_ngc(name)
306389
elif source == "monaihosting":
307390
return _get_latest_bundle_version_monaihosting(name)
308391
elif source == "ngc_private":
309392
headers = kwargs.pop("headers", {})
310393
name = _add_ngc_prefix(name)
311-
return _get_latest_bundle_version_private_registry(name, repo, headers)
394+
return _get_latest_bundle_version_ngc(name, repo=repo, headers=headers)
312395
elif source == "github":
313396
repo_owner, repo_name, tag_name = repo.split("/")
314397
return get_bundle_versions(name, repo=f"{repo_owner}/{repo_name}", tag=tag_name)["latest_version"]
@@ -470,9 +553,8 @@ def download(
470553
if version_ is None:
471554
version_ = _get_latest_bundle_version(source=source_, name=name_, repo=repo_, headers=headers)
472555
if source_ == "github":
473-
if version_ is not None:
474-
name_ = "_v".join([name_, version_])
475-
_download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_, progress=progress_)
556+
name_ver = "_v".join([name_, version_]) if version_ is not None else name_
557+
_download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_ver, progress=progress_)
476558
elif source_ == "monaihosting":
477559
_download_from_monaihosting(download_path=bundle_dir_, filename=name_, version=version_, progress=progress_)
478560
elif source_ == "ngc":
@@ -501,6 +583,8 @@ def download(
501583
f"got source: {source_}."
502584
)
503585

586+
_check_monai_version(bundle_dir_, name_)
587+
504588

505589
@deprecated_arg("net_name", since="1.2", removed="1.5", msg_suffix="please use ``model`` instead.")
506590
@deprecated_arg("net_kwargs", since="1.2", removed="1.5", msg_suffix="please use ``model`` instead.")

tests/test_bundle_download.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import tempfile
1717
import unittest
1818
from unittest.case import skipUnless
19+
from unittest.mock import patch
1920

2021
import numpy as np
2122
import torch
@@ -24,6 +25,7 @@
2425
import monai.networks.nets as nets
2526
from monai.apps import check_hash
2627
from monai.bundle import ConfigParser, create_workflow, load
28+
from monai.bundle.scripts import _examine_monai_version, _list_latest_versions, download
2729
from monai.utils import optional_import
2830
from tests.utils import (
2931
SkipIfBeforePyTorchVersion,
@@ -207,6 +209,55 @@ def test_monaihosting_source_download_bundle(self, bundle_files, bundle_name, ve
207209
file_path = os.path.join(tempdir, bundle_name, file)
208210
self.assertTrue(os.path.exists(file_path))
209211

212+
@patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2"})
213+
def test_examine_monai_version(self, mock_get_versions):
214+
self.assertTrue(_examine_monai_version("1.1")[0]) # Should return True, compatible
215+
self.assertTrue(_examine_monai_version("1.2rc1")[0]) # Should return True, compatible
216+
self.assertFalse(_examine_monai_version("1.3")[0]) # Should return False, not compatible
217+
218+
@patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2rc1"})
219+
def test_examine_monai_version_rc(self, mock_get_versions):
220+
self.assertTrue(_examine_monai_version("1.2")[0]) # Should return True, compatible
221+
self.assertFalse(_examine_monai_version("1.3")[0]) # Should return False, not compatible
222+
223+
def test_list_latest_versions(self):
224+
"""Test listing of the latest versions."""
225+
data = {
226+
"modelVersions": [
227+
{"createdDate": "2021-01-01", "versionId": "1.0"},
228+
{"createdDate": "2021-01-02", "versionId": "1.1"},
229+
{"createdDate": "2021-01-03", "versionId": "1.2"},
230+
]
231+
}
232+
self.assertEqual(_list_latest_versions(data), ["1.2", "1.1", "1.0"])
233+
self.assertEqual(_list_latest_versions(data, max_versions=2), ["1.2", "1.1"])
234+
data = {
235+
"modelVersions": [
236+
{"createdDate": "2021-01-01", "versionId": "1.0"},
237+
{"createdDate": "2021-01-02", "versionId": "1.1"},
238+
]
239+
}
240+
self.assertEqual(_list_latest_versions(data), ["1.1", "1.0"])
241+
242+
@skip_if_quick
243+
@patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2"})
244+
def test_download_monaihosting(self, mock_get_versions):
245+
"""Test checking MONAI version from a metadata file."""
246+
with patch("monai.bundle.scripts.logger") as mock_logger:
247+
with tempfile.TemporaryDirectory() as tempdir:
248+
download(name="spleen_ct_segmentation", bundle_dir=tempdir, source="monaihosting")
249+
# Should have a warning message because the latest version is using monai > 1.2
250+
mock_logger.warning.assert_called_once()
251+
252+
@skip_if_quick
253+
@patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2"})
254+
def test_download_ngc(self, mock_get_versions):
255+
"""Test checking MONAI version from a metadata file."""
256+
with patch("monai.bundle.scripts.logger") as mock_logger:
257+
with tempfile.TemporaryDirectory() as tempdir:
258+
download(name="spleen_ct_segmentation", bundle_dir=tempdir, source="ngc")
259+
mock_logger.warning.assert_not_called()
260+
210261

211262
@skip_if_no_cuda
212263
class TestLoad(unittest.TestCase):

0 commit comments

Comments
 (0)