Skip to content

Commit 6c8de35

Browse files
8394 Update bundle download API (#8403)
Fixes #8394 . ### Description Add support to download monaihosting bundles from Huggingface. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Yiheng Wang <vennw@nvidia.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent 83dcd35 commit 6c8de35

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

monai/bundle/scripts.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,9 @@ def download(
528528
If source is "ngc_private", you need specify the NGC_API_KEY in the environment variable.
529529
repo: repo name. This argument is used when `url` is `None` and `source` is "github" or "huggingface_hub".
530530
If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag".
531-
If `source` is "huggingface_hub", it should be in the form of "repo_owner/repo_name".
531+
If `source` is "huggingface_hub", it should be in the form of "repo_owner/repo_name". Please note that
532+
bundles for "monaihosting" source are also hosted on Hugging Face Hub, but the "repo_id" is always in the form
533+
of "MONAI/bundle_name", therefore, this argument is not required for "monaihosting" source.
532534
If `source` is "ngc_private", it should be in the form of "org/org_name" or "org/org_name/team/team_name",
533535
or you can specify the environment variable NGC_ORG and NGC_TEAM.
534536
url: url to download the data. If not `None`, data will be downloaded directly
@@ -600,11 +602,15 @@ def download(
600602
_download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_ver, progress=progress_)
601603
elif source_ == "monaihosting":
602604
try:
605+
extract_path = os.path.join(bundle_dir_, name_)
606+
huggingface_hub.snapshot_download(repo_id=f"MONAI/{name_}", revision=version_, local_dir=extract_path)
607+
except (huggingface_hub.errors.RevisionNotFoundError, huggingface_hub.errors.RepositoryNotFoundError):
608+
# if bundle or version not found from huggingface, download from ngc monaihosting
603609
_download_from_monaihosting(
604610
download_path=bundle_dir_, filename=name_, version=version_, progress=progress_
605611
)
606612
except urllib.error.HTTPError:
607-
# for monaihosting bundles, if cannot download from default host, download according to bundle_info
613+
# if also cannot download from ngc monaihosting, download according to bundle_info
608614
_download_from_bundle_info(
609615
download_path=bundle_dir_, filename=name_, version=version_, progress=progress_
610616
)

tests/bundle/test_bundle_download.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@
6363

6464
TEST_CASE_6 = [["models/model.pt", "configs/train.json"], "renalStructures_CECT_segmentation", "0.1.0"]
6565

66+
TEST_CASE_6_HF = [["models/model.pt", "configs/train.yaml"], "mednist_ddpm", "1.0.1"]
67+
6668
TEST_CASE_7 = [
6769
["model.pt", "model.ts", "network.json", "test_output.pt", "test_input.pt"],
6870
"test_bundle",
@@ -193,6 +195,7 @@ def test_ngc_private_source_download_bundle(self, bundle_files, bundle_name, _ur
193195

194196
@parameterized.expand([TEST_CASE_6])
195197
@skip_if_quick
198+
@skipUnless(has_huggingface_hub, "Requires `huggingface_hub`.")
196199
def test_monaihosting_source_download_bundle(self, bundle_files, bundle_name, version):
197200
with skip_if_downloading_fails():
198201
# download a single file from url, also use `args_file`
@@ -239,6 +242,7 @@ def test_list_latest_versions(self):
239242
self.assertEqual(_list_latest_versions(data), ["1.1", "1.0"])
240243

241244
@skip_if_quick
245+
@skipUnless(has_huggingface_hub, "Requires `huggingface_hub`.")
242246
@patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2"})
243247
def test_download_monaihosting(self, mock_get_versions):
244248
"""Test checking MONAI version from a metadata file."""
@@ -333,6 +337,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file)
333337

334338
@parameterized.expand([TEST_CASE_8])
335339
@skip_if_quick
340+
@skipUnless(has_huggingface_hub, "Requires `huggingface_hub`.")
336341
def test_load_weights_with_net_override(self, bundle_name, device, net_override):
337342
with skip_if_downloading_fails():
338343
# download bundle, and load weights from the downloaded path

0 commit comments

Comments
 (0)