Skip to content

Commit

Permalink
refactor client model copy
Browse files Browse the repository at this point in the history
  • Loading branch information
goldenxinxing committed May 15, 2023
1 parent 1a0b725 commit 8d250c6
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 45 deletions.
70 changes: 34 additions & 36 deletions client/starwhale/base/bundle_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,6 @@ class _UploadPhase:


class BundleCopy(CloudRequestMixed):
progress: Progress = Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeElapsedColumn(),
TotalFileSizeColumn(),
TransferSpeedColumn(),
console=console.rich_console,
refresh_per_second=0.2,
)

def __init__(
self,
src_uri: str | Resource,
Expand Down Expand Up @@ -258,34 +246,44 @@ def do(self) -> None:

console.print(f":construction: start to copy {self.src_uri} -> {self.dest_uri}")

if self.src_uri.instance.is_local:
if self.typ == ResourceType.model:
self._do_upload_bundle_dir(BundleCopy.progress)
elif self.typ == ResourceType.runtime:
self._do_upload_bundle_tar(BundleCopy.progress)
else:
raise NoSupportError(
f"no support to copy {self.typ} from standalone to server"
)
else:
if self.typ == ResourceType.model:
self._do_download_bundle_dir(BundleCopy.progress)
elif self.typ == ResourceType.runtime:
self._do_download_bundle_tar(BundleCopy.progress)
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeElapsedColumn(),
TotalFileSizeColumn(),
TransferSpeedColumn(),
console=console.rich_console,
refresh_per_second=0.2,
) as progress:
if self.src_uri.instance.is_local:
if self.typ == ResourceType.model:
self._do_upload_bundle_dir(progress)
elif self.typ == ResourceType.runtime:
self._do_upload_bundle_tar(progress)
else:
raise NoSupportError(
f"no support to copy {self.typ} from standalone to server"
)
else:
raise NoSupportError(
f"no support to copy {self.typ} from server to standalone"
if self.typ == ResourceType.model:
self._do_download_bundle_dir(progress)
elif self.typ == ResourceType.runtime:
self._do_download_bundle_tar(progress)
else:
raise NoSupportError(
f"no support to copy {self.typ} from server to standalone"
)
StandaloneTag(self.dest_uri).add_fast_tag()
self._update_manifest(
self._get_versioned_resource_path(self.dest_uri),
{CREATED_AT_KEY: now_str()},
)

StandaloneTag(self.dest_uri).add_fast_tag()
self._update_manifest(
self._get_versioned_resource_path(self.dest_uri),
{CREATED_AT_KEY: now_str()},
)
self.final_steps()
self.final_steps(progress)
console.print(f":tea: console url of the remote bundle: {remote_url}")

def final_steps(self) -> None:
def final_steps(self, progress: Progress) -> None:
pass

def upload_files(self, workdir: Path) -> t.Iterator[FileNode]:
Expand Down
56 changes: 48 additions & 8 deletions client/starwhale/core/model/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import typing as t
from pathlib import Path

from rich.progress import TaskID, Progress

from starwhale.utils import load_yaml
from starwhale.consts import (
FileDesc,
Expand All @@ -16,9 +18,9 @@
)
from starwhale.utils.fs import extract_tar
from starwhale.utils.retry import http_retry
from starwhale.base.bundle_copy import BundleCopy
from starwhale.base.bundle_copy import BundleCopy, _query_param_map
from starwhale.base.uri.instance import Instance
from starwhale.base.uri.resource import ResourceType
from starwhale.base.uri.resource import Resource, ResourceType


class ModelCopy(BundleCopy):
Expand Down Expand Up @@ -88,7 +90,7 @@ def download_files(self, workdir: Path) -> t.Iterator[FileNode]:
# _dest # the unify dir
# )

def final_steps(self) -> None:
def final_steps(self, progress: Progress) -> None:
if self.src_uri.instance.is_local:
manifest_file = (
self._get_versioned_resource_path(self.src_uri) / DEFAULT_MANIFEST_NAME
Expand All @@ -97,17 +99,51 @@ def final_steps(self) -> None:
manifest = load_yaml(manifest_file)
packaged_runtime = manifest.get("packaged_runtime", None)
if packaged_runtime:
_tid = progress.add_task(
f":arrow_up: synchronize the built-in runtime..."
)
rt_version = packaged_runtime["manifest"]["version"]
rt_file_path = (
self._get_versioned_resource_path(self.src_uri)
/ packaged_runtime["path"]
)

runtime_copy = BundleCopy(
src_uri=f'{packaged_runtime["name"]}/version/{rt_version}',
dest_uri=f"{self.dest_uri.project}/{SW_BUILT_IN}/version/{rt_version}",
dest_uri = Resource(
f"{self.dest_uri.project}/{SW_BUILT_IN}/version/{rt_version}",
typ=ResourceType.runtime,
refine=True,
)
runtime_copy.do()

def upload_runtime_tar(file_path: Path, progress: Progress) -> None:
task_id = progress.add_task(
f":synchronize the built-in runtime {file_path.name}",
total=file_path.stat().st_size,
)
self.do_multipart_upload_file(
url_path=f"/project/{dest_uri.project.name}/{ResourceType.runtime.value}/{SW_BUILT_IN}/version/{rt_version}/file",
file_path=file_path,
instance=dest_uri.instance,
fields={
_query_param_map[
ResourceType.runtime
]: f"{SW_BUILT_IN}:{rt_version}",
"project": dest_uri.project.name,
"force": "1" if self.force else "0",
},
use_raise=True,
progress=progress,
task_id=task_id,
)

upload_runtime_tar(rt_file_path, progress)

@http_retry
def sync_built_in_runtime(path: str, instance: Instance) -> None:
def sync_built_in_runtime(
path: str,
instance: Instance,
progress: t.Optional[Progress] = None,
task_id: TaskID = TaskID(0),
) -> None:
self.do_http_request(
path=path,
method=HTTPMethod.PUT,
Expand All @@ -120,8 +156,12 @@ def sync_built_in_runtime(path: str, instance: Instance) -> None:
use_raise=True,
disable_default_content_type=False,
)
progress.update(task_id, completed=100)

print(f"final link:{self._get_remote_bundle_api_url(for_head=True)}")
sync_built_in_runtime(
path=self._get_remote_bundle_api_url(for_head=True),
instance=self.dest_uri.instance,
progress=progress,
task_id=_tid,
)
44 changes: 43 additions & 1 deletion client/tests/base/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from starwhale.utils import NoSupportError
from starwhale.consts import (
HTTPMethod,
SW_BUILT_IN,
VERSION_PREFIX_CNT,
RESOURCE_FILES_NAME,
DEFAULT_MANIFEST_NAME,
Expand Down Expand Up @@ -326,6 +327,7 @@ def test_model_copy_c2l(self, rm: Mocker, *args: MagicMock) -> None:
@Mocker()
def test_model_copy_l2c(self, rm: Mocker) -> None:
version = "ge3tkylgha2tenrtmftdgyjzni3dayq"
built_in_version = "abcdefg1234"
swmp_path = (
self._sw_config.rootdir
/ "self"
Expand All @@ -345,7 +347,14 @@ def test_model_copy_l2c(self, rm: Mocker) -> None:
swmp_manifest_path,
yaml.safe_dump(
{
"version": "ge3tkylgha2tenrtmftdgyjzni3dayq",
"version": version,
"packaged_runtime": {
"manifest": {
"version": built_in_version,
},
"name": "other",
"path": "src/.starwhale/runtime/packaged.swrt",
},
}
),
parents=True,
Expand All @@ -357,6 +366,12 @@ def test_model_copy_l2c(self, rm: Mocker) -> None:
parents=True,
)

ensure_file(
swmp_path / "src" / ".starwhale" / "runtime" / "packaged.swrt",
"",
parents=True,
)

ensure_file(
tag_manifest_path,
yaml.safe_dump(
Expand Down Expand Up @@ -446,13 +461,27 @@ def test_model_copy_l2c(self, rm: Mocker) -> None:
headers={"X-SW-UPLOAD-TYPE": FileDesc.MANIFEST.name},
json={"data": {"uploadId": "123"}},
)

rt_upload_request = rm.request(
HTTPMethod.POST,
f"http://1.1.1.1:8182/api/v1/project/mnist/runtime/{SW_BUILT_IN}/version/{built_in_version}/file",
headers={"X-SW-UPLOAD-TYPE": FileDesc.MANIFEST.name},
json={"data": {"uploadId": "126"}},
)
link_rt_request = rm.request(
HTTPMethod.PUT,
f"http://1.1.1.1:8182/api/v1/project/mnist/model/{case['dest_model']}/version/{version}",
json={"built_in_runtime": built_in_version},
)
ModelCopy(
src_uri=case["src_uri"],
dest_uri=case["dest_uri"],
typ=ResourceType.model,
).do()
assert head_request.call_count == 1
assert upload_request.call_count == 3
assert rt_upload_request.call_count == 1
assert link_rt_request.call_count == 1

head_request = rm.request(
HTTPMethod.HEAD,
Expand All @@ -465,6 +494,17 @@ def test_model_copy_l2c(self, rm: Mocker) -> None:
f"http://1.1.1.1:8182/api/v1/project/mnist/model/mnist-alias/version/{version}/file",
json={"data": {"uploadId": "123"}},
)
rt_upload_request = rm.request(
HTTPMethod.POST,
f"http://1.1.1.1:8182/api/v1/project/mnist/runtime/{SW_BUILT_IN}/version/{built_in_version}/file",
headers={"X-SW-UPLOAD-TYPE": FileDesc.MANIFEST.name},
json={"data": {"uploadId": "126"}},
)
link_rt_request = rm.request(
HTTPMethod.PUT,
f"http://1.1.1.1:8182/api/v1/project/mnist/model/mnist-alias/version/{version}",
json={"built_in_runtime": built_in_version},
)
ModelCopy(
src_uri="mnist/v1",
dest_uri="cloud://pre-bare/project/mnist/model/mnist-alias",
Expand All @@ -473,6 +513,8 @@ def test_model_copy_l2c(self, rm: Mocker) -> None:

assert head_request.call_count == 1
assert upload_request.call_count == 3
assert rt_upload_request.call_count == 1
assert link_rt_request.call_count == 1

def _prepare_local_dataset(self) -> t.Tuple[str, str]:
name = "mnist"
Expand Down

0 comments on commit 8d250c6

Please sign in to comment.