Skip to content

Commit

Permalink
Add checks for monai bundles after download and warn if incompatible (#…
Browse files Browse the repository at this point in the history
…7938)

Fixes #7930 .

### Description

Check the monai version in metadata JSON and warn if the version is
newer than the package being used.


### Demonstration

Warning when the version is hardcoded to 1.2 from monaihosting
```
root@MS-7D31:/workspace/MONAI# python -m monai.bundle download spleen_ct_segmentation
2024-07-23 11:00:31,286 - INFO - --- input summary of monai.bundle.scripts.download ---
2024-07-23 11:00:31,286 - INFO - > name: 'spleen_ct_segmentation'
2024-07-23 11:00:31,286 - INFO - > source: 'monaihosting'
2024-07-23 11:00:31,286 - INFO - > remove_prefix: 'monai_'
2024-07-23 11:00:31,286 - INFO - > progress: True
2024-07-23 11:00:31,286 - INFO - ---


2024-07-23 11:00:31,985 - INFO - Expected md5 is None, skip md5 check for file /root/.cache/torch/hub/bundle/spleen_ct_segmentation_v0.5.8.zip.
2024-07-23 11:00:31,986 - INFO - File exists: /root/.cache/torch/hub/bundle/spleen_ct_segmentation_v0.5.8.zip, skipped downloading.
2024-07-23 11:00:31,986 - INFO - Writing into directory: /root/.cache/torch/hub/bundle.
2024-07-23 11:00:32,176 - WARNING - Your MONAI version is 1.2, but the bundle is built on MONAI version 1.3.2.
```

Auto select version if the download src is from NGC
```
root@MS-7D31:/workspace/MONAI# BUNDLE_DOWNLOAD_SRC=ngc python -m monai.bundle download spleen_ct_segmentation
2024-07-23 11:02:12,277 - INFO - --- input summary of monai.bundle.scripts.download ---
2024-07-23 11:02:12,277 - INFO - > name: 'spleen_ct_segmentation'
2024-07-23 11:02:12,277 - INFO - > source: 'ngc'
2024-07-23 11:02:12,277 - INFO - > remove_prefix: 'monai_'
2024-07-23 11:02:12,277 - INFO - > progress: True
2024-07-23 11:02:12,277 - INFO - ---


monai_spleen_ct_segmentation_v0.3.7.zip: 34.0MB [00:01, 24.1MB/s]                                                                                                                                                       
2024-07-23 11:02:17,953 - INFO - Downloaded: /root/.cache/torch/hub/bundle/monai_spleen_ct_segmentation_v0.3.7.zip
2024-07-23 11:02:17,954 - INFO - Expected md5 is None, skip md5 check for file /root/.cache/torch/hub/bundle/monai_spleen_ct_segmentation_v0.3.7.zip.
2024-07-23 11:02:17,954 - INFO - Writing into directory: /root/.cache/torch/hub/bundle/spleen_ct_segmentation.
```

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [x] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Mingxin Zheng <mingxinz@nvidia.com>
Signed-off-by: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 24, 2024
1 parent 37917e0 commit 316934a
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 27 deletions.
138 changes: 111 additions & 27 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import torch
from torch.cuda import is_available

from monai.apps.mmars.mmars import _get_all_ngc_models
from monai._version import get_versions
from monai.apps.utils import _basename, download_url, extractall, get_logger
from monai.bundle.config_item import ConfigComponent
from monai.bundle.config_parser import ConfigParser
Expand Down Expand Up @@ -67,6 +67,9 @@
DEFAULT_DOWNLOAD_SOURCE = os.environ.get("BUNDLE_DOWNLOAD_SRC", "monaihosting")
PPRINT_CONFIG_N = 5

MONAI_HOSTING_BASE_URL = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting"
NGC_BASE_URL = "https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit"


