Skip to content

Commit

Permalink
Support download bundles from ngc private registry (#7907)
Browse files Browse the repository at this point in the history
### Description

Support download from ngc private registry, this download option
requires ngc api key.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
KumoLiu and pre-commit-ci[bot] authored Jul 17, 2024
1 parent 4fbe800 commit 50d5180
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 9 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/blossom-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ jobs:
run: blossom-ci
env:
OPERATION: 'START-CI-JOB'
NGC_API_KEY: ${{ secrets.NGC_API_KEY }}
NGC_ORG: ${{ secrets.NGC_ORG }}
NGC_TEAM: ${{ secrets.NGC_TEAM }}
CI_SERVER: ${{ secrets.CI_SERVER }}
REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }}

Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ jobs:
conda deactivate
- name: Test env (CPU ${{ runner.os }})
shell: bash -el {0}
env:
NGC_API_KEY: ${{ secrets.NGC_API_KEY }}
NGC_ORG: ${{ secrets.NGC_ORG }}
NGC_TEAM: ${{ secrets.NGC_TEAM }}
run: |
conda activate monai
$(pwd)/runtests.sh --build --unittests
Expand Down
12 changes: 12 additions & 0 deletions .github/workflows/cron.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ jobs:
python -m pip install -r requirements-dev.txt
python -m pip list
- name: Run tests report coverage
env:
NGC_API_KEY: ${{ secrets.NGC_API_KEY }}
NGC_ORG: ${{ secrets.NGC_ORG }}
NGC_TEAM: ${{ secrets.NGC_TEAM }}
run: |
export LAUNCH_DELAY=$[ $RANDOM % 16 * 60 ]
echo "Sleep $LAUNCH_DELAY"
Expand Down Expand Up @@ -94,6 +98,10 @@ jobs:
python -m pip install -r requirements-dev.txt
python -m pip list
- name: Run tests report coverage
env:
NGC_API_KEY: ${{ secrets.NGC_API_KEY }}
NGC_ORG: ${{ secrets.NGC_ORG }}
NGC_TEAM: ${{ secrets.NGC_TEAM }}
run: |
export LAUNCH_DELAY=$[ $RANDOM % 16 * 60 ]
echo "Sleep $LAUNCH_DELAY"
Expand Down Expand Up @@ -196,6 +204,10 @@ jobs:
- name: Run tests report coverage
# The docker image process has done the compilation.
# BUILD_MONAI=1 is necessary for triggering the USE_COMPILED flag.
env:
NGC_API_KEY: ${{ secrets.NGC_API_KEY }}
NGC_ORG: ${{ secrets.NGC_ORG }}
NGC_TEAM: ${{ secrets.NGC_TEAM }}
run: |
cd /opt/monai
nvidia-smi
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,6 @@ jobs:
shell: bash
env:
QUICKTEST: True
NGC_API_KEY: ${{ secrets.NGC_API_KEY }}
NGC_ORG: ${{ secrets.NGC_ORG }}
NGC_TEAM: ${{ secrets.NGC_TEAM }}
3 changes: 3 additions & 0 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ jobs:
shell: bash
env:
BUILD_MONAI: 1
NGC_API_KEY: ${{ secrets.NGC_API_KEY }}
NGC_ORG: ${{ secrets.NGC_ORG }}
NGC_TEAM: ${{ secrets.NGC_TEAM }}
run: ./runtests.sh --build --net

- name: Add reaction
Expand Down
9 changes: 9 additions & 0 deletions .github/workflows/pythonapp-min.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ jobs:
shell: bash
env:
QUICKTEST: True
NGC_API_KEY: ${{ secrets.NGC_API_KEY }}
NGC_ORG: ${{ secrets.NGC_ORG }}
NGC_TEAM: ${{ secrets.NGC_TEAM }}

min-dep-py3: # min dependencies installed tests for different python
runs-on: ubuntu-latest
Expand Down Expand Up @@ -112,6 +115,9 @@ jobs:
./runtests.sh --min
env:
QUICKTEST: True
NGC_API_KEY: ${{ secrets.NGC_API_KEY }}
NGC_ORG: ${{ secrets.NGC_ORG }}
NGC_TEAM: ${{ secrets.NGC_TEAM }}

min-dep-pytorch: # min dependencies installed tests for different pytorch
runs-on: ubuntu-latest
Expand Down Expand Up @@ -161,3 +167,6 @@ jobs:
./runtests.sh --min
env:
QUICKTEST: True
NGC_API_KEY: ${{ secrets.NGC_API_KEY }}
NGC_ORG: ${{ secrets.NGC_ORG }}
NGC_TEAM: ${{ secrets.NGC_TEAM }}
8 changes: 8 additions & 0 deletions .github/workflows/setupapp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ jobs:
python -m pip install --upgrade torch torchvision
python -m pip install -r requirements-dev.txt
- name: Run unit tests report coverage
env:
NGC_API_KEY: ${{ secrets.NGC_API_KEY }}
NGC_ORG: ${{ secrets.NGC_ORG }}
NGC_TEAM: ${{ secrets.NGC_TEAM }}
run: |
python -m pip list
git config --global --add safe.directory /__w/MONAI/MONAI
Expand Down Expand Up @@ -104,6 +108,10 @@ jobs:
python -m pip install --upgrade pip wheel
python -m pip install -r requirements-dev.txt
- name: Run quick tests CPU ubuntu
env:
NGC_API_KEY: ${{ secrets.NGC_API_KEY }}
NGC_ORG: ${{ secrets.NGC_ORG }}
NGC_TEAM: ${{ secrets.NGC_TEAM }}
run: |
python -m pip list
python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))'
Expand Down
114 changes: 106 additions & 8 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
import re
import warnings
import zipfile
from collections.abc import Mapping, Sequence
from pathlib import Path
from pydoc import locate
Expand Down Expand Up @@ -171,6 +172,10 @@ def _get_ngc_bundle_url(model_name: str, version: str) -> str:
return f"https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit/{model_name.lower()}/versions/{version}/zip"


def _get_ngc_private_bundle_url(model_name: str, version: str, repo: str) -> str:
return f"https://api.ngc.nvidia.com/v2/{repo}/models/{model_name.lower()}/versions/{version}/zip"


