Skip to content

Commit

Permalink
unified for upload
Browse files Browse the repository at this point in the history
  • Loading branch information
goldenxinxing committed Feb 8, 2023
1 parent 9a4e375 commit 83bc016
Show file tree
Hide file tree
Showing 16 changed files with 90 additions and 79 deletions.
20 changes: 7 additions & 13 deletions client/starwhale/base/bundle_copy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import typing as t
from copy import deepcopy
from http import HTTPStatus
from pathlib import Path
from concurrent.futures import wait, ThreadPoolExecutor
Expand Down Expand Up @@ -46,8 +45,6 @@
URIType.RUNTIME: "runtime",
}

_UPLOAD_ID_KEY = "X-SW-UPLOAD-ID"


class _UploadPhase:
MANIFEST = "MANIFEST"
Expand Down Expand Up @@ -267,9 +264,8 @@ def _download(_tid: TaskID, fd: FileNode) -> None:
dest_path=fd.path,
instance_uri=self.src_uri,
params={
# for ds download
"desc": fd.file_desc.name,
"name": fd.name,
"partName": fd.name,
"signature": fd.signature,
},
progress=progress,
Expand Down Expand Up @@ -317,10 +313,10 @@ def _do_ubd_bundle_prepare(
url_path=url_path,
file_path=manifest_path,
instance_uri=self.dest_uri,
headers={"X-SW-UPLOAD-TYPE": FileDesc.MANIFEST.name},
fields={
self.field_flag: self.field_value,
"phase": _UploadPhase.MANIFEST,
"desc": FileDesc.MANIFEST.name,
"project": self.dest_uri.project,
"force": "1" if self.force else "0",
},
Expand All @@ -339,15 +335,11 @@ def _do_ubd_blobs(
existed_files: t.Optional[t.List] = None,
) -> None:
existed_files = existed_files or []
_headers = {_UPLOAD_ID_KEY: str(upload_id)}

# TODO: add retry deco
def _upload_blob(_tid: TaskID, fd: FileNode) -> None:
if not fd.path.exists():
raise NotFoundError(f"{fd.path} not found")
_upload_headers = deepcopy(_headers)
_upload_headers["X-SW-UPLOAD-TYPE"] = fd.file_desc.name
_upload_headers["X-SW-UPLOAD-OBJECT-HASH"] = fd.signature

if progress is not None:
progress.update(_tid, visible=True)
Expand All @@ -358,9 +350,12 @@ def _upload_blob(_tid: TaskID, fd: FileNode) -> None:
instance_uri=self.dest_uri,
fields={
self.field_flag: self.field_value,
"uploadId": upload_id,
"partName": fd.name,
"signature": fd.signature,
"desc": fd.file_desc.name,
"phase": _UploadPhase.BLOB,
},
headers=_upload_headers,
use_raise=True,
progress=progress,
task_id=_tid,
Expand All @@ -380,7 +375,6 @@ def _upload_blob(_tid: TaskID, fd: FileNode) -> None:
visible=False,
)
_p_map[_tid] = _f

with ThreadPoolExecutor(
max_workers=int(os.environ.get("SW_BUNDLE_COPY_THREAD_NUM", "5"))
) as executor:
Expand All @@ -399,11 +393,11 @@ def _do_ubd_end(self, upload_id: str, url_path: str, ok: bool) -> None:
path=url_path,
method=HTTPMethod.POST,
instance_uri=self.dest_uri,
headers={_UPLOAD_ID_KEY: str(upload_id)},
data={
self.field_flag: self.field_value,
"project": self.dest_uri.project,
"phase": phase,
"uploadId": upload_id,
},
use_raise=True,
disable_default_content_type=True,
Expand Down
4 changes: 2 additions & 2 deletions client/starwhale/core/dataset/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def upload_files(self, workdir: Path) -> Iterator[FileNode]:
_path = workdir / "data" / _hash[: DatasetStorage.short_sign_cnt]
yield FileNode(
path=_path,
name=os.path.basename(_path),
name=_hash,
size=_size,
file_desc=FileDesc.DATA,
signature=_hash,
Expand Down Expand Up @@ -79,7 +79,7 @@ def download_files(self, workdir: Path) -> Iterator[FileNode]:
path=_dest,
signature=_hash,
size=_size,
name=_hash[: DatasetStorage.short_sign_cnt],
name=_hash,
file_desc=FileDesc.DATA,
)

