Skip to content

Commit

Permalink
fix dataset copy error
Browse files Browse the repository at this point in the history
  • Loading branch information
goldenxinxing committed Feb 6, 2023
1 parent 839e47b commit bac3372
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 6 deletions.
8 changes: 4 additions & 4 deletions client/starwhale/base/bundle_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,17 @@ def __init__(
force: bool = False,
**kw: t.Any,
) -> None:
self.src_uri = Resource(src_uri, typ=ResourceType[typ]).to_uri()
self.src_resource = Resource(src_uri, typ=ResourceType[typ])
self.src_uri = self.src_resource.to_uri()
if self.src_uri.instance_type == InstanceType.CLOUD:
p = kw.get("dest_local_project_uri")
project = p and Project(p) or None
dest_uri = self.src_uri.object.name if dest_uri.strip() == "." else dest_uri
else:
project = None

self.dest_uri = Resource(
dest_uri, typ=ResourceType[typ], project=project
).to_uri()
self.dest_resource = Resource(dest_uri, typ=ResourceType[typ], project=project)
self.dest_uri = self.dest_resource.to_uri()

self.typ = typ
self.force = force
Expand Down
27 changes: 25 additions & 2 deletions client/starwhale/core/dataset/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import copy
from http import HTTPStatus
from typing import Iterator
from pathlib import Path

Expand All @@ -11,15 +12,18 @@
from starwhale.consts import (
FileDesc,
FileNode,
HTTPMethod,
STANDALONE_INSTANCE,
DEFAULT_MANIFEST_NAME,
ARCHIVED_SWDS_META_FNAME,
)
from starwhale.utils.fs import ensure_dir
from starwhale.base.bundle_copy import BundleCopy

from ... import URI
from .store import DatasetStorage
from .tabular import TabularDataset
from ...utils.error import NotFoundError


class DatasetCopy(BundleCopy):
Expand Down Expand Up @@ -106,7 +110,9 @@ def _do_ubd_datastore(self) -> None:
) as local, TabularDataset(
name=self.bundle_name,
version=self.bundle_version,
project=self.dest_uri.project,
project=self._get_remote_project_name(
self.dest_resource.project.instance.to_uri(), self.dest_uri.project
),
instance_name=self.dest_uri.instance,
) as remote:
console.print(
Expand All @@ -128,7 +134,9 @@ def _do_download_bundle_dir(self, progress: Progress) -> None:
) as local, TabularDataset(
name=self.bundle_name,
version=self.bundle_version,
project=self.src_uri.project,
project=self._get_remote_project_name(
self.src_resource.project.instance.to_uri(), self.src_uri.project
),
instance_name=self.src_uri.instance,
) as remote:
console.print(
Expand All @@ -141,3 +149,18 @@ def _do_download_bundle_dir(self, progress: Progress) -> None:
local._info = copy.deepcopy(remote.info)

super()._do_download_bundle_dir(progress)

def _get_remote_project_name(self, instance: URI, project: str) -> str:
resp = self.do_http_request(
f"/project/{project}",
instance_uri=instance,
method=HTTPMethod.GET,
use_raise=True,
)
if resp.status_code != HTTPStatus.OK:
raise NotFoundError(f"project:{project}")

resp_body = resp.json()

_project = resp_body.get("data")
return _project["name"] if _project else None
21 changes: 21 additions & 0 deletions client/tests/base/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,11 @@ def test_model_copy_l2c(self, rm: Mocker) -> None:
@patch("starwhale.core.dataset.copy.TabularDataset.scan")
def test_dataset_copy_c2l(self, rm: Mocker, m_td_scan: MagicMock) -> None:
version = "ge3tkylgha2tenrtmftdgyjzni3dayq"
rm.request(
HTTPMethod.GET,
"http://1.1.1.1:8182/api/v1/project/myproject",
json={"data": {"id": 1, "name": "myproject"}},
)
rm.request(
HTTPMethod.HEAD,
f"http://1.1.1.1:8182/api/v1/project/myproject/dataset/mnist/version/{version}",
Expand Down Expand Up @@ -610,6 +615,12 @@ def test_dataset_copy_l2c(self, rm: Mocker, m_td_scan: MagicMock) -> None:
},
]

rm.request(
HTTPMethod.GET,
"http://1.1.1.1:8182/api/v1/project/mnist",
json={"data": {"id": 1, "name": "mnist"}},
)

for case in cases:
head_request = rm.request(
HTTPMethod.HEAD,
Expand Down Expand Up @@ -721,6 +732,11 @@ def test_download_bundle_file(self, rm: Mocker) -> None:
@Mocker()
@patch("starwhale.core.dataset.copy.TabularDataset.scan")
def test_upload_bundle_dir(self, rm: Mocker, m_td_scan: MagicMock) -> None:
rm.request(
HTTPMethod.GET,
"http://1.1.1.1:8182/api/v1/project/project",
json={"data": {"id": 1, "name": "project"}},
)
rm.request(
HTTPMethod.HEAD,
"http://1.1.1.1:8182/api/v1/project/project/dataset/mnist/version/abcde",
Expand Down Expand Up @@ -764,6 +780,11 @@ def test_upload_bundle_dir(self, rm: Mocker, m_td_scan: MagicMock) -> None:
def test_download_bundle_dir(self, rm: Mocker, m_td_scan: MagicMock) -> None:
hash_name1 = "bfa8805ddc2d43df098e43832c24e494ad"
hash_name2 = "f954056e4324495ae5bec4e8e5e6d18f1b"
rm.request(
HTTPMethod.GET,
"http://1.1.1.1:8182/api/v1/project/1",
json={"data": {"id": 1, "name": "project"}},
)
rm.request(
HTTPMethod.HEAD,
"http://1.1.1.1:8182/api/v1/project/1/dataset/mnist/version/latest",
Expand Down
52 changes: 52 additions & 0 deletions client/tests/sdk/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@
create_generic_cls,
SWDSBinBuildExecutor,
)
from starwhale.base.uricomponents.instance import Instance
from starwhale.base.uricomponents.resource import Resource, ResourceType

from .test_base import BaseTestCase

Expand Down Expand Up @@ -231,6 +233,12 @@ def test_upload(self, rm: Mocker) -> None:
dataset_name = "complex_annotations"
dataset_version = "123"
cloud_project = "project"

rm.request(
HTTPMethod.GET,
f"{instance_uri}/api/v1/project/project",
json={"data": {"id": 1, "name": "project"}},
)
rm.request(
HTTPMethod.POST,
f"{instance_uri}/api/v1/project/{cloud_project}/dataset/{dataset_name}/version/{dataset_version}/file",
Expand Down Expand Up @@ -364,6 +372,11 @@ def test_download(self, rm: Mocker) -> None:
dataset_version = "123"
cloud_project = "project"

rm.request(
HTTPMethod.GET,
f"{instance_uri}/api/v1/project/project",
json={"data": {"id": 1, "name": "project"}},
)
rm.request(
HTTPMethod.HEAD,
f"{instance_uri}/api/v1/project/{cloud_project}/dataset/{dataset_name}/version/{dataset_version}",
Expand Down Expand Up @@ -509,6 +522,45 @@ def test_download(self, rm: Mocker) -> None:
assert bbox.x == 2 and bbox.y == 2
assert bbox.width == 3 and bbox.height == 4

@patch("os.environ", {})
@Mocker()
def test_get_remote_project(self, rm: Mocker) -> None:
instance_uri = "http://1.1.1.1:8182"
project = "1"
rm.request(
HTTPMethod.GET,
f"{instance_uri}/api/v1/project/{project}",
json={"data": {"id": 1, "name": "starwhale"}},
)
remote_uri = f"{instance_uri}/project/1/dataset/ds_test/version/v1"

origin_conf = config.load_swcli_config().copy()
# patch config to pass instance alias check
with patch("starwhale.utils.config.load_swcli_config") as mock_conf:
origin_conf.update(
{
"current_instance": "local",
"instances": {
"foo": {"uri": "http://1.1.1.1:8182"},
"local": {"uri": "local"},
},
}
)
mock_conf.return_value = origin_conf
dc = DatasetCopy(
src_uri=remote_uri,
dest_uri="",
dest_local_project_uri="self",
typ=URIType.DATASET,
)
remote_resource = Resource(remote_uri, typ=ResourceType[URIType.DATASET])
assert (
dc._get_remote_project_name(
instance=remote_resource.project.instance.to_uri(), project="1"
)
== "starwhale"
)


class TestDatasetBuildExecutor(BaseTestCase):
def setUp(self) -> None:
Expand Down

0 comments on commit bac3372

Please sign in to comment.