def _get_monaihosting_bundle_url(model_name: str, version: str) -> str:
monaihosting_root_path = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting"
return f"{monaihosting_root_path}/{model_name.lower()}/versions/{version}/files/{model_name}_v{version}.zip"
Expand Down Expand Up @@ -219,6 +224,48 @@ def _download_from_ngc(
extractall(filepath=filepath, output_dir=extract_path, has_base=True)


def _download_from_ngc_private(
download_path: Path, filename: str, version: str, remove_prefix: str | None, repo: str, headers: dict | None = None
) -> None:
# ensure prefix is contained
filename = _add_ngc_prefix(filename)
request_url = _get_ngc_private_bundle_url(model_name=filename, version=version, repo=repo)
if has_requests:
headers = {} if headers is None else headers
response = requests_get(request_url, headers=headers)
response.raise_for_status()
else:
raise ValueError("NGC API requires requests package. Please install it.")

zip_path = download_path / f"{filename}_v{version}.zip"
with open(zip_path, "wb") as f:
f.write(response.content)
logger.info(f"Downloading: {zip_path}.")
if remove_prefix:
filename = _remove_ngc_prefix(filename, prefix=remove_prefix)
extract_path = download_path / f"{filename}"
with zipfile.ZipFile(zip_path, "r") as z:
z.extractall(extract_path)
logger.info(f"Writing into directory: {extract_path}.")


def _get_ngc_token(api_key, retry=0):
"""Try to connect to NGC."""
url = "https://authn.nvidia.com/token?service=ngc"
headers = {"Accept": "application/json", "Authorization": "ApiKey " + api_key}
if has_requests:
response = requests_get(url, headers=headers)
if not response.ok:
# retry 3 times, if failed, raise an error.
if retry < 3:
logger.info(f"Retrying {retry} time(s) to GET {url}.")
return _get_ngc_token(url, retry + 1)
raise RuntimeError("NGC API response is not ok. Failed to get token.")
else:
token = response.json()["token"]
return token


def _get_latest_bundle_version_monaihosting(name):
url = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting"
full_url = f"{url}/{name.lower()}"
Expand All @@ -227,12 +274,28 @@ def _get_latest_bundle_version_monaihosting(name):
resp = requests_get(full_url)
resp.raise_for_status()
else:
raise ValueError("NGC API requires requests package. Please install it.")
raise ValueError("NGC API requires requests package. Please install it.")
model_info = json.loads(resp.text)
return model_info["model"]["latestVersionIdStr"]


def _get_latest_bundle_version(source: str, name: str, repo: str) -> dict[str, list[str] | str] | Any | None:
def _get_latest_bundle_version_private_registry(name, repo, headers=None):
url = f"https://api.ngc.nvidia.com/v2/{repo}/models"
full_url = f"{url}/{name.lower()}"
requests_get, has_requests = optional_import("requests", name="get")
if has_requests:
headers = {} if headers is None else headers
resp = requests_get(full_url, headers=headers)
resp.raise_for_status()
else:
raise ValueError("NGC API requires requests package. Please install it.")
model_info = json.loads(resp.text)
return model_info["model"]["latestVersionIdStr"]


def _get_latest_bundle_version(
source: str, name: str, repo: str, **kwargs: Any
) -> dict[str, list[str] | str] | Any | None:
if source == "ngc":
name = _add_ngc_prefix(name)
model_dict = _get_all_ngc_models(name)
Expand All @@ -242,6 +305,10 @@ def _get_latest_bundle_version(source: str, name: str, repo: str) -> dict[str, l
return None
elif source == "monaihosting":
return _get_latest_bundle_version_monaihosting(name)
elif source == "ngc_private":
headers = kwargs.pop("headers", {})
name = _add_ngc_prefix(name)
return _get_latest_bundle_version_private_registry(name, repo, headers)
elif source == "github":
repo_owner, repo_name, tag_name = repo.split("/")
return get_bundle_versions(name, repo=f"{repo_owner}/{repo_name}", tag=tag_name)["latest_version"]
Expand Down Expand Up @@ -308,6 +375,9 @@ def download(
# Execute this module as a CLI entry, and download bundle via URL:
python -m monai.bundle download --name <bundle_name> --url <url>
# Execute this module as a CLI entry, and download bundle from ngc_private with latest version:
python -m monai.bundle download --name <bundle_name> --source "ngc_private" --bundle_dir "./" --repo "org/org_name"
# Set default args of `run` in a JSON / YAML file, help to record and simplify the command line.
# Other args still can override the default args at runtime.
# The content of the JSON / YAML file is a dictionary. For example:
Expand All @@ -328,10 +398,13 @@ def download(
Default is `bundle` subfolder under `torch.hub.get_dir()`.
source: storage location name. This argument is used when `url` is `None`.
In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and
it should be "ngc", "monaihosting", "github", or "huggingface_hub".
it should be "ngc", "monaihosting", "github", "ngc_private", or "huggingface_hub".
If source is "ngc_private", you need specify the NGC_API_KEY in the environment variable.
repo: repo name. This argument is used when `url` is `None` and `source` is "github" or "huggingface_hub".
If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag".
If `source` is "huggingface_hub", it should be in the form of "repo_owner/repo_name".
If `source` is "ngc_private", it should be in the form of "org/org_name" or "org/org_name/team/team_name",
or you can specify the environment variable NGC_ORG and NGC_TEAM.
url: url to download the data. If not `None`, data will be downloaded directly
and `source` will not be checked.
If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`.
Expand Down Expand Up @@ -363,11 +436,18 @@ def download(

bundle_dir_ = _process_bundle_dir(bundle_dir_)
if repo_ is None:
repo_ = "Project-MONAI/model-zoo/hosting_storage_v1"
if len(repo_.split("/")) != 3 and source_ != "huggingface_hub":
raise ValueError("repo should be in the form of `repo_owner/repo_name/release_tag`.")
org_ = os.getenv("NGC_ORG", None)
team_ = os.getenv("NGC_TEAM", None)
if org_ is not None:
repo_ = f"org/{org_}/team/{team_}" if team_ is not None else f"org/{org_}"
else:
repo_ = "Project-MONAI/model-zoo/hosting_storage_v1"
if len(repo_.split("/")) not in (2, 4) and source_ == "ngc_private":
raise ValueError(f"repo should be in the form of `org/org_name/team/team_name` or `org/org_name`, got {repo_}.")
if len(repo_.split("/")) != 3 and source_ == "github":
raise ValueError(f"repo should be in the form of `repo_owner/repo_name/release_tag`, got {repo_}.")
elif len(repo_.split("/")) != 2 and source_ == "huggingface_hub":
raise ValueError("Hugging Face Hub repo should be in the form of `repo_owner/repo_name`")
raise ValueError(f"Hugging Face Hub repo should be in the form of `repo_owner/repo_name`, got {repo_}.")
if url_ is not None:
if name_ is not None:
filepath = bundle_dir_ / f"{name_}.zip"
Expand All @@ -376,10 +456,19 @@ def download(
download_url(url=url_, filepath=filepath, hash_val=None, progress=progress_)
extractall(filepath=filepath, output_dir=bundle_dir_, has_base=True)
else:
headers = {}
if name_ is None:
raise ValueError(f"To download from source: {source_}, `name` must be provided.")
if source == "ngc_private":
api_key = os.getenv("NGC_API_KEY", None)
if api_key is None:
raise ValueError("API key is required for ngc_private source.")
else:
token = _get_ngc_token(api_key)
headers = {"Authorization": f"Bearer {token}"}

if version_ is None:
version_ = _get_latest_bundle_version(source=source_, name=name_, repo=repo_)
version_ = _get_latest_bundle_version(source=source_, name=name_, repo=repo_, headers=headers)
if source_ == "github":
if version_ is not None:
name_ = "_v".join([name_, version_])
Expand All @@ -394,6 +483,15 @@ def download(
remove_prefix=remove_prefix_,
progress=progress_,
)
elif source_ == "ngc_private":
_download_from_ngc_private(
download_path=bundle_dir_,
filename=name_,
version=version_,
remove_prefix=remove_prefix_,
repo=repo_,
headers=headers,
)
elif source_ == "huggingface_hub":
extract_path = os.path.join(bundle_dir_, name_)
huggingface_hub.snapshot_download(repo_id=repo_, revision=version_, local_dir=extract_path)
Expand Down
19 changes: 18 additions & 1 deletion tests/test_bundle_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
TEST_CASE_5 = [
["models/model.pt", "models/model.ts", "configs/train.json"],
"brats_mri_segmentation",
"https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting/brats_mri_segmentation/versions/0.3.9/files/brats_mri_segmentation_v0.3.9.zip",
"https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting/brats_mri_segmentation/versions/0.4.0/files/brats_mri_segmentation_v0.4.0.zip",
]

TEST_CASE_6 = [["models/model.pt", "configs/train.json"], "renalStructures_CECT_segmentation", "0.1.0"]
Expand Down Expand Up @@ -173,6 +173,23 @@ def test_monaihosting_url_download_bundle(self, bundle_files, bundle_name, url):
file_path = os.path.join(tempdir, bundle_name, file)
self.assertTrue(os.path.exists(file_path))

@parameterized.expand([TEST_CASE_5])
@skip_if_quick
def test_ngc_private_source_download_bundle(self, bundle_files, bundle_name, _url):
with skip_if_downloading_fails():
# download a single file from url, also use `args_file`
with tempfile.TemporaryDirectory() as tempdir:
def_args = {"name": bundle_name, "bundle_dir": tempdir}
def_args_file = os.path.join(tempdir, "def_args.json")
parser = ConfigParser()
parser.export_config_file(config=def_args, filepath=def_args_file)
cmd = ["coverage", "run", "-m", "monai.bundle", "download", "--args_file", def_args_file]
cmd += ["--progress", "False", "--source", "ngc_private"]
command_line_tests(cmd)
for file in bundle_files:
file_path = os.path.join(tempdir, bundle_name, file)
self.assertTrue(os.path.exists(file_path))

@parameterized.expand([TEST_CASE_6])
@skip_if_quick
def test_monaihosting_source_download_bundle(self, bundle_files, bundle_name, version):
Expand Down

0 comments on commit 50d5180

Please sign in to comment.