Expand Down
10 changes: 5 additions & 5 deletions client/tests/base/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,12 @@ def test_model_copy_c2l(
)
rm.request(
HTTPMethod.GET,
f"http://1.1.1.1:8182/api/v1/project/myproject/model/mnist/version/{version}/file?desc=MANIFEST&name=_manifest.yaml&signature=",
f"http://1.1.1.1:8182/api/v1/project/myproject/model/mnist/version/{version}/file?desc=MANIFEST&partName=_manifest.yaml&signature=",
json={"resources": []},
)
rm.request(
HTTPMethod.GET,
f"http://1.1.1.1:8182/api/v1/project/myproject/model/mnist/version/{version}/file?desc=SRC_TAR&name=src.tar&signature=",
f"http://1.1.1.1:8182/api/v1/project/myproject/model/mnist/version/{version}/file?desc=SRC_TAR&partName=src.tar&signature=",
content=b"mnist model content",
)
# m_load_yaml.return_value = {"resources": []}
Expand Down Expand Up @@ -440,14 +440,14 @@ def test_dataset_copy_c2l(self, rm: Mocker, m_td_scan: MagicMock) -> None:
)
rm.request(
HTTPMethod.GET,
f"http://1.1.1.1:8182/api/v1/project/myproject/dataset/mnist/version/{version}/file?desc=MANIFEST&name=_manifest.yaml&signature=",
f"http://1.1.1.1:8182/api/v1/project/myproject/dataset/mnist/version/{version}/file?desc=MANIFEST&partName=_manifest.yaml&signature=",
json={
"signature": [],
},
)
rm.request(
HTTPMethod.GET,
f"http://1.1.1.1:8182/api/v1/project/myproject/dataset/mnist/version/{version}/file?desc=SRC_TAR&name=archive.swds_meta&signature=",
f"http://1.1.1.1:8182/api/v1/project/myproject/dataset/mnist/version/{version}/file?desc=SRC_TAR&partName=archive.swds_meta&signature=",
content=b"mnist dataset content",
)
rm.request(
Expand Down Expand Up @@ -627,7 +627,7 @@ def test_dataset_copy_l2c(self, rm: Mocker, m_td_scan: MagicMock) -> None:
src_uri=case["src_uri"], dest_uri=case["dest_uri"], typ=URIType.DATASET
).do()
assert head_request.call_count == 1
assert upload_request.call_count == 3
assert upload_request.call_count == 2

# TODO: support the flowing case
with self.assertRaises(NoMockAddress):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.PutMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestHeader;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RequestParam;
Expand Down Expand Up @@ -215,8 +214,6 @@ ResponseEntity<ResponseMessage<PageInfo<DatasetVersionVo>>> listDatasetVersion(
produces = MediaType.APPLICATION_JSON_VALUE)
@PreAuthorize("hasAnyRole('OWNER', 'MAINTAINER')")
ResponseEntity<ResponseMessage<UploadResult>> uploadDs(
@RequestHeader(name = "X-SW-UPLOAD-ID", required = false) String uploadId,
@RequestHeader(name = "X-SW-UPLOAD-DATA-URI", required = false) String uri,
@PathVariable(name = "projectUrl") String projectUrl,
@Pattern(regexp = BUNDLE_NAME_REGEX, message = "Dataset name is invalid")
@PathVariable(name = "datasetName") String datasetName,
Expand All @@ -235,8 +232,8 @@ void pullDs(
@PathVariable(name = "projectUrl") String projectUrl,
@PathVariable(name = "datasetUrl") String datasetUrl,
@PathVariable(name = "versionUrl") String versionUrl,
@Parameter(name = "signature", description = "optional, _manifest.yaml is used if not specified")
@RequestParam(name = "signature", required = false) String partName,
@Parameter(name = "partName", description = "optional, _manifest.yaml is used if not specified")
@RequestParam(name = "partName", required = false) String partName,
HttpServletResponse httpResponse);

@Operation(summary = "Pull Dataset uri file contents",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,13 @@ public ResponseEntity<ResponseMessage<PageInfo<DatasetVersionVo>>> listDatasetVe
}

@Override
public ResponseEntity<ResponseMessage<UploadResult>> uploadDs(String uploadId, String uri,
public ResponseEntity<ResponseMessage<UploadResult>> uploadDs(
String projectUrl, String datasetUrl, String versionUrl,
MultipartFile dsFile, DatasetUploadRequest uploadRequest) {
uploadRequest.setProject(projectUrl);
uploadRequest.setSwds(datasetUrl + ":" + versionUrl);
Long uploadId = uploadRequest.getUploadId();
String uri = uploadRequest.getUri();
switch (uploadRequest.getPhase()) {
case MANIFEST:
String text;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.PutMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestHeader;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RequestParam;
Expand Down Expand Up @@ -291,9 +290,6 @@ ResponseEntity<ResponseMessage<String>> manageModelTag(
produces = MediaType.APPLICATION_JSON_VALUE)
@PreAuthorize("hasAnyRole('OWNER', 'MAINTAINER')")
ResponseEntity<ResponseMessage<Object>> upload(
@RequestHeader(name = "X-SW-UPLOAD-TYPE", required = false) FileDesc fileDesc,
@RequestHeader(name = "X-SW-UPLOAD-OBJECT-HASH", required = false) String signature,
@RequestHeader(name = "X-SW-UPLOAD-ID", required = false) Long uploadId,
@Parameter(
in = ParameterIn.PATH,
description = "Project url",
Expand All @@ -317,7 +313,7 @@ ResponseEntity<ResponseMessage<Object>> upload(
@PreAuthorize("hasAnyRole('OWNER', 'MAINTAINER')")
void pull(
@RequestParam(name = "desc", required = false) FileDesc fileDesc,
@RequestParam(name = "name", required = false) String name,
@RequestParam(name = "partName", required = false) String name,
@RequestParam(name = "path", required = false) String path,
@RequestParam(name = "signature", required = false) String signature,
@PathVariable("projectUrl") String projectUrl,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,13 @@ public ResponseEntity<ResponseMessage<String>> manageModelTag(String projectUrl,

@Override
public ResponseEntity<ResponseMessage<Object>> upload(
FileDesc fileDesc, String signature, Long uploadId,
String projectUrl, String modelUrl, String versionUrl,
MultipartFile file, ModelUploadRequest uploadRequest) {
uploadRequest.setProject(projectUrl);
uploadRequest.setSwmp(modelUrl + ":" + versionUrl);
FileDesc fileDesc = uploadRequest.getDesc();
String signature = uploadRequest.getSignature();
Long uploadId = uploadRequest.getUploadId();
switch (uploadRequest.getPhase()) {
case MANIFEST:
return ResponseEntity.ok(Code.success.asResponse(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package ai.starwhale.mlops.api.protocol.upload;

import ai.starwhale.mlops.api.protocol.storage.FileDesc;
import javax.validation.constraints.NotNull;
import lombok.Data;
import org.springframework.validation.annotation.Validated;
Expand All @@ -26,6 +27,12 @@ public abstract class UploadRequest {

protected static final String SEPARATOR = ":";

Long uploadId;
String partName;
String signature;
String uri;
FileDesc desc;

@NotNull
UploadPhase phase;
String force;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,6 @@
public class UploadResult {

@JsonProperty("upload_id")
String uploadId;
Long uploadId;

}
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ public DatasetUploader(HotDatasetHolder hotDatasetHolder, DatasetMapper datasetM
this.versionAliasConvertor = versionAliasConvertor;
}

public void cancel(String uploadId) {
public void cancel(Long uploadId) {
final DatasetVersionWithMeta swDatasetVersionEntityWithMeta = getDatasetVersion(uploadId);
datasetVersionMapper.delete(swDatasetVersionEntityWithMeta.getDatasetVersion().getId());
hotDatasetHolder.cancel(uploadId);
Expand All @@ -145,7 +145,7 @@ private void clearDatasetStorageData(DatasetVersion datasetVersion) {
}
}

public void uploadBody(String uploadId, MultipartFile file, String uri) {
public void uploadBody(Long uploadId, MultipartFile file, String uri) {
final DatasetVersionWithMeta swDatasetVersionWithMeta = getDatasetVersion(uploadId);
String filename = file.getOriginalFilename();
try (InputStream inputStream = file.getInputStream()) {
Expand Down Expand Up @@ -187,7 +187,7 @@ private boolean fileUploaded(DatasetVersionWithMeta datasetVersionWithMeta, Stri
return digest.equals(uploadedFileBlake2bs.get(filename));
}

DatasetVersionWithMeta getDatasetVersion(String uploadId) {
DatasetVersionWithMeta getDatasetVersion(Long uploadId) {
final Optional<DatasetVersionWithMeta> swDatasetVersionEntityOpt = hotDatasetHolder.of(uploadId);
return swDatasetVersionEntityOpt
.orElseThrow(
Expand All @@ -212,7 +212,7 @@ void reUploadManifest(DatasetVersionEntity datasetVersionEntity, String fileName
}

@Transactional
public String create(String yamlContent, String fileName, DatasetUploadRequest uploadRequest) {
public Long create(String yamlContent, String fileName, DatasetUploadRequest uploadRequest) {
Manifest manifest;
try {
manifest = yamlMapper.readValue(yamlContent, Manifest.class);
Expand Down Expand Up @@ -284,7 +284,7 @@ public String create(String yamlContent, String fileName, DatasetUploadRequest u

hotDatasetHolder.manifest(DatasetVersion.fromEntity(datasetEntity, datasetVersionEntity));

return datasetVersionEntity.getVersionName();
return datasetVersionEntity.getId();
}

private DatasetVersionEntity from(String projectName, DatasetEntity datasetEntity, Manifest manifest) {
Expand Down Expand Up @@ -320,7 +320,7 @@ private Long getOwner() {
return currentUserDetail.getIdTableKey();
}

public void end(String uploadId) {
public void end(Long uploadId) {
final DatasetVersionWithMeta datasetVersionWithMeta = getDatasetVersion(uploadId);
datasetVersionMapper.updateStatus(datasetVersionWithMeta.getDatasetVersion().getId(),
DatasetVersion.STATUS_AVAILABLE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
@Component
public class HotDatasetHolder {

Map<String, DatasetVersionWithMeta> datasetHolder;
Map<Long, DatasetVersionWithMeta> datasetHolder;

final DatasetVersionWithMetaConverter datasetVersionWithMetaConverter;

Expand All @@ -36,19 +36,19 @@ public HotDatasetHolder(DatasetVersionWithMetaConverter datasetVersionWithMetaCo
}

public void manifest(DatasetVersion datasetVersion) {
datasetHolder.put(datasetVersion.getVersionName(),
datasetHolder.put(datasetVersion.getId(),
datasetVersionWithMetaConverter.from(datasetVersion));
}

public void cancel(String datasetId) {
public void cancel(Long datasetId) {
datasetHolder.remove(datasetId);
}

public void end(String datasetId) {
public void end(Long datasetId) {
datasetHolder.remove(datasetId);
}

public Optional<DatasetVersionWithMeta> of(String id) {
public Optional<DatasetVersionWithMeta> of(Long id) {
return Optional.ofNullable(datasetHolder.get(id));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ public ModelUploadResult uploadManifest(MultipartFile multipartFile, ModelUpload
.revertVersionTo(modelVersionEntity.getModelId(), modelVersionEntity.getId());
}
return ModelUploadResult.builder()
.uploadId(modelVersionEntity.getId().toString())
.uploadId(modelVersionEntity.getId())
.existed(existed)
.build();
}
Expand Down Expand Up @@ -579,6 +579,7 @@ public void pull(FileDesc fileDesc, String name, String path, String signature,
// update correct attributes
name = Objects.isNull(name) ? file.getName() : name;
path = Objects.isNull(path) ? file.getPath() : path;
signature = Objects.isNull(signature) ? file.getSignature() : signature;
break;
}
}
Expand Down
Loading

0 comments on commit 83bc016

Please sign in to comment.