def update_kwargs(args: str | dict | None = None, ignore_none: bool = True, **kwargs: Any) -> dict:
"""
Expand Down Expand Up @@ -169,16 +172,19 @@ def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filenam


def _get_ngc_bundle_url(model_name: str, version: str) -> str:
return f"https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit/{model_name.lower()}/versions/{version}/zip"
return f"{NGC_BASE_URL}/{model_name.lower()}/versions/{version}/zip"


def _get_ngc_private_base_url(repo: str) -> str:
return f"https://api.ngc.nvidia.com/v2/{repo}/models"


def _get_ngc_private_bundle_url(model_name: str, version: str, repo: str) -> str:
return f"https://api.ngc.nvidia.com/v2/{repo}/models/{model_name.lower()}/versions/{version}/zip"
return f"{_get_ngc_private_base_url(repo)}/{model_name.lower()}/versions/{version}/zip"


def _get_monaihosting_bundle_url(model_name: str, version: str) -> str:
monaihosting_root_path = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting"
return f"{monaihosting_root_path}/{model_name.lower()}/versions/{version}/files/{model_name}_v{version}.zip"
return f"{MONAI_HOSTING_BASE_URL}/{model_name.lower()}/versions/{version}/files/{model_name}_v{version}.zip"


def _download_from_github(repo: str, download_path: Path, filename: str, progress: bool = True) -> None:
Expand Down Expand Up @@ -267,8 +273,7 @@ def _get_ngc_token(api_key, retry=0):


def _get_latest_bundle_version_monaihosting(name):
url = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting"
full_url = f"{url}/{name.lower()}"
full_url = f"{MONAI_HOSTING_BASE_URL}/{name.lower()}"
requests_get, has_requests = optional_import("requests", name="get")
if has_requests:
resp = requests_get(full_url)
Expand All @@ -279,36 +284,114 @@ def _get_latest_bundle_version_monaihosting(name):
return model_info["model"]["latestVersionIdStr"]


def _get_latest_bundle_version_private_registry(name, repo, headers=None):
url = f"https://api.ngc.nvidia.com/v2/{repo}/models"
full_url = f"{url}/{name.lower()}"
requests_get, has_requests = optional_import("requests", name="get")
if has_requests:
headers = {} if headers is None else headers
resp = requests_get(full_url, headers=headers)
resp.raise_for_status()
else:
raise ValueError("NGC API requires requests package. Please install it.")
def _examine_monai_version(monai_version: str) -> tuple[bool, str]:
"""Examine if the package version is compatible with the MONAI version in the metadata."""
version_dict = get_versions()
package_version = version_dict.get("version", "0+unknown")
if package_version == "0+unknown":
return False, "Package version is not available. Skipping version check."
if monai_version == "0+unknown":
return False, "MONAI version is not specified in the bundle. Skipping version check."
# treat rc versions as the same as the release version
package_version = re.sub(r"rc\d.*", "", package_version)
monai_version = re.sub(r"rc\d.*", "", monai_version)
if package_version < monai_version:
return (
False,
f"Your MONAI version is {package_version}, but the bundle is built on MONAI version {monai_version}.",
)
return True, ""


def _check_monai_version(bundle_dir: PathLike, name: str) -> None:
"""Get the `monai_version` from the metadata.json and compare if it is smaller than the installed `monai` package version"""
metadata_file = Path(bundle_dir) / name / "configs" / "metadata.json"
if not metadata_file.exists():
logger.warning(f"metadata file not found in {metadata_file}.")
return
with open(metadata_file) as f:
metadata = json.load(f)
is_compatible, msg = _examine_monai_version(metadata.get("monai_version", "0+unknown"))
if not is_compatible:
logger.warning(msg)


def _list_latest_versions(data: dict, max_versions: int = 3) -> list[str]:
"""
Extract the latest versions from the data dictionary.
Args:
data: the data dictionary.
max_versions: the maximum number of versions to return.
Returns:
versions of the latest models in the reverse order of creation date, e.g. ['1.0.0', '0.9.0', '0.8.0'].
"""
# Check if the data is a dictionary and it has the key 'modelVersions'
if not isinstance(data, dict) or "modelVersions" not in data:
raise ValueError("The data is not a dictionary or it does not have the key 'modelVersions'.")

# Extract the list of model versions
model_versions = data["modelVersions"]

if (
not isinstance(model_versions, list)
or len(model_versions) == 0
or "createdDate" not in model_versions[0]
or "versionId" not in model_versions[0]
):
raise ValueError(
"The model versions are not a list or it is empty or it does not have the keys 'createdDate' and 'versionId'."
)

# Sort the versions by the 'createdDate' in descending order
sorted_versions = sorted(model_versions, key=lambda x: x["createdDate"], reverse=True)
return [v["versionId"] for v in sorted_versions[:max_versions]]


def _get_latest_bundle_version_ngc(name: str, repo: str | None = None, headers: dict | None = None) -> str:
base_url = _get_ngc_private_base_url(repo) if repo else NGC_BASE_URL
version_endpoint = base_url + f"/{name.lower()}/versions/"

if not has_requests:
raise ValueError("requests package is required, please install it.")

version_header = {"Accept-Encoding": "gzip, deflate"} # Excluding 'zstd' to fit NGC requirements
if headers:
version_header.update(headers)
resp = requests_get(version_endpoint, headers=version_header)
resp.raise_for_status()
model_info = json.loads(resp.text)
return model_info["model"]["latestVersionIdStr"]
latest_versions = _list_latest_versions(model_info)

for version in latest_versions:
file_endpoint = base_url + f"/{name.lower()}/versions/{version}/files/configs/metadata.json"
resp = requests_get(file_endpoint, headers=headers)
metadata = json.loads(resp.text)
resp.raise_for_status()
# if the package version is not available or the model is compatible with the package version
is_compatible, _ = _examine_monai_version(metadata["monai_version"])
if is_compatible:
if version != latest_versions[0]:
logger.info(f"Latest version is {latest_versions[0]}, but the compatible version is {version}.")
return version

# if no compatible version is found, return the latest version
return latest_versions[0]


def _get_latest_bundle_version(
source: str, name: str, repo: str, **kwargs: Any
) -> dict[str, list[str] | str] | Any | None:
if source == "ngc":
name = _add_ngc_prefix(name)
model_dict = _get_all_ngc_models(name)
for v in model_dict.values():
if v["name"] == name:
return v["latest"]
return None
return _get_latest_bundle_version_ngc(name)
elif source == "monaihosting":
return _get_latest_bundle_version_monaihosting(name)
elif source == "ngc_private":
headers = kwargs.pop("headers", {})
name = _add_ngc_prefix(name)
return _get_latest_bundle_version_private_registry(name, repo, headers)
return _get_latest_bundle_version_ngc(name, repo=repo, headers=headers)
elif source == "github":
repo_owner, repo_name, tag_name = repo.split("/")
return get_bundle_versions(name, repo=f"{repo_owner}/{repo_name}", tag=tag_name)["latest_version"]
Expand Down Expand Up @@ -470,9 +553,8 @@ def download(
if version_ is None:
version_ = _get_latest_bundle_version(source=source_, name=name_, repo=repo_, headers=headers)
if source_ == "github":
if version_ is not None:
name_ = "_v".join([name_, version_])
_download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_, progress=progress_)
name_ver = "_v".join([name_, version_]) if version_ is not None else name_
_download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_ver, progress=progress_)
elif source_ == "monaihosting":
_download_from_monaihosting(download_path=bundle_dir_, filename=name_, version=version_, progress=progress_)
elif source_ == "ngc":
Expand Down Expand Up @@ -501,6 +583,8 @@ def download(
f"got source: {source_}."
)

_check_monai_version(bundle_dir_, name_)


@deprecated_arg("net_name", since="1.2", removed="1.5", msg_suffix="please use ``model`` instead.")
@deprecated_arg("net_kwargs", since="1.2", removed="1.5", msg_suffix="please use ``model`` instead.")
Expand Down
51 changes: 51 additions & 0 deletions tests/test_bundle_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import tempfile
import unittest
from unittest.case import skipUnless
from unittest.mock import patch

import numpy as np
import torch
Expand All @@ -24,6 +25,7 @@
import monai.networks.nets as nets
from monai.apps import check_hash
from monai.bundle import ConfigParser, create_workflow, load
from monai.bundle.scripts import _examine_monai_version, _list_latest_versions, download
from monai.utils import optional_import
from tests.utils import (
SkipIfBeforePyTorchVersion,
Expand Down Expand Up @@ -207,6 +209,55 @@ def test_monaihosting_source_download_bundle(self, bundle_files, bundle_name, ve
file_path = os.path.join(tempdir, bundle_name, file)
self.assertTrue(os.path.exists(file_path))

@patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2"})
def test_examine_monai_version(self, mock_get_versions):
self.assertTrue(_examine_monai_version("1.1")[0]) # Should return True, compatible
self.assertTrue(_examine_monai_version("1.2rc1")[0]) # Should return True, compatible
self.assertFalse(_examine_monai_version("1.3")[0]) # Should return False, not compatible

@patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2rc1"})
def test_examine_monai_version_rc(self, mock_get_versions):
self.assertTrue(_examine_monai_version("1.2")[0]) # Should return True, compatible
self.assertFalse(_examine_monai_version("1.3")[0]) # Should return False, not compatible

def test_list_latest_versions(self):
"""Test listing of the latest versions."""
data = {
"modelVersions": [
{"createdDate": "2021-01-01", "versionId": "1.0"},
{"createdDate": "2021-01-02", "versionId": "1.1"},
{"createdDate": "2021-01-03", "versionId": "1.2"},
]
}
self.assertEqual(_list_latest_versions(data), ["1.2", "1.1", "1.0"])
self.assertEqual(_list_latest_versions(data, max_versions=2), ["1.2", "1.1"])
data = {
"modelVersions": [
{"createdDate": "2021-01-01", "versionId": "1.0"},
{"createdDate": "2021-01-02", "versionId": "1.1"},
]
}
self.assertEqual(_list_latest_versions(data), ["1.1", "1.0"])

@skip_if_quick
@patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2"})
def test_download_monaihosting(self, mock_get_versions):
"""Test checking MONAI version from a metadata file."""
with patch("monai.bundle.scripts.logger") as mock_logger:
with tempfile.TemporaryDirectory() as tempdir:
download(name="spleen_ct_segmentation", bundle_dir=tempdir, source="monaihosting")
# Should have a warning message because the latest version is using monai > 1.2
mock_logger.warning.assert_called_once()

@skip_if_quick
@patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2"})
def test_download_ngc(self, mock_get_versions):
"""Test checking MONAI version from a metadata file."""
with patch("monai.bundle.scripts.logger") as mock_logger:
with tempfile.TemporaryDirectory() as tempdir:
download(name="spleen_ct_segmentation", bundle_dir=tempdir, source="ngc")
mock_logger.warning.assert_not_called()


@skip_if_no_cuda
class TestLoad(unittest.TestCase):
Expand Down

0 comments on commit 316934a

Please sign in to comment.