From bac3372f6db39189ba93faf6bb8933080ad68032 Mon Sep 17 00:00:00 2001 From: gaoxinxing <15031259256@163.com> Date: Fri, 3 Feb 2023 18:48:07 +0800 Subject: [PATCH] fix dataset copy error --- client/starwhale/base/bundle_copy.py | 8 ++--- client/starwhale/core/dataset/copy.py | 27 ++++++++++++-- client/tests/base/test_copy.py | 21 +++++++++++ client/tests/sdk/test_dataset.py | 52 +++++++++++++++++++++++++++ 4 files changed, 102 insertions(+), 6 deletions(-) diff --git a/client/starwhale/base/bundle_copy.py b/client/starwhale/base/bundle_copy.py index ad8f4d9019..e721a4a029 100644 --- a/client/starwhale/base/bundle_copy.py +++ b/client/starwhale/base/bundle_copy.py @@ -65,7 +65,8 @@ def __init__( force: bool = False, **kw: t.Any, ) -> None: - self.src_uri = Resource(src_uri, typ=ResourceType[typ]).to_uri() + self.src_resource = Resource(src_uri, typ=ResourceType[typ]) + self.src_uri = self.src_resource.to_uri() if self.src_uri.instance_type == InstanceType.CLOUD: p = kw.get("dest_local_project_uri") project = p and Project(p) or None @@ -73,9 +74,8 @@ def __init__( else: project = None - self.dest_uri = Resource( - dest_uri, typ=ResourceType[typ], project=project - ).to_uri() + self.dest_resource = Resource(dest_uri, typ=ResourceType[typ], project=project) + self.dest_uri = self.dest_resource.to_uri() self.typ = typ self.force = force diff --git a/client/starwhale/core/dataset/copy.py b/client/starwhale/core/dataset/copy.py index ce8f486bbe..83e1bf03b0 100644 --- a/client/starwhale/core/dataset/copy.py +++ b/client/starwhale/core/dataset/copy.py @@ -2,6 +2,7 @@ import os import copy +from http import HTTPStatus from typing import Iterator from pathlib import Path @@ -11,6 +12,7 @@ from starwhale.consts import ( FileDesc, FileNode, + HTTPMethod, STANDALONE_INSTANCE, DEFAULT_MANIFEST_NAME, ARCHIVED_SWDS_META_FNAME, @@ -18,8 +20,10 @@ from starwhale.utils.fs import ensure_dir from starwhale.base.bundle_copy import BundleCopy +from ... import URI from .store import DatasetStorage from .tabular import TabularDataset +from ...utils.error import NotFoundError class DatasetCopy(BundleCopy): @@ -106,7 +110,9 @@ def _do_ubd_datastore(self) -> None: ) as local, TabularDataset( name=self.bundle_name, version=self.bundle_version, - project=self.dest_uri.project, + project=self._get_remote_project_name( + self.dest_resource.project.instance.to_uri(), self.dest_uri.project + ), instance_name=self.dest_uri.instance, ) as remote: console.print( @@ -128,7 +134,9 @@ def _do_download_bundle_dir(self, progress: Progress) -> None: ) as local, TabularDataset( name=self.bundle_name, version=self.bundle_version, - project=self.src_uri.project, + project=self._get_remote_project_name( + self.src_resource.project.instance.to_uri(), self.src_uri.project + ), instance_name=self.src_uri.instance, ) as remote: console.print( @@ -141,3 +149,18 @@ def _do_download_bundle_dir(self, progress: Progress) -> None: local._info = copy.deepcopy(remote.info) super()._do_download_bundle_dir(progress) + + def _get_remote_project_name(self, instance: URI, project: str) -> str: + resp = self.do_http_request( + f"/project/{project}", + instance_uri=instance, + method=HTTPMethod.GET, + use_raise=True, + ) + if resp.status_code != HTTPStatus.OK: + raise NotFoundError(f"project:{project}") + + resp_body = resp.json() + + _project = resp_body.get("data") + return _project["name"] if _project else None diff --git a/client/tests/base/test_copy.py b/client/tests/base/test_copy.py index a4dc882db6..ec88b82579 100644 --- a/client/tests/base/test_copy.py +++ b/client/tests/base/test_copy.py @@ -437,6 +437,11 @@ def test_model_copy_l2c(self, rm: Mocker) -> None: @patch("starwhale.core.dataset.copy.TabularDataset.scan") def test_dataset_copy_c2l(self, rm: Mocker, m_td_scan: MagicMock) -> None: version = "ge3tkylgha2tenrtmftdgyjzni3dayq" + rm.request( + HTTPMethod.GET, + "http://1.1.1.1:8182/api/v1/project/myproject", + json={"data": {"id": 1, "name": "myproject"}}, + ) rm.request( HTTPMethod.HEAD, f"http://1.1.1.1:8182/api/v1/project/myproject/dataset/mnist/version/{version}", @@ -610,6 +615,12 @@ def test_dataset_copy_l2c(self, rm: Mocker, m_td_scan: MagicMock) -> None: }, ] + rm.request( + HTTPMethod.GET, + "http://1.1.1.1:8182/api/v1/project/mnist", + json={"data": {"id": 1, "name": "mnist"}}, + ) + for case in cases: head_request = rm.request( HTTPMethod.HEAD, @@ -721,6 +732,11 @@ def test_download_bundle_file(self, rm: Mocker) -> None: @Mocker() @patch("starwhale.core.dataset.copy.TabularDataset.scan") def test_upload_bundle_dir(self, rm: Mocker, m_td_scan: MagicMock) -> None: + rm.request( + HTTPMethod.GET, + "http://1.1.1.1:8182/api/v1/project/project", + json={"data": {"id": 1, "name": "project"}}, + ) rm.request( HTTPMethod.HEAD, "http://1.1.1.1:8182/api/v1/project/project/dataset/mnist/version/abcde", @@ -764,6 +780,11 @@ def test_upload_bundle_dir(self, rm: Mocker, m_td_scan: MagicMock) -> None: def test_download_bundle_dir(self, rm: Mocker, m_td_scan: MagicMock) -> None: hash_name1 = "bfa8805ddc2d43df098e43832c24e494ad" hash_name2 = "f954056e4324495ae5bec4e8e5e6d18f1b" + rm.request( + HTTPMethod.GET, + "http://1.1.1.1:8182/api/v1/project/1", + json={"data": {"id": 1, "name": "project"}}, + ) rm.request( HTTPMethod.HEAD, "http://1.1.1.1:8182/api/v1/project/1/dataset/mnist/version/latest", diff --git a/client/tests/sdk/test_dataset.py b/client/tests/sdk/test_dataset.py index 98399de91f..3093b7d9dc 100644 --- a/client/tests/sdk/test_dataset.py +++ b/client/tests/sdk/test_dataset.py @@ -81,6 +81,8 @@ create_generic_cls, SWDSBinBuildExecutor, ) +from starwhale.base.uricomponents.instance import Instance +from starwhale.base.uricomponents.resource import Resource, ResourceType from .test_base import BaseTestCase @@ -231,6 +233,12 @@ def test_upload(self, rm: Mocker) -> None: dataset_name = "complex_annotations" dataset_version = "123" cloud_project = "project" + + rm.request( + HTTPMethod.GET, + f"{instance_uri}/api/v1/project/project", + json={"data": {"id": 1, "name": "project"}}, + ) rm.request( HTTPMethod.POST, f"{instance_uri}/api/v1/project/{cloud_project}/dataset/{dataset_name}/version/{dataset_version}/file", @@ -364,6 +372,11 @@ def test_download(self, rm: Mocker) -> None: dataset_version = "123" cloud_project = "project" + rm.request( + HTTPMethod.GET, + f"{instance_uri}/api/v1/project/project", + json={"data": {"id": 1, "name": "project"}}, + ) rm.request( HTTPMethod.HEAD, f"{instance_uri}/api/v1/project/{cloud_project}/dataset/{dataset_name}/version/{dataset_version}", @@ -509,6 +522,45 @@ def test_download(self, rm: Mocker) -> None: assert bbox.x == 2 and bbox.y == 2 assert bbox.width == 3 and bbox.height == 4 + @patch("os.environ", {}) + @Mocker() + def test_get_remote_project(self, rm: Mocker) -> None: + instance_uri = "http://1.1.1.1:8182" + project = "1" + rm.request( + HTTPMethod.GET, + f"{instance_uri}/api/v1/project/{project}", + json={"data": {"id": 1, "name": "starwhale"}}, + ) + remote_uri = f"{instance_uri}/project/1/dataset/ds_test/version/v1" + + origin_conf = config.load_swcli_config().copy() + # patch config to pass instance alias check + with patch("starwhale.utils.config.load_swcli_config") as mock_conf: + origin_conf.update( + { + "current_instance": "local", + "instances": { + "foo": {"uri": "http://1.1.1.1:8182"}, + "local": {"uri": "local"}, + }, + } + ) + mock_conf.return_value = origin_conf + dc = DatasetCopy( + src_uri=remote_uri, + dest_uri="", + dest_local_project_uri="self", + typ=URIType.DATASET, + ) + remote_resource = Resource(remote_uri, typ=ResourceType[URIType.DATASET]) + assert ( + dc._get_remote_project_name( + instance=remote_resource.project.instance.to_uri(), project="1" + ) + == "starwhale" + ) + class TestDatasetBuildExecutor(BaseTestCase): def setUp(self) -> None: