diff --git a/client/starwhale/base/bundle_copy.py b/client/starwhale/base/bundle_copy.py index ad8f4d9019..e721a4a029 100644 --- a/client/starwhale/base/bundle_copy.py +++ b/client/starwhale/base/bundle_copy.py @@ -65,7 +65,8 @@ def __init__( force: bool = False, **kw: t.Any, ) -> None: - self.src_uri = Resource(src_uri, typ=ResourceType[typ]).to_uri() + self.src_resource = Resource(src_uri, typ=ResourceType[typ]) + self.src_uri = self.src_resource.to_uri() if self.src_uri.instance_type == InstanceType.CLOUD: p = kw.get("dest_local_project_uri") project = p and Project(p) or None @@ -73,9 +74,8 @@ def __init__( else: project = None - self.dest_uri = Resource( - dest_uri, typ=ResourceType[typ], project=project - ).to_uri() + self.dest_resource = Resource(dest_uri, typ=ResourceType[typ], project=project) + self.dest_uri = self.dest_resource.to_uri() self.typ = typ self.force = force diff --git a/client/starwhale/core/dataset/copy.py b/client/starwhale/core/dataset/copy.py index ce8f486bbe..f9837c6917 100644 --- a/client/starwhale/core/dataset/copy.py +++ b/client/starwhale/core/dataset/copy.py @@ -2,7 +2,7 @@ import os import copy -from typing import Iterator +from typing import Any, Iterator from pathlib import Path from rich.progress import Progress @@ -11,10 +11,12 @@ from starwhale.consts import ( FileDesc, FileNode, + HTTPMethod, STANDALONE_INSTANCE, DEFAULT_MANIFEST_NAME, ARCHIVED_SWDS_META_FNAME, ) +from starwhale.base.uri import URI from starwhale.utils.fs import ensure_dir from starwhale.base.bundle_copy import BundleCopy @@ -106,7 +108,9 @@ def _do_ubd_datastore(self) -> None: ) as local, TabularDataset( name=self.bundle_name, version=self.bundle_version, - project=self.dest_uri.project, + project=self._get_remote_project_name( + self.dest_resource.project.instance.to_uri(), self.dest_uri.project + ), instance_name=self.dest_uri.instance, ) as remote: console.print( @@ -128,7 +132,9 @@ def _do_download_bundle_dir(self, progress: Progress) -> None: ) as local, TabularDataset( name=self.bundle_name, version=self.bundle_version, - project=self.src_uri.project, + project=self._get_remote_project_name( + self.src_resource.project.instance.to_uri(), self.src_uri.project + ), instance_name=self.src_uri.instance, ) as remote: console.print( @@ -141,3 +147,13 @@ def _do_download_bundle_dir(self, progress: Progress) -> None: local._info = copy.deepcopy(remote.info) super()._do_download_bundle_dir(progress) + + def _get_remote_project_name(self, instance: URI, project: str) -> Any: + resp = self.do_http_request( + f"/project/{project}", + instance_uri=instance, + method=HTTPMethod.GET, + use_raise=True, + ) + + return resp.json().get("data", {}).get("name") diff --git a/client/tests/base/test_copy.py b/client/tests/base/test_copy.py index a4dc882db6..ec88b82579 100644 --- a/client/tests/base/test_copy.py +++ b/client/tests/base/test_copy.py @@ -437,6 +437,11 @@ def test_model_copy_l2c(self, rm: Mocker) -> None: @patch("starwhale.core.dataset.copy.TabularDataset.scan") def test_dataset_copy_c2l(self, rm: Mocker, m_td_scan: MagicMock) -> None: version = "ge3tkylgha2tenrtmftdgyjzni3dayq" + rm.request( + HTTPMethod.GET, + "http://1.1.1.1:8182/api/v1/project/myproject", + json={"data": {"id": 1, "name": "myproject"}}, + ) rm.request( HTTPMethod.HEAD, f"http://1.1.1.1:8182/api/v1/project/myproject/dataset/mnist/version/{version}", @@ -610,6 +615,12 @@ def test_dataset_copy_l2c(self, rm: Mocker, m_td_scan: MagicMock) -> None: }, ] + rm.request( + HTTPMethod.GET, + "http://1.1.1.1:8182/api/v1/project/mnist", + json={"data": {"id": 1, "name": "mnist"}}, + ) + for case in cases: head_request = rm.request( HTTPMethod.HEAD, @@ -721,6 +732,11 @@ def test_download_bundle_file(self, rm: Mocker) -> None: @Mocker() @patch("starwhale.core.dataset.copy.TabularDataset.scan") def test_upload_bundle_dir(self, rm: Mocker, m_td_scan: MagicMock) -> None: + rm.request( + HTTPMethod.GET, + "http://1.1.1.1:8182/api/v1/project/project", + json={"data": {"id": 1, "name": "project"}}, + ) rm.request( HTTPMethod.HEAD, "http://1.1.1.1:8182/api/v1/project/project/dataset/mnist/version/abcde", @@ -764,6 +780,11 @@ def test_upload_bundle_dir(self, rm: Mocker, m_td_scan: MagicMock) -> None: def test_download_bundle_dir(self, rm: Mocker, m_td_scan: MagicMock) -> None: hash_name1 = "bfa8805ddc2d43df098e43832c24e494ad" hash_name2 = "f954056e4324495ae5bec4e8e5e6d18f1b" + rm.request( + HTTPMethod.GET, + "http://1.1.1.1:8182/api/v1/project/1", + json={"data": {"id": 1, "name": "project"}}, + ) rm.request( HTTPMethod.HEAD, "http://1.1.1.1:8182/api/v1/project/1/dataset/mnist/version/latest", diff --git a/client/tests/sdk/test_dataset.py b/client/tests/sdk/test_dataset.py index e3edf63393..bb3c0eaad3 100644 --- a/client/tests/sdk/test_dataset.py +++ b/client/tests/sdk/test_dataset.py @@ -81,6 +81,7 @@ _header_struct, create_generic_cls, ) +from starwhale.base.uricomponents.resource import Resource, ResourceType from .test_base import BaseTestCase @@ -241,6 +242,12 @@ def test_upload(self, rm: Mocker) -> None: dataset_name = "complex_annotations" dataset_version = "123" cloud_project = "project" + + rm.request( + HTTPMethod.GET, + f"{instance_uri}/api/v1/project/project", + json={"data": {"id": 1, "name": "project"}}, + ) rm.request( HTTPMethod.POST, f"{instance_uri}/api/v1/project/{cloud_project}/dataset/{dataset_name}/version/{dataset_version}/file", @@ -383,6 +390,11 @@ def test_download(self, rm: Mocker) -> None: dataset_version = "123" cloud_project = "project" + rm.request( + HTTPMethod.GET, + f"{instance_uri}/api/v1/project/project", + json={"data": {"id": 1, "name": "project"}}, + ) rm.request( HTTPMethod.HEAD, f"{instance_uri}/api/v1/project/{cloud_project}/dataset/{dataset_name}/version/{dataset_version}", @@ -514,6 +526,45 @@ def test_download(self, rm: Mocker) -> None: assert bbox.x == 2 and bbox.y == 2 assert bbox.width == 3 and bbox.height == 4 + @patch("os.environ", {}) + @Mocker() + def test_get_remote_project(self, rm: Mocker) -> None: + instance_uri = "http://1.1.1.1:8182" + project = "1" + rm.request( + HTTPMethod.GET, + f"{instance_uri}/api/v1/project/{project}", + json={"data": {"id": 1, "name": "starwhale"}}, + ) + remote_uri = f"{instance_uri}/project/1/dataset/ds_test/version/v1" + + origin_conf = config.load_swcli_config().copy() + # patch config to pass instance alias check + with patch("starwhale.utils.config.load_swcli_config") as mock_conf: + origin_conf.update( + { + "current_instance": "local", + "instances": { + "foo": {"uri": "http://1.1.1.1:8182"}, + "local": {"uri": "local"}, + }, + } + ) + mock_conf.return_value = origin_conf + dc = DatasetCopy( + src_uri=remote_uri, + dest_uri="", + dest_local_project_uri="self", + typ=URIType.DATASET, + ) + remote_resource = Resource(remote_uri, typ=ResourceType[URIType.DATASET]) + assert ( + dc._get_remote_project_name( + instance=remote_resource.project.instance.to_uri(), project="1" + ) + == "starwhale" + ) + class MockBinWriter: def __init__(self) -> None: