Skip to content

Commit

Permalink
fix(client): copy remote rc without version fail (#2288)
Browse files Browse the repository at this point in the history
  • Loading branch information
jialeicui authored May 30, 2023
1 parent 1e53ce7 commit db079ee
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 10 deletions.
9 changes: 5 additions & 4 deletions client/starwhale/base/uri/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def __init__(
self.typ = ResourceType(info["rc_type"][:-1]) # remove the last 's'
self.name = info.get("rc_id") or ""
self.version = info.get("rc_version") or ""
self._refine_remote_rc_info()
if refine:
self._refine_remote_rc_info()
return

if project:
Expand Down Expand Up @@ -226,17 +227,17 @@ def _parse_by_version(self, ver: str) -> None:
def _refine_remote_rc_info(self) -> None:
if self.project.instance.is_local:
raise VerifyException("only used for remote resources")
if not self.name or not self.version:
# TODO guess by name or version only
if not self.name or self.typ in {ResourceType.job, ResourceType.evaluation}:
return
ver = self.version or "latest"
if self._remote_info:
# have remote info, assume it is already refined
return

base_path = f"{self.instance.url}/api/{SW_API_VERSION}/project/{self.project.name}/{self.typ.value}/{self.name}"
headers = {"Authorization": self.instance.token}
resp = requests.get(
base_path, timeout=60, params={"versionUrl": self.version}, headers=headers
base_path, timeout=60, params={"versionUrl": ver}, headers=headers
)
resp.raise_for_status()
self._remote_info = resp.json().get("data", {})
Expand Down
23 changes: 17 additions & 6 deletions client/tests/base/uri/test_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,25 @@ def __eq__(self, other: Resource):
for url, expect in tests.items():
assert expect == Resource(url)

get.return_value.json.return_value = {
"data": {"name": "name in server", "versionName": "version in server"}
}
def response_of_get(*args, **kwargs):
ret = type("", (), {})()
ret.raise_for_status = lambda: None
ver = "version in server"
if kwargs.get("params", {}).get("versionUrl", "") == "latest":
ver = "latest of the version"
ret.json = lambda: {"data": {"name": "name in server", "versionName": ver}}
return ret

get.side_effect = response_of_get

for url, expect in tests.items():
# only the resource with name and version will be parsed
if expect.name and expect.version:
if expect.name:
expect.name = "name in server"
expect.version = "version in server"
if expect.version:
expect.version = "version in server"
else:
expect.version = "latest of the version"
assert expect == Resource(url)

with self.assertRaises(Exception):
Expand All @@ -221,7 +232,7 @@ def test_short_uri(self, load_conf: MagicMock) -> None:
}

for uri, expect in tests.items():
p = Resource(uri, typ=ResourceType.runtime)
p = Resource(uri, typ=ResourceType.runtime, refine=False)
assert p.name == expect[0]
assert p.project.name == expect[1]
assert p.instance.alias == expect[2]
Expand Down
1 change: 1 addition & 0 deletions client/tests/sdk/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1703,6 +1703,7 @@ def test_put_for_cloud(self, rm: Mocker, m_conf: MagicMock) -> None:
"mnist",
project=Project("cloud://foo/project/self"),
typ=ResourceType.dataset,
refine=False,
),
)
mdb.put(
Expand Down
1 change: 1 addition & 0 deletions client/tests/sdk/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@ def test_remote_batch_sign(
"local": {"uri": "local"},
},
}
rm.get("http://localhost/api/v1/project/x/dataset/mnist", json={})
m_summary.return_value = DatasetSummary(rows=4)
tdsc = m_sc()
tdsc.get_scan_range.side_effect = [["a", "d"], None]
Expand Down

0 comments on commit db079ee

Please sign in to comment.