From 13a0a040e5e1892b4ec3152c66a898398dad451f Mon Sep 17 00:00:00 2001 From: tianwei Date: Thu, 24 Nov 2022 11:02:01 +0800 Subject: [PATCH] add dataset sdk interface --- client/starwhale/__init__.py | 5 + client/starwhale/api/_impl/data_store.py | 62 +- .../starwhale/api/_impl/dataset/__init__.py | 2 + client/starwhale/api/_impl/dataset/builder.py | 227 ++++-- client/starwhale/api/_impl/dataset/loader.py | 63 +- client/starwhale/api/_impl/dataset/model.py | 673 ++++++++++++++++ client/starwhale/api/_impl/model.py | 2 +- client/starwhale/api/_impl/wrapper.py | 26 +- client/starwhale/api/dataset.py | 2 + client/starwhale/base/bundle.py | 26 +- client/starwhale/base/cloud.py | 7 +- client/starwhale/base/tag.py | 14 +- client/starwhale/core/dataset/cli.py | 11 +- client/starwhale/core/dataset/model.py | 165 ++-- client/starwhale/core/dataset/store.py | 2 +- client/starwhale/core/dataset/tabular.py | 9 +- client/starwhale/core/dataset/type.py | 24 +- client/starwhale/core/dataset/view.py | 29 +- client/starwhale/core/model/model.py | 16 +- client/starwhale/core/model/view.py | 13 +- client/starwhale/core/runtime/model.py | 17 +- client/starwhale/core/runtime/view.py | 12 +- client/starwhale/utils/load.py | 6 +- client/tests/base/test_tag.py | 4 +- client/tests/core/test_dataset.py | 66 +- client/tests/core/test_model.py | 2 +- client/tests/core/test_runtime.py | 15 +- client/tests/sdk/test_data_store.py | 34 + client/tests/sdk/test_dataset.py | 261 ++++++ client/tests/sdk/test_dataset_sdk.py | 747 ++++++++++++++++++ client/tests/sdk/test_loader.py | 37 +- client/tests/sdk/test_metric.py | 7 + scripts/client_test/cmds/artifacts_cmd.py | 3 +- 33 files changed, 2334 insertions(+), 255 deletions(-) create mode 100644 client/starwhale/api/_impl/dataset/model.py create mode 100644 client/tests/sdk/test_dataset_sdk.py diff --git a/client/starwhale/__init__.py b/client/starwhale/__init__.py index 2aa963a92b..11bf2713f2 100644 --- a/client/starwhale/__init__.py +++ b/client/starwhale/__init__.py @@ -10,6 +10,7 @@ Image, Video, Binary, + Dataset, LinkAuth, LinkType, MIMEType, @@ -28,10 +29,14 @@ from starwhale.api.evaluation import Evaluation from starwhale.core.dataset.tabular import get_dataset_consumption +dataset = Dataset.dataset + __all__ = [ "__version__", "PipelineHandler", "multi_classification", + "Dataset", + "dataset", "URI", "URIType", "step", diff --git a/client/starwhale/api/_impl/data_store.py b/client/starwhale/api/_impl/data_store.py index 299055c7c1..3904d61599 100644 --- a/client/starwhale/api/_impl/data_store.py +++ b/client/starwhale/api/_impl/data_store.py @@ -11,7 +11,19 @@ import threading from abc import ABCMeta, abstractmethod from http import HTTPStatus -from typing import Any, Set, cast, Dict, List, Type, Tuple, Union, Iterator, Optional +from typing import ( + Any, + Set, + cast, + Dict, + List, + Type, + Tuple, + Union, + Callable, + Iterator, + Optional, +) import dill import numpy as np @@ -665,6 +677,7 @@ def _scan_parquet_file( start: Optional[Any] = None, end: Optional[Any] = None, keep_none: bool = False, + end_inclusive: bool = False, ) -> Iterator[dict]: f = pq.ParquetFile(path) schema_arrow = f.schema_arrow @@ -708,7 +721,10 @@ def _scan_parquet_file( n_cols = len(names) for j in range(n_rows): key = types[0].deserialize(table[0][j].as_py()) - if (start is not None and key < start) or (end is not None and key >= end): + _end_check: Callable = lambda x, y: x > y if end_inclusive else x >= y + if (start is not None and key < start) or ( + end is not None and _end_check(key, end) + ): continue d = {"*": key} if key_alias is not None: @@ -828,6 +844,7 @@ def _scan_table( start: Optional[Any] = None, end: Optional[Any] = None, keep_none: bool = False, + end_inclusive: bool = False, ) -> Iterator[dict]: iters = [] for file in _get_table_files(path): @@ -835,7 +852,7 @@ def _scan_table( keep = True else: keep = keep_none - iters.append(_scan_parquet_file(file, columns, start, end, keep)) + iters.append(_scan_parquet_file(file, columns, start, end, keep, end_inclusive)) return _merge_scan(iters, keep_none) @@ -911,16 +928,24 @@ def scan( start: Optional[Any] = None, end: Optional[Any] = None, keep_none: bool = False, + end_inclusive: bool = False, ) -> Iterator[Dict[str, Any]]: + _end_check: Callable = lambda x, y: x <= y if end_inclusive else x < y + with self.lock: schema = self.schema.copy() - records = [ - {self.schema.key_column: key, "-": True} - for key in self.deletes - if (start is None or key >= start) and (end is None or key < end) - ] + + records = [] + for key in self.deletes: + if (start is None or key >= start) and ( + end is None or _end_check(key, end) + ): + records.append({self.schema.key_column: key, "-": True}) + for k, v in self.records.items(): - if (start is None or k >= start) and (end is None or k < end): + if (start is None or k >= start) and ( + end is None or _end_check(k, end) + ): records.append(v) records.sort(key=lambda x: cast(str, x[self.schema.key_column])) for r in records: @@ -1105,6 +1130,7 @@ def scan_tables( start: Optional[Any] = None, end: Optional[Any] = None, keep_none: bool = False, + end_inclusive: bool = False, ) -> Iterator[Dict[str, Any]]: class TableInfo: def __init__( @@ -1163,9 +1189,14 @@ def __init__( start, end, info.keep_none, + end_inclusive, ), self.tables[info.name].scan( - info.columns, start, end, True + info.columns, + start, + end, + True, + end_inclusive, ), ], info.keep_none, @@ -1174,7 +1205,11 @@ def __init__( else: iters.append( self.tables[info.name].scan( - info.columns, start, end, info.keep_none + info.columns, + start, + end, + info.keep_none, + end_inclusive, ) ) else: @@ -1185,6 +1220,7 @@ def __init__( start, end, info.keep_none, + end_inclusive, ) ) for record in _merge_scan(iters, keep_none): @@ -1284,6 +1320,7 @@ def scan_tables( start: Optional[Any] = None, end: Optional[Any] = None, keep_none: bool = False, + end_inclusive: bool = False, ) -> Iterator[Dict[str, Any]]: post_data: Dict[str, Any] = {"tables": [table.to_dict() for table in tables]} key_type = _get_type(start) @@ -1294,6 +1331,8 @@ def scan_tables( post_data["limit"] = 1000 if keep_none: post_data["keepNone"] = True + if end_inclusive: + post_data["endInclusive"] = True assert self.token is not None while True: resp_json = self._do_scan_table_request(post_data) @@ -1338,6 +1377,7 @@ def scan_tables( start: Optional[Any] = None, end: Optional[Any] = None, keep_none: bool = False, + end_inclusive: bool = False, ) -> Iterator[Dict[str, Any]]: ... diff --git a/client/starwhale/api/_impl/dataset/__init__.py b/client/starwhale/api/_impl/dataset/__init__.py index 8f8ed79bcf..0c5144f65e 100644 --- a/client/starwhale/api/_impl/dataset/__init__.py +++ b/client/starwhale/api/_impl/dataset/__init__.py @@ -17,6 +17,7 @@ COCOObjectAnnotation, ) +from .model import Dataset from .loader import get_data_loader, SWDSBinDataLoader, UserRawDataLoader from .builder import BuildExecutor, SWDSBinBuildExecutor, UserRawBuildExecutor @@ -43,4 +44,5 @@ "BoundingBox", "GrayscaleImage", "COCOObjectAnnotation", + "Dataset", ] diff --git a/client/starwhale/api/_impl/dataset/builder.py b/client/starwhale/api/_impl/dataset/builder.py index c668bb0040..23536e8364 100644 --- a/client/starwhale/api/_impl/dataset/builder.py +++ b/client/starwhale/api/_impl/dataset/builder.py @@ -1,16 +1,22 @@ +from __future__ import annotations + import os +import time +import queue import struct import typing as t import inspect import tempfile +import threading from abc import ABCMeta, abstractmethod from types import TracebackType from pathlib import Path from binascii import crc32 import jsonlines +from loguru import logger -from starwhale.consts import AUTH_ENV_FNAME, SWDS_DATA_FNAME_FMT +from starwhale.consts import AUTH_ENV_FNAME, DEFAULT_PROJECT, SWDS_DATA_FNAME_FMT from starwhale.base.uri import URI from starwhale.utils.fs import empty_dir, ensure_dir from starwhale.base.type import DataFormatType, DataOriginType, ObjectStoreType @@ -29,6 +35,7 @@ from starwhale.core.dataset.store import DatasetStorage from starwhale.api._impl.data_store import SwObject from starwhale.core.dataset.tabular import TabularDataset, TabularDatasetRow +from starwhale.api._impl.dataset.loader import DataRow # TODO: tune header size _header_magic = struct.unpack(">I", b"SWDS")[0] @@ -98,9 +105,11 @@ def __exit__( value: t.Optional[BaseException], trace: TracebackType, ) -> None: - if value: + if value: # pragma: no cover print(f"type:{type}, exception:{value}, traceback:{trace}") + self.close() + def close(self) -> None: try: self.tabular_dataset.close() except Exception as e: @@ -137,6 +146,31 @@ def _merge_forked_summary(self, s: DatasetSummary) -> DatasetSummary: def data_format_type(self) -> DataFormatType: raise NotImplementedError + def _unpack_row_content( + self, row_content: t.Union[t.Tuple, DataRow], append_seq_id: int + ) -> t.Tuple[t.Union[str, int], BaseArtifact, t.Dict]: + if isinstance(row_content, DataRow): + idx, row_data, row_annotations = row_content + elif isinstance(row_content, tuple): + if len(row_content) == 2: + idx = append_seq_id + row_data, row_annotations = row_content + elif len(row_content) == 3: + idx, row_data, row_annotations = row_content + else: + raise FormatError( + f"iter_item must return (data, annotations) or (id, data, annotations): {row_content}" + ) + else: + raise FormatError( + f"row content not return tuple or DataRow type: {row_content}" + ) + + if not isinstance(row_annotations, dict): + raise FormatError(f"annotations({row_annotations}) must be dict type") + + return idx, row_data, row_annotations + class SWDSBinBuildExecutor(BaseBuildExecutor): """ @@ -206,21 +240,9 @@ def make_swds(self) -> DatasetSummary: for append_seq_id, item_content in enumerate( self.iter_item(), start=self._forked_last_seq_id + 1 ): - if not isinstance(item_content, tuple): - raise FormatError(f"iter_item not return tuple type: {item_content}") - - if len(item_content) == 2: - idx = append_seq_id - row_data, row_annotations = item_content - elif len(item_content) == 3: - idx, row_data, row_annotations = item_content - else: - raise FormatError( - f"iter_item must return (data, annotations) or (id, data, annotations): {item_content}" - ) - - if not isinstance(row_annotations, dict): - raise FormatError(f"annotations({row_annotations}) must be dict type") + idx, row_data, row_annotations = self._unpack_row_content( + item_content, append_seq_id + ) _artifact: BaseArtifact if isinstance(row_data, bytes): @@ -332,7 +354,7 @@ class UserRawBuildExecutor(BaseBuildExecutor): def make_swds(self) -> DatasetSummary: increased_rows = 0 total_data_size = 0 - auth_candidates = {} + auth_candidates: t.Dict[str, LinkAuth] = {} include_link = False map_path_sign: t.Dict[str, t.Tuple[str, Path]] = {} @@ -342,18 +364,9 @@ def make_swds(self) -> DatasetSummary: self.iter_item(), start=self._forked_last_seq_id + 1, ): - if len(item_content) == 2: - idx = append_seq_id - row_data, row_annotations = item_content - elif len(item_content) == 3: - idx, row_data, row_annotations = item_content - else: - raise FormatError( - f"iter_item must return (data, annotations) or (id, data, annotations): {item_content}" - ) - - if not isinstance(row_annotations, dict): - raise FormatError(f"annotations({row_annotations}) must be dict type") + idx, row_data, row_annotations = self._unpack_row_content( + item_content, append_seq_id + ) if not dataset_annotations: # TODO: check annotations type and name @@ -377,7 +390,7 @@ def _travel_link(obj: t.Any) -> None: if isinstance(obj, Link): if not obj.with_local_fs_data: raise NoSupportError( - f"Local Link only suuports local link annotations: {obj}" + f"Local Link only supports local link annotations: {obj}" ) if obj.uri not in map_path_sign: map_path_sign[obj.uri] = DatasetStorage.save_data_file( @@ -482,26 +495,152 @@ def _do_iter_item(self: t.Any) -> t.Generator: for _item in items_iter: yield _item - attrs = {"iter_item": _do_iter_item} - - if len(item) == 2: - data = item[0] - elif len(item) == 3: - data = item[1] + if isinstance(item, DataRow): + data = item.data + elif isinstance(item, (tuple, list)): + if len(item) == 2: + data = item[0] + elif len(item) == 3: + data = item[1] + else: + raise FormatError(f"wrong item format: {item}") else: - raise FormatError(f"wrong item format: {item}") + raise TypeError(f"item only supports tuple, list or DataRow type: {item}") - if isinstance(data, Link): + use_swds_bin = not isinstance(data, Link) + return create_generic_cls_by_mode(use_swds_bin, _do_iter_item) + + +def create_generic_cls_by_mode( + use_swds_bin: bool, iter_func: t.Callable +) -> t.Type[BaseBuildExecutor]: + attrs = {"iter_item": iter_func} + if use_swds_bin: _cls = type( - "GenericUserRawHandler", - (UserRawBuildExecutor,), + "GenericSWDSBinHandler", + (SWDSBinBuildExecutor,), attrs, ) else: _cls = type( - "GenericSWDSBinHandler", - (SWDSBinBuildExecutor,), + "GenericUserRawHandler", + (UserRawBuildExecutor,), attrs, ) - return _cls + + +class RowWriter(threading.Thread): + def __init__( + self, + dataset_name: str, + dataset_version: str, + project_name: str = DEFAULT_PROJECT, + workdir: Path = Path(".dataset_tmp"), + alignment_bytes_size: int = D_ALIGNMENT_SIZE, + volume_bytes_size: int = D_FILE_VOLUME_SIZE, + append: bool = False, + append_from_version: str = "", + append_from_uri: t.Optional[URI] = None, + append_with_swds_bin: bool = True, + ) -> None: + super().__init__( + name=f"RowWriter-{dataset_name}-{dataset_version}-{project_name}" + ) + + self._kw = { + "dataset_name": dataset_name, + "dataset_version": dataset_version, + "project_name": project_name, + "workdir": workdir, + "alignment_bytes_size": alignment_bytes_size, + "volume_bytes_size": volume_bytes_size, + "append": append, + "append_from_version": append_from_version, + "append_from_uri": append_from_uri, + } + + self._queue: queue.Queue[t.Optional[DataRow]] = queue.Queue() + self._summary = DatasetSummary() + self._lock = threading.Lock() + + self._run_exception: t.Optional[Exception] = None + + self.setDaemon(True) + self._builder: t.Optional[BaseBuildExecutor] = None + if append and append_from_version: + _cls = create_generic_cls_by_mode(append_with_swds_bin, self.__iter__) + self._builder = _cls(**self._kw) # type: ignore + self.start() + + def _raise_run_exception(self) -> None: + if self._run_exception is not None: + _e = self._run_exception + self._run_exception = None + raise threading.ThreadError(f"RowWriter Thread raise exception: {_e}") + + @property + def summary(self) -> DatasetSummary: + return self._summary + + def __enter__(self) -> RowWriter: + return self + + def __exit__( + self, + type: t.Optional[t.Type[BaseException]], + value: t.Optional[BaseException], + trace: TracebackType, + ) -> None: + if value: # pragma: no cover + logger.warning(f"type:{type}, exception:{value}, traceback:{trace}") + + self.close() + + def flush(self) -> None: + while not self._queue.empty(): + # TODO: tune flush with thread condition + time.sleep(0.1) + + def close(self) -> None: + self._queue.put(None) + + self.join() + if self._builder: + self._builder.close() + + self._raise_run_exception() + + def update(self, row_item: DataRow) -> None: + self._raise_run_exception() + self._queue.put(row_item) + + with self._lock: + if self._builder is None: + _cls = create_generic_cls(self.__iter__) + self._builder = _cls(**self._kw) # type: ignore + self.start() + + def __iter__(self) -> t.Generator[DataRow, None, None]: + while True: + item = self._queue.get(block=True, timeout=None) + if item is None: + if self._queue.qsize() > 0: + continue + else: + break + + if not isinstance(item, DataRow): + continue + + yield item + + def run(self) -> None: + try: + if self._builder is None: + raise RuntimeError("dataset builder object wasn't initialized") + self._summary = self._builder.make_swds() + except Exception as e: + logger.exception(e) + self._run_exception = e + raise diff --git a/client/starwhale/api/_impl/dataset/loader.py b/client/starwhale/api/_impl/dataset/loader.py index de7cfb83aa..e93137ca71 100644 --- a/client/starwhale/api/_impl/dataset/loader.py +++ b/client/starwhale/api/_impl/dataset/loader.py @@ -3,6 +3,7 @@ import os import typing as t from abc import ABCMeta, abstractmethod +from functools import total_ordering import loguru from loguru import logger as _logger @@ -23,7 +24,55 @@ TabularDatasetSessionConsumption, ) -_LDType = t.Tuple[t.Union[str, int], t.Any, t.Dict] + +@total_ordering +class DataRow: + def __init__( + self, + index: t.Union[str, int], + data: t.Optional[t.Union[BaseArtifact, Link]], + annotations: t.Dict, + ) -> None: + self.index = index + self.data = data + self.annotations = annotations + + self._do_validate() + + def __str__(self) -> str: + return f"{self.index}" + + def __repr__(self) -> str: + return f"index:{self.index}, data:{self.data}, annotations:{self.annotations}" + + def __iter__(self) -> t.Iterator: + return iter((self.index, self.data, self.annotations)) + + def __getitem__(self, i: int) -> t.Any: + return (self.index, self.data, self.annotations)[i] + + def __len__(self) -> int: + return len(self.__dict__) + + def _do_validate(self) -> None: + if not isinstance(self.index, (str, int)): + raise TypeError(f"index({self.index}) is not int or str type") + + if self.data is not None and not isinstance(self.data, (BaseArtifact, Link)): + raise TypeError(f"data({self.data}) is not BaseArtifact or Link type") + + if not isinstance(self.annotations, dict): + raise TypeError(f"annotations({self.annotations}) is not dict type") + + def __lt__(self, obj: DataRow) -> bool: + return str(self.index) < str(obj.index) + + def __eq__(self, obj: t.Any) -> bool: + return bool( + self.index == obj.index + and self.data == obj.data + and self.annotations == obj.annotations + ) class DataLoader(metaclass=ABCMeta): @@ -134,6 +183,7 @@ def _sign_uris(self, uris: t.List[str]) -> dict: def _iter_row(self) -> t.Generator[TabularDatasetRow, None, None]: if not self.session_consumption: + # TODO: refactor for batch-signed urls for row in self.tabular_dataset.scan(): yield row else: @@ -164,17 +214,22 @@ def _iter_row(self) -> t.Generator[TabularDatasetRow, None, None]: for row in self.tabular_dataset.scan(rt[0], rt[1]): yield row - def _unpack_row(self, row: TabularDatasetRow) -> _LDType: + def _unpack_row( + self, row: TabularDatasetRow, skip_fetch_data: bool = False + ) -> DataRow: + if skip_fetch_data: + return DataRow(index=row.id, data=None, annotations=row.annotations) + store = self._get_store(row) key_compose = self._get_key_compose(row, store) file = store.backend._make_file(store.bucket, key_compose) data_content, _ = self._read_data(file, row) data = BaseArtifact.reflect(data_content, row.data_type) - return row.id, data, row.annotations + return DataRow(index=row.id, data=data, annotations=row.annotations) def __iter__( self, - ) -> t.Generator[_LDType, None, None]: + ) -> t.Generator[DataRow, None, None]: for row in self._iter_row(): # TODO: tune performance by fetch in batch yield self._unpack_row(row) diff --git a/client/starwhale/api/_impl/dataset/model.py b/client/starwhale/api/_impl/dataset/model.py new file mode 100644 index 0000000000..a6973a83db --- /dev/null +++ b/client/starwhale/api/_impl/dataset/model.py @@ -0,0 +1,673 @@ +from __future__ import annotations + +import typing as t +import threading +from http import HTTPStatus +from types import TracebackType +from pathlib import Path +from functools import wraps + +from loguru import logger + +from starwhale.utils import gen_uniq_version +from starwhale.consts import HTTPMethod, DEFAULT_PAGE_IDX, DEFAULT_PAGE_SIZE +from starwhale.base.uri import URI, URIType +from starwhale.base.type import InstanceType +from starwhale.base.cloud import CloudRequestMixed +from starwhale.utils.error import ExistedError, NotFoundError, NoSupportError +from starwhale.core.dataset.type import DatasetConfig, DatasetSummary +from starwhale.core.dataset.model import Dataset as CoreDataset +from starwhale.core.dataset.model import StandaloneDataset +from starwhale.core.dataset.store import DatasetStorage +from starwhale.core.dataset.tabular import ( + get_dataset_consumption, + DEFAULT_CONSUMPTION_BATCH_SIZE, + TabularDatasetSessionConsumption, +) + +from .loader import DataRow, DataLoader, get_data_loader +from .builder import RowWriter, BaseBuildExecutor + +_DType = t.TypeVar("_DType", bound="Dataset") +_ItemType = t.Union[str, int, slice] +_HandlerType = t.Optional[t.Union[t.Callable, BaseBuildExecutor]] +_GItemType = t.Optional[t.Union[DataRow, t.List[DataRow]]] + + +class _Tags: + def __init__(self, core_dataset: CoreDataset) -> None: + self.__core_dataset = core_dataset + + def add(self, tags: t.Union[str, t.List[str]], ignore_errors: bool = False) -> None: + if isinstance(tags, str): + tags = [tags] + self.__core_dataset.add_tags(tags, ignore_errors) + + def remove( + self, tags: t.Union[str, t.List[str]], ignore_errors: bool = False + ) -> None: + if isinstance(tags, str): + tags = [tags] + self.__core_dataset.remove_tags(tags, ignore_errors) + + def __iter__(self) -> t.Generator[str, None, None]: + for tag in self.__core_dataset.list_tags(): + yield tag + + def __str__(self) -> str: + return f"Dataset Tag: {self.__core_dataset}" + + __repr__ = __str__ + + +class Dataset: + def __init__( + self, + name: str, + version: str, + project_uri: URI, + create: bool = False, + ) -> None: + self.name = name + self.project_uri = project_uri + + _origin_uri = URI.capsulate_uri( + self.project_uri.instance, + self.project_uri.project, + URIType.DATASET, + self.name, + version, + ) + + if create: + self.version = gen_uniq_version() + else: + self.version = self._auto_complete_version(version) + + if not self.version: + raise ValueError("version field is empty") + + self.uri = URI.capsulate_uri( + self.project_uri.instance, + self.project_uri.project, + URIType.DATASET, + self.name, + self.version, + ) + + self.__readonly = not create + self.__core_dataset = CoreDataset.get_dataset(self.uri) + if create: + setattr(self.__core_dataset, "_version", self.version) + + self._append_use_swds_bin = False + _summary = None + if self._check_uri_exists(_origin_uri): + if create: + self._append_from_version = version + self._create_by_append = True + self._fork_dataset() + _summary = CoreDataset.get_dataset(_origin_uri).summary() + if _summary: + self._append_use_swds_bin = not ( + _summary.include_link or _summary.include_user_raw + ) + else: + self._append_from_version = "" + self._create_by_append = False + else: + if create: + self._append_from_version = "" + self._create_by_append = False + else: + raise ExistedError(f"{self.uri} was not found fo load") + + self._summary = _summary or self.__core_dataset.summary() + + self._rows_cnt = self._summary.rows if self._summary else 0 + self._consumption: t.Optional[TabularDatasetSessionConsumption] = None + self._lock = threading.Lock() + self.__data_loaders: t.Dict[str, DataLoader] = {} + self.__build_handler: _HandlerType = None + self._trigger_handler_build = False + self._trigger_icode_build = False + self._writer_lock = threading.Lock() + self._row_writer: t.Optional[RowWriter] = None + self.__keys_cache: t.Set[t.Union[int, str]] = set() + self._enable_copy_src = False + + def _fork_dataset(self) -> None: + # TODO: support cloud dataset prepare in the tmp dir + # TODO: lazy fork dataset + self.__core_dataset._prepare_snapshot() + self.__core_dataset._fork_swds( + self._create_by_append, self._append_from_version + ) + + def _auto_complete_version(self, version: str) -> str: + version = version.strip() + if not version: + return version + + if self.project_uri.instance_type == InstanceType.CLOUD: + return version + + _uri = URI.capsulate_uri( + instance=self.project_uri.instance, + project=self.project_uri.project, + obj_type=URIType.DATASET, + obj_name=self.name, + obj_ver=version, + ) + store = DatasetStorage(_uri) + if not store.snapshot_workdir.exists(): + return version + else: + return store.id + + def __str__(self) -> str: + return f"Dataset: {self.name}-{self.version}" + + def __repr__(self) -> str: + return f"Dataset: uri-{self.uri}" + + def __len__(self) -> int: + return self._rows_cnt + + def __enter__(self: _DType) -> _DType: + return self + + def __bool__(self) -> bool: + return True + + def __exit__( + self, + type: t.Optional[t.Type[BaseException]], + value: t.Optional[BaseException], + trace: TracebackType, + ) -> None: + if value: # pragma: no cover + logger.warning(f"type:{type}, exception:{value}, traceback:{trace}") + + self.close() + + def make_distributed_consumption( + self, session_id: str, batch_size: int = DEFAULT_CONSUMPTION_BATCH_SIZE + ) -> Dataset: + if self._consumption is not None: + raise RuntimeError( + f"distributed consumption has already been created ({self._consumption})" + ) + + with self._lock: + self._consumption = get_dataset_consumption( + self.uri, session_id=session_id, batch_size=batch_size + ) + return self + + def _get_data_loader( + self, recreate: bool = False, disable_consumption: bool = False + ) -> DataLoader: + with self._lock: + key = f"consumption-{disable_consumption}" + + _loader = self.__data_loaders.get(key) + if _loader is None or recreate: + if disable_consumption: + consumption = None + else: + consumption = self._consumption + + _loader = get_data_loader(self.uri, session_consumption=consumption) + self.__data_loaders[key] = _loader + + return _loader + + def __iter__(self) -> t.Iterator[DataRow]: + for row in self._get_data_loader(): + yield row + + def __getitem__( + self, + item: _ItemType, + ) -> _GItemType: + """ + Example: + self["str_key"] # get the DataRow by the "str_key" string key + self[1] # get the DataRow by the 1 int key + self["start":"end"] # get a slice of the dataset by the range ("start", "end") + self[1:10:2] # get a slice of the dataset by the range (1, 10), step is 2 + """ + # TODO: tune datastore performance for getitem + return self._getitem(item) + + def _getitem( + self, + item: _ItemType, + skip_fetch_data: bool = False, + ) -> _GItemType: + def _run() -> _GItemType: + loader = self._get_data_loader(disable_consumption=True) + if isinstance(item, (int, str)): + row = next(loader.tabular_dataset.scan(item, item, end_inclusive=True)) + return loader._unpack_row(row, skip_fetch_data) + elif isinstance(item, slice): + step = item.step or 1 + if step <= 0: + raise ValueError( + f"Dataset slice step({step}) cannot be zero or negative number" + ) + cnt = 0 + # TODO: batch signed urls + rows = [] + for row in loader.tabular_dataset.scan(item.start, item.stop): + if cnt % step == 0: + rows.append(loader._unpack_row(row, skip_fetch_data)) + cnt += 1 + return rows + else: + raise ValueError(f"{item} type is not int, str or slice") + + try: + return _run() + except RuntimeError as e: + if str(e).startswith("table is empty"): + return None + raise + except StopIteration: + return None + + @property + def readonly(self) -> bool: + return self.__readonly + + def _check_readonly(func: t.Callable): # type: ignore + @wraps(func) + def _wrapper(*args: t.Any, **kwargs: t.Any) -> t.Any: + self: Dataset = args[0] + if self.readonly: + raise RuntimeError(f"{func} does not work in the readonly mode") + return func(*args, **kwargs) + + return _wrapper + + def _forbid_handler_build(func: t.Callable): # type: ignore + @wraps(func) + def _wrapper(*args: t.Any, **kwargs: t.Any) -> t.Any: + self: Dataset = args[0] + if self._trigger_handler_build: + raise NoSupportError( + "no support build from handler and from cache code at the same time, build from handler has already been activated" + ) + return func(*args, **kwargs) + + return _wrapper + + def _forbid_icode_build(func: t.Callable): # type: ignore + @wraps(func) + def _wrapper(*args: t.Any, **kwargs: t.Any) -> t.Any: + self: Dataset = args[0] + if self._trigger_icode_build: + raise NoSupportError( + "no support build from handler and from cache code at the same time, build from interactive code has already been activated" + ) + return func(*args, **kwargs) + + return _wrapper + + @property + def build_handler(self) -> _HandlerType: + return self.__build_handler + + @build_handler.setter + def build_handler(self, handler: _HandlerType) -> None: + if self._trigger_icode_build: + raise RuntimeError( + "dataset append by interactive code has already been called" + ) + self._trigger_handler_build = True + self.__build_handler = handler + + @property + def tags(self) -> _Tags: + return _Tags(self.__core_dataset) + + @staticmethod + def _check_uri_exists(uri: t.Optional[URI]) -> bool: + if uri is None or uri.object.version == "": + return False + + if uri.instance_type == InstanceType.CLOUD: + crm = CloudRequestMixed() + ok, _ = crm.do_http_request_simple_ret( + path=f"/project/{uri.project}/{URIType.DATASET}/{uri.object.name}/version/{uri.object.version}/file", + method=HTTPMethod.HEAD, + instance_uri=uri, + ignore_status_codes=[HTTPStatus.NOT_FOUND], + ) + return ok + else: + _store = DatasetStorage(uri) + return _store.manifest_path.exists() + + def exists(self) -> bool: + return self._check_uri_exists(self.uri) + + @_check_readonly + def flush(self) -> None: + loader = self._get_data_loader(disable_consumption=True) + loader.tabular_dataset.flush() + self.__keys_cache = set() + + if self._row_writer: + self._row_writer.flush() + + @_check_readonly + def rehash(self) -> None: + # TODO: rehash for swds-bin format dataset with append/delete items to reduce volumes size + raise NotImplementedError + + def remove(self, force: bool = False) -> None: + ok, reason = self.__core_dataset.remove(force) + if not ok: + raise RuntimeError(f"failed to remove dataset: {reason}") + + def recover(self, force: bool = False) -> None: + ok, reason = self.__core_dataset.recover(force) + if not ok: + raise RuntimeError(f"failed to recover dataset: {reason}") + + def summary(self) -> t.Optional[DatasetSummary]: + return self._summary + + def history(self) -> t.List[t.Dict]: + return self.__core_dataset.history() + + def close(self) -> None: + self.__keys_cache = set() + for _loader in self.__data_loaders.values(): + if not _loader: + continue # pragma: no cover + + _loader.tabular_dataset.close() + + if self._row_writer: + self._row_writer.close() + + # TODO: flush raw data into disk + + def diff(self, cmp: Dataset) -> t.Dict: + return self.__core_dataset.diff(cmp.uri) + + def info(self) -> t.Dict[str, t.Any]: + return self.__core_dataset.info() + + def head(self, n: int = 3, show_raw_data: bool = False) -> t.List[t.Dict]: + # TODO: render artifact in JupyterNotebook + return self.__core_dataset.head(n, show_raw_data) + + def to_pytorch(self) -> t.Any: + raise NotImplementedError + + def to_tensorflow(self) -> t.Any: + raise NotImplementedError + + @_check_readonly + @_forbid_handler_build + def __setitem__( + self, key: t.Union[str, int], value: t.Union[DataRow, t.Tuple] + ) -> None: + # TODO: tune the performance of getitem by cache + self._trigger_icode_build = True + if not isinstance(self.__core_dataset, StandaloneDataset): + raise NoSupportError( + f"setitem only supports for standalone dataset: {self.__core_dataset}" + ) + + _row_writer = self._get_row_writer() + + if not isinstance(key, (int, str)): + raise TypeError(f"key must be str or int type: {key}") + + if isinstance(value, DataRow): + value.index = key + row = value + elif isinstance(value, (tuple, list)): + if len(value) == 2: + data, annotations = value + elif len(value) == 3: + _, data, annotations = value + else: + raise ValueError(f"{value} cannot unpack") + + row = DataRow(index=key, data=data, annotations=annotations) + else: + raise TypeError(f"value only supports tuple or DataRow type: {value}") + + if key not in self.__keys_cache: + self.__keys_cache.add(key) + _item = self._getitem(key, skip_fetch_data=True) + if _item is None or len(_item) == 0: + self._rows_cnt += 1 + + _row_writer.update(row) + + def _get_row_writer(self) -> RowWriter: + if self._row_writer is not None: + return self._row_writer + + with self._writer_lock: + if self._row_writer is None: + if self._create_by_append and self._append_from_version: + append_from_uri = URI.capsulate_uri( + instance=self.project_uri.instance, + project=self.project_uri.project, + obj_type=URIType.DATASET, + obj_name=self.name, + obj_ver=self._append_from_version, + ) + store = DatasetStorage(append_from_uri) + if not store.snapshot_workdir.exists(): + raise NotFoundError(f"dataset uri: {append_from_uri}") + append_from_version = store.id + else: + append_from_uri = None + append_from_version = "" + + # TODO: support alignment_bytes_size, volume_bytes_size arguments + if not isinstance(self.__core_dataset, StandaloneDataset): + raise NoSupportError( + f"setitem only supports for standalone dataset: {self.__core_dataset}" + ) + + self._row_writer = RowWriter( + dataset_name=self.name, + dataset_version=self.version, + project_name=self.project_uri.project, + workdir=self.__core_dataset.store.snapshot_workdir, # TODO: use tmpdir which is same as dataset build command + append=self._create_by_append, + append_from_version=append_from_version, + append_from_uri=append_from_uri, + append_with_swds_bin=self._append_use_swds_bin, + ) + return self._row_writer + + _init_row_writer = _get_row_writer + + @_check_readonly + @_forbid_handler_build + def __delitem__(self, key: _ItemType) -> None: + self._trigger_icode_build = True + self._init_row_writer() # hack for del item as the first operation + + items: t.List + if isinstance(key, (str, int)): + items = [self._getitem(key, skip_fetch_data=True)] + elif isinstance(key, slice): + items = self._getitem(key, skip_fetch_data=True) # type: ignore + else: + raise TypeError(f"key({key}) is not str, int or slice type") + + # TODO: raise not-found key error? + loader = self._get_data_loader(disable_consumption=True) + for item in items: + if not item or not isinstance(item, DataRow): + continue # pragma: no cover + if item.index in self.__keys_cache: + self.__keys_cache.remove(item.index) + loader.tabular_dataset.delete(item.index) + self._rows_cnt -= 1 + + @_check_readonly + @_forbid_handler_build + def append(self, item: t.Any) -> None: + if isinstance(item, DataRow): + self.__setitem__(item.index, item) + elif isinstance(item, (list, tuple)): + if len(item) == 2: + row = DataRow(self._rows_cnt, item[0], item[1]) + elif len(item) == 3: + row = DataRow(item[0], item[1], item[2]) + else: + raise ValueError( + f"cannot unpack value({item}), expected sequence is (index, data, annotations) or (data, annotations)" + ) + + self.__setitem__(row.index, row) + else: + raise TypeError(f"value({item}) is not DataRow, list or tuple type") + + @_check_readonly + @_forbid_handler_build + def extend(self, items: t.Sequence[t.Any]) -> None: + for item in items: + self.append(item) + + @_check_readonly + def build_with_copy_src( + self, + src_dir: t.Union[str, Path], + include_files: t.Optional[t.List[str]] = None, + exclude_files: t.Optional[t.List[str]] = None, + ) -> Dataset: + self._enable_copy_src = True + self._build_src_dir = Path(src_dir) + self._build_include_files = include_files or [] + self._build_exclude_files = exclude_files or [] + return self + + commit_with_copy_src = build_with_copy_src + + @_check_readonly + @_forbid_handler_build + def _do_build_from_interactive_code(self) -> None: + ds = self.__core_dataset + if isinstance(ds, StandaloneDataset): + if self._row_writer is None: + raise RuntimeError("row writer is none, no data was written") + + self.flush() + self._row_writer.close() + # TODO: use the elegant method to refactor manifest update + self._summary = self._row_writer.summary + self._summary.rows = len(self) + ds._manifest["dataset_summary"] = self._summary.asdict() + ds._calculate_signature() + ds._render_manifest() + ds._make_swds_meta_tar() + ds._make_auto_tags() + else: + # TODO: support cloud dataset build + raise NoSupportError("only support standalone dataset build") + + @_check_readonly + @_forbid_icode_build + def _do_build_from_handler(self) -> None: + self._trigger_icode_build = True + config = DatasetConfig( + name=self.name, + handler=self.build_handler, + project_uri=self.project_uri.full_uri, + append=self._create_by_append, + append_from=self._append_from_version, + ) + + kw: t.Dict[str, t.Any] = {"disable_copy_src": not self._enable_copy_src} + if self._enable_copy_src: + config.pkg_data = self._build_include_files + config.exclude_pkg_data = self._build_exclude_files + kw["workdir"] = self._build_src_dir + + # TODO: support DatasetAttr config for SDK + config.do_validate() + kw["config"] = config + # TODO: support build tmpdir, follow the swcli dataset build command behavior + self.__core_dataset.buildImpl(**kw) + _summary = self.__core_dataset.summary() + self._rows_cnt = _summary.rows if _summary else 0 + + @_check_readonly + def build(self) -> None: + # TODO: support build dataset for cloud uri directly + if self.project_uri.instance_type == InstanceType.CLOUD: + raise NoSupportError("no support to build cloud dataset directly") + + if self._trigger_icode_build: + self._do_build_from_interactive_code() + elif self._trigger_handler_build and self.build_handler: + self._do_build_from_handler() + else: + raise RuntimeError("no data to build dataset") + + commit = build + + def copy( + self, dest_uri: str, force: bool = False, dest_local_project_uri: str = "" + ) -> None: + CoreDataset.copy( + str(self.uri), + dest_uri, + force=force, + dest_local_project_uri=dest_local_project_uri, + ) + + @staticmethod + def list( + project_uri: t.Union[str, URI] = "", + fullname: bool = False, + show_removed: bool = False, + page_index: int = DEFAULT_PAGE_IDX, + page_size: int = DEFAULT_PAGE_SIZE, + ) -> t.Tuple[t.List[t.Dict[str, t.Any]], t.Dict[str, t.Any]]: + from starwhale.core.dataset.view import DatasetTermView + + return DatasetTermView.list( + project_uri, fullname, show_removed, page_index, page_size + ) + + @staticmethod + def dataset( + uri: t.Union[str, URI], + create: bool = False, + create_from_handler: t.Optional[_HandlerType] = None, + ) -> Dataset: + if isinstance(uri, str): + _uri = URI(uri, expected_type=URIType.DATASET) + elif isinstance(uri, URI) and uri.object.typ == URIType.DATASET: + _uri = uri + else: + raise TypeError( + f"uri({uri}) argument type is not expected, dataset uri or str is ok" + ) + + ds = Dataset( + name=_uri.object.name, + version=_uri.object.version, + project_uri=_uri, # TODO: cut off dataset resource info? + create=create or bool(create_from_handler), + ) + + if create_from_handler: + ds.build_handler = create_from_handler + + return ds diff --git a/client/starwhale/api/_impl/model.py b/client/starwhale/api/_impl/model.py index 2bc3b5ef69..7515f8b56f 100644 --- a/client/starwhale/api/_impl/model.py +++ b/client/starwhale/api/_impl/model.py @@ -161,7 +161,7 @@ def __exit__( self._sw_logger.debug( f"execute {self.context.step}-{self.context.index} exit func..." ) - if value: + if value: # pragma: no cover print(f"type:{type}, exception:{value}, traceback:{trace}") if self._stdout_changed: diff --git a/client/starwhale/api/_impl/wrapper.py b/client/starwhale/api/_impl/wrapper.py index bc131b640e..bd04bb5102 100644 --- a/client/starwhale/api/_impl/wrapper.py +++ b/client/starwhale/api/_impl/wrapper.py @@ -35,7 +35,7 @@ def close(self) -> None: if exceptions: raise Exception(*exceptions) - def _log(self, table_name: str, record: Dict[str, Any]) -> None: + def _fetch_writer(self, table_name: str) -> data_store.TableWriter: with self._lock: if table_name not in self._writers: self._writers.setdefault(table_name, None) @@ -44,7 +44,10 @@ def _log(self, table_name: str, record: Dict[str, Any]) -> None: _store = getattr(self, "_data_store", None) writer = data_store.TableWriter(table_name, data_store=_store) self._writers[table_name] = writer + return writer + def _log(self, table_name: str, record: Dict[str, Any]) -> None: + writer = self._fetch_writer(table_name) writer.insert(record) def _flush(self, table_name: str) -> None: @@ -54,6 +57,10 @@ def _flush(self, table_name: str) -> None: return writer.flush() + def _delete(self, table_name: str, key: Any) -> None: + writer = self._fetch_writer(table_name) + writer.delete(key) + def _serialize(data: Any) -> Any: return dill.dumps(data) @@ -174,16 +181,27 @@ def put(self, data_id: Union[str, int], **kwargs: Any) -> None: record[k.lower()] = v self._log(self._meta_table_name, record) - def scan(self, start: Any, end: Any) -> Iterator[Dict[str, Any]]: + def delete(self, data_id: Union[str, int]) -> None: + self._delete(self._meta_table_name, data_id) + + def scan( + self, start: Any, end: Any, end_inclusive: bool = False + ) -> Iterator[Dict[str, Any]]: return self._data_store.scan_tables( - [data_store.TableDesc(self._meta_table_name)], start=start, end=end + [data_store.TableDesc(self._meta_table_name)], + start=start, + end=end, + end_inclusive=end_inclusive, ) - def scan_id(self, start: Any, end: Any) -> Iterator[Any]: + def scan_id( + self, start: Any, end: Any, end_inclusive: bool = False + ) -> Iterator[Any]: return self._data_store.scan_tables( [data_store.TableDesc(self._meta_table_name, columns=["id"])], start=start, end=end, + end_inclusive=end_inclusive, ) def flush(self) -> None: diff --git a/client/starwhale/api/dataset.py b/client/starwhale/api/dataset.py index 2539acb498..e5636851f5 100644 --- a/client/starwhale/api/dataset.py +++ b/client/starwhale/api/dataset.py @@ -5,6 +5,7 @@ Image, Video, Binary, + Dataset, LinkAuth, LinkType, MIMEType, @@ -49,4 +50,5 @@ "BoundingBox", "GrayscaleImage", "COCOObjectAnnotation", + "Dataset", ] diff --git a/client/starwhale/base/bundle.py b/client/starwhale/base/bundle.py index d84ad52e22..a6ee83efbf 100644 --- a/client/starwhale/base/bundle.py +++ b/client/starwhale/base/bundle.py @@ -51,11 +51,15 @@ def recover(self, force: bool = False) -> t.Tuple[bool, str]: raise NotImplementedError @abstractmethod - def add_tags(self, tags: t.List[str], quiet: bool = False) -> None: + def list_tags(self) -> t.List[str]: raise NotImplementedError @abstractmethod - def remove_tags(self, tags: t.List[str], quiet: bool = False) -> None: + def add_tags(self, tags: t.List[str], ignore_errors: bool = False) -> None: + raise NotImplementedError + + @abstractmethod + def remove_tags(self, tags: t.List[str], ignore_errors: bool = False) -> None: raise NotImplementedError @abstractmethod @@ -87,7 +91,7 @@ def copy(cls, src_uri: str, dest_uri: str, force: bool = False) -> None: def extract(self, force: bool = False, target: t.Union[str, Path] = "") -> Path: raise NotImplementedError - def build(self, workdir: Path, yaml_name: str = "", **kw: t.Any) -> None: + def build(self, **kw: t.Any) -> None: # TODO: remove yaml_name in build function self.store.building = True # type: ignore @@ -103,12 +107,16 @@ def when_exit() -> None: with ExitStack() as stack: stack.callback(when_exit) - kw["yaml_name"] = yaml_name or self.yaml_name - self.buildImpl(workdir, **kw) + kw["yaml_name"] = kw.get("yaml_name", self.yaml_name) + self.buildImpl(**kw) - def buildImpl(self, workdir: Path, **kw: t.Any) -> None: + def buildImpl(self, **kw: t.Any) -> None: raise NotImplementedError + @property + def version(self) -> str: + return getattr(self, "_version", "") or self.uri.object.version + class LocalStorageBundleMixin: def __init__(self) -> None: @@ -119,6 +127,8 @@ def _render_manifest(self) -> None: os=platform.system(), sw_version=STARWHALE_VERSION, ) + self._manifest["version"] = self._version # type: ignore + self._manifest["created_at"] = now_str() # TODO: add signature for import files: model, config _fpath = self.store.snapshot_workdir / DEFAULT_MANIFEST_NAME # type: ignore @@ -131,13 +141,11 @@ def _gen_version(self) -> None: self._version = gen_uniq_version() self.uri.object.version = self._version # type:ignore - self._manifest["version"] = self._version # type: ignore - self._manifest["created_at"] = now_str() logger.info(f"[step:version]version: {self._version}") console.print(f":new: version {self._version[:SHORT_VERSION_CNT]}") # type: ignore def _make_auto_tags(self) -> None: - self.tag.add([LATEST_TAG], quiet=True) # type: ignore + self.tag.add([LATEST_TAG], ignore_errors=True) # type: ignore self.tag.add_fast_tag() # type: ignore def _make_tar(self, ftype: str = "") -> None: diff --git a/client/starwhale/base/cloud.py b/client/starwhale/base/cloud.py index b789957ac6..5f96159f3e 100644 --- a/client/starwhale/base/cloud.py +++ b/client/starwhale/base/cloud.py @@ -319,8 +319,11 @@ def recover(self, force: bool = False) -> t.Tuple[bool, str]: instance_uri=uri, ) - def add_tags(self, tags: t.List[str], quiet: bool = False) -> None: + def list_tags(self) -> t.List[str]: + raise NoSupportError("no support list tags for dataset in the cloud instance") + + def add_tags(self, tags: t.List[str], ignore_errors: bool = False) -> None: raise NoSupportError("no support add tags for dataset in the cloud instance") - def remove_tags(self, tags: t.List[str], quiet: bool = False) -> None: + def remove_tags(self, tags: t.List[str], ignore_errors: bool = False) -> None: raise NoSupportError("no support remove tags for dataset in the cloud instance") diff --git a/client/starwhale/base/tag.py b/client/starwhale/base/tag.py index 21f0aed502..5f5807b5f7 100644 --- a/client/starwhale/base/tag.py +++ b/client/starwhale/base/tag.py @@ -74,13 +74,13 @@ def add_fast_tag(self) -> None: def add( self, tags: t.List[str], - quiet: bool = False, + ignore_errors: bool = False, manifest: t.Optional[t.Dict] = None, ) -> None: _manifest = manifest or self._get_manifest() _version = self.uri.object.version - if not _version and not quiet: + if not _version and not ignore_errors: raise MissingFieldError(f"uri version, {self.uri}") for _t in tags: @@ -90,7 +90,7 @@ def add( _ok, _reason = validate_obj_name(_t) if not _ok: - if quiet: + if ignore_errors: continue else: raise FormatError(f"{_t}, reason:{_reason}") @@ -109,13 +109,13 @@ def add( self._save_manifest(_manifest) - def remove(self, tags: t.List[str], quiet: bool = False) -> None: + def remove(self, tags: t.List[str], ignore_errors: bool = False) -> None: _manifest = self._get_manifest() for _t in tags: _version = _manifest["tags"].pop(_t, "") if _version not in _manifest["versions"]: - if quiet: + if ignore_errors: continue else: raise NotFoundError(f"tag:{_t}, version:{_version}") @@ -126,6 +126,10 @@ def remove(self, tags: t.List[str], quiet: bool = False) -> None: self._save_manifest(_manifest) + def __iter__(self) -> t.Generator[str, None, None]: + for tag in self.list(): + yield tag + def list(self) -> t.List[str]: _manifest = self._get_manifest() _version = self.uri.object.version diff --git a/client/starwhale/core/dataset/cli.py b/client/starwhale/core/dataset/cli.py index 44f4af1a9b..3899e18f69 100644 --- a/client/starwhale/core/dataset/cli.py +++ b/client/starwhale/core/dataset/cli.py @@ -8,6 +8,7 @@ from starwhale.base.uri import URI from starwhale.base.type import URIType from starwhale.utils.cli import AliasedGroup +from starwhale.utils.load import import_object from starwhale.utils.error import NotFoundError from starwhale.core.dataset.type import MIMEType, DatasetAttr, DatasetConfig @@ -54,6 +55,7 @@ def dataset_cmd(ctx: click.Context) -> None: @click.option("-a", "--append", is_flag=True, default=None, help="Only append new data") @click.option("-af", "--append-from", help="Append from dataset version") @click.option("-r", "--runtime", help="runtime uri") +@click.option("-dcs", "--disable-copy-src", help="disable copy src dir") @click.pass_obj def _build( view: DatasetTermView, @@ -69,6 +71,7 @@ def _build( append: bool, append_from: str, runtime: str, + disable_copy_src: bool, ) -> None: # TODO: add dry-run # TODO: add compress args @@ -81,7 +84,8 @@ def _build( config = DatasetConfig.create_by_yaml(yaml_path) config.name = name or config.name or Path(workdir).absolute().name - config.handler = handler or config.handler + handler = handler or config.handler + config.handler = import_object(workdir, handler) config.runtime_uri = runtime or config.runtime_uri config.project_uri = project or config.project_uri # TODO: support README.md as the default desc @@ -97,11 +101,8 @@ def _build( if append is not None: config.append = append - print(config.name) - print(config.handler) - config.do_validate() - view.build(workdir, config) + view.build(workdir, config, disable_copy_src) @dataset_cmd.command("diff", help="Dataset version diff") diff --git a/client/starwhale/core/dataset/model.py b/client/starwhale/core/dataset/model.py index 3645a2df53..e752d18c5f 100644 --- a/client/starwhale/core/dataset/model.py +++ b/client/starwhale/core/dataset/model.py @@ -28,11 +28,9 @@ from starwhale.base.type import URIType, BundleType, InstanceType from starwhale.base.cloud import CloudRequestMixed, CloudBundleModelMixin from starwhale.utils.http import ignore_error -from starwhale.utils.load import import_object from starwhale.base.bundle import BaseBundle, LocalStorageBundleMixin from starwhale.utils.error import NotFoundError, NoSupportError from starwhale.utils.progress import run_with_progress_bar -from starwhale.api._impl.dataset import get_data_loader from starwhale.core.dataset.copy import DatasetCopy from .type import DatasetConfig, DatasetSummary @@ -44,6 +42,14 @@ class Dataset(BaseBundle, metaclass=ABCMeta): def __str__(self) -> str: return f"Starwhale Dataset: {self.uri}" + def _prepare_snapshot(self) -> None: + raise NotImplementedError + + def _fork_swds( + self, append: bool = False, append_from_version: t.Optional[str] = None + ) -> None: + raise NotImplementedError + @abstractmethod def summary(self) -> t.Optional[DatasetSummary]: raise NotImplementedError @@ -51,6 +57,8 @@ def summary(self) -> t.Optional[DatasetSummary]: def head( self, rows: int = 5, show_raw_data: bool = False ) -> t.List[t.Dict[str, t.Any]]: + from starwhale.api._impl.dataset import get_data_loader + ret = [] loader = get_data_loader(self.uri) for idx, row in enumerate(loader._iter_row()): @@ -67,8 +75,10 @@ def head( }, } if show_raw_data: - _, raw, _ = loader._unpack_row(row) - info["data"]["raw"] = raw.to_bytes() + _un_row = loader._unpack_row(row) + info["data"]["raw"] = ( + _un_row.data.to_bytes() if _un_row and _un_row.data else b"" + ) info["data"]["size"] = len(info["data"]["raw"]) ret.append(info) @@ -123,11 +133,14 @@ def __init__(self, uri: URI) -> None: self.yaml_name = DefaultYAMLName.DATASET self._version = uri.object.version - def add_tags(self, tags: t.List[str], quiet: bool = False) -> None: - self.tag.add(tags, quiet) + def list_tags(self) -> t.List[str]: + return self.tag.list() - def remove_tags(self, tags: t.List[str], quiet: bool = False) -> None: - self.tag.remove(tags, quiet) + def add_tags(self, tags: t.List[str], ignore_errors: bool = False) -> None: + self.tag.add(tags, ignore_errors) + + def remove_tags(self, tags: t.List[str], ignore_errors: bool = False) -> None: + self.tag.remove(tags, ignore_errors) def diff(self, compare_uri: URI) -> t.Dict[str, t.Any]: # TODO: support cross-instance diff: standalone <--> cloud @@ -268,26 +281,8 @@ def list( return rs, {} - def buildImpl(self, workdir: Path, **kw: t.Any) -> None: + def buildImpl(self, **kw: t.Any) -> None: config = kw["config"] - append = config.append - if append: - append_from_uri = URI.capsulate_uri( - instance=self.uri.instance, - project=self.uri.project, - obj_type=self.uri.object.typ, - obj_name=self.uri.object.name, - obj_ver=config.append_from, - ) - append_from_store = DatasetStorage(append_from_uri) - if not append_from_store.snapshot_workdir.exists(): - raise NotFoundError(f"dataset uri: {append_from_uri}") - else: - append_from_uri = None - append_from_store = None - - # TODO: design uniq build steps for model build, swmp build - operations = [ (self._gen_version, 5, "gen version"), (self._prepare_snapshot, 5, "prepare snapshot"), @@ -299,23 +294,13 @@ def buildImpl(self, workdir: Path, **kw: t.Any) -> None: swds_config=config, ), ), - ( - self._copy_src, - 15, - "copy src", - dict( - workdir=workdir, - pkg_data=config.pkg_data, - exclude_pkg_data=config.exclude_pkg_data, - ), - ), ( self._fork_swds, 10, "fork swds", dict( - append=append, - append_from_store=append_from_store, + append=config.append, + append_from_version=config.append_from, ), ), ( @@ -323,11 +308,7 @@ def buildImpl(self, workdir: Path, **kw: t.Any) -> None: 30, "make swds", dict( - workdir=workdir, swds_config=config, - append=append, - append_from_uri=append_from_uri, - append_from_store=append_from_store, ), ), (self._calculate_signature, 5, "calculate signature"), @@ -339,18 +320,46 @@ def buildImpl(self, workdir: Path, **kw: t.Any) -> None: (self._make_swds_meta_tar, 15, "make meta tar"), (self._make_auto_tags, 5, "make auto tags"), ] + + if not kw.get("disable_copy_src", False): + operations.append( + ( + self._copy_src, + 15, + "copy src", + dict( + workdir=kw["workdir"], + pkg_data=config.pkg_data, + exclude_pkg_data=config.exclude_pkg_data, + ), + ) + ) run_with_progress_bar("swds building...", operations) def _fork_swds( - self, append: bool, append_from_store: t.Optional[DatasetStorage] + self, append: bool = False, append_from_version: t.Optional[str] = None ) -> None: - if not append or not append_from_store: + if not append or not append_from_version: return - console.print( - f":articulated_lorry: fork dataset data from {append_from_store.id}" + uri = URI.capsulate_uri( + instance=self.uri.instance, + project=self.uri.project, + obj_type=self.uri.object.typ, + obj_name=self.uri.object.name, + obj_ver=append_from_version, ) - src_data_dir = append_from_store.data_dir + store = DatasetStorage(uri) + if not store.snapshot_workdir.exists(): + raise NotFoundError(f"dataset uri: {uri}") + + self._manifest["from"] = { + "version": store.id, + "append": append, + } + + console.print(f":articulated_lorry: fork dataset data from {store.id}") + src_data_dir = store.data_dir for src in src_data_dir.rglob("*"): if not src.is_symlink(): continue @@ -363,11 +372,7 @@ def _fork_swds( def _call_make_swds( self, - workdir: Path, swds_config: DatasetConfig, - append: bool, - append_from_uri: t.Optional[URI], - append_from_store: t.Optional[DatasetStorage], ) -> None: from starwhale.api._impl.dataset.builder import ( BaseBuildExecutor, @@ -375,29 +380,41 @@ def _call_make_swds( ) logger.info("[step:swds]try to gen swds...") - append_from_version = ( - append_from_store.id if append and append_from_store else "" + + if swds_config.append: + append_from_uri = URI.capsulate_uri( + instance=self.uri.instance, + project=self.uri.project, + obj_type=self.uri.object.typ, + obj_name=self.uri.object.name, + obj_ver=swds_config.append_from, + ) + _store = DatasetStorage(append_from_uri) + if not _store.snapshot_workdir.exists(): + raise NotFoundError(f"dataset uri: {append_from_uri}") + swds_config.append_from = append_from_version = _store.id + append_from_version = _store.id + else: + append_from_version = "" + append_from_uri = None + + _handler_name = getattr(swds_config.handler, "__name__", None) or str( + swds_config.handler ) self._manifest.update( { "dataset_attr": swds_config.attr.asdict(), - "handler": swds_config.handler, - "from": { - "version": append_from_version, - "append": append, - }, + "handler": _handler_name, } ) - # TODO: add more import format support, current is module:class - logger.info(f"[info:swds]try to import {swds_config.handler} @ {workdir}") - _handler = import_object(workdir, swds_config.handler) - _cls: t.Type[BaseBuildExecutor] - if inspect.isclass(_handler) and issubclass(_handler, BaseBuildExecutor): - _cls = _handler - elif inspect.isfunction(_handler): - _cls = create_generic_cls(_handler) + if inspect.isclass(swds_config.handler) and issubclass( + swds_config.handler, BaseBuildExecutor + ): + _cls = swds_config.handler + elif inspect.isfunction(swds_config.handler): + _cls = create_generic_cls(swds_config.handler) else: raise RuntimeError( f"{swds_config.handler} not BaseBuildExecutor or generator function" @@ -410,14 +427,12 @@ def _call_make_swds( workdir=self.store.snapshot_workdir, alignment_bytes_size=swds_config.attr.alignment_size, volume_bytes_size=swds_config.attr.volume_size, - append=append, + append=swds_config.append, append_from_version=append_from_version, append_from_uri=append_from_uri, data_mime_type=swds_config.attr.data_mime_type, ) as _obj: - console.print( - f":ghost: import [red]{swds_config.handler}@{workdir.resolve()}[/] to make swds..." - ) + console.print(f":ghost: import [red]{_handler_name}[/] to make swds...") _summary: DatasetSummary = _obj.make_swds() self._manifest["dataset_summary"] = _summary.asdict() @@ -444,9 +459,11 @@ def _make_swds_meta_tar(self) -> None: out = self.store.snapshot_workdir / ARCHIVED_SWDS_META_FNAME logger.info(f"[step:tar]try to tar for swmp meta(NOT INCLUDE DATASET){out}") with tarfile.open(out, "w:") as tar: - tar.add(str(self.store.src_dir), arcname="src") + if self.store.src_dir.exists(): + tar.add(str(self.store.src_dir), arcname="src") + if (self.store.snapshot_workdir / DefaultYAMLName.DATASET).exists(): + tar.add(str(self.store.snapshot_workdir / DefaultYAMLName.DATASET)) tar.add(str(self.store.snapshot_workdir / DEFAULT_MANIFEST_NAME)) - tar.add(str(self.store.snapshot_workdir / DefaultYAMLName.DATASET)) console.print( ":hibiscus: congratulation! you can run " @@ -516,5 +533,5 @@ def summary(self) -> t.Optional[DatasetSummary]: _summary = _manifest.get("dataset_summary", {}) return DatasetSummary(**_summary) if _summary else None - def build(self, workdir: Path, yaml_name: str = "", **kw: t.Any) -> None: + def build(self, **kw: t.Any) -> None: raise NoSupportError("no support build dataset in the cloud instance") diff --git a/client/starwhale/core/dataset/store.py b/client/starwhale/core/dataset/store.py index fb5a2c2e77..c07cb33536 100644 --- a/client/starwhale/core/dataset/store.py +++ b/client/starwhale/core/dataset/store.py @@ -469,7 +469,7 @@ def __exit__( value: t.Optional[BaseException], trace: TracebackType, ) -> None: - if value: + if value: # pragma: no cover print(f"type:{type}, exception:{value}, traceback:{trace}") self.close() diff --git a/client/starwhale/core/dataset/tabular.py b/client/starwhale/core/dataset/tabular.py index b90a342fa6..bd3554ced0 100644 --- a/client/starwhale/core/dataset/tabular.py +++ b/client/starwhale/core/dataset/tabular.py @@ -224,6 +224,9 @@ def update( def put(self, row: TabularDatasetRow) -> None: self._ds_wrapper.put(row.id, **row.asdict()) + def delete(self, row_id: t.Union[str, int]) -> None: + self._ds_wrapper.delete(row_id) + def flush(self) -> None: self._ds_wrapper.flush() @@ -231,6 +234,7 @@ def scan( self, start: t.Optional[t.Any] = None, end: t.Optional[t.Any] = None, + end_inclusive: bool = False, ) -> t.Generator[TabularDatasetRow, None, None]: if start is None or (self.start is not None and start < self.start): start = self.start @@ -238,7 +242,7 @@ def scan( if end is None or (self.end is not None and end > self.end): end = self.end - for _d in self._ds_wrapper.scan(start, end): + for _d in self._ds_wrapper.scan(start, end, end_inclusive): for k, v in self._map_types.items(): if k not in _d: continue @@ -261,6 +265,7 @@ def scan_batch( yield batch def close(self) -> None: + self.flush() self._ds_wrapper.close() def __enter__(self: _TDType) -> _TDType: @@ -272,7 +277,7 @@ def __exit__( value: t.Optional[BaseException], trace: TracebackType, ) -> None: - if value: + if value: # pragma: no cover logger.warning(f"type:{type}, exception:{value}, traceback:{trace}") self.close() diff --git a/client/starwhale/core/dataset/type.py b/client/starwhale/core/dataset/type.py index b1972bf1ce..0319f2e2cc 100644 --- a/client/starwhale/core/dataset/type.py +++ b/client/starwhale/core/dataset/type.py @@ -11,6 +11,7 @@ from urllib.parse import urlparse import requests +from numpy import ndarray from starwhale.utils import load_yaml, convert_to_bytes, validate_obj_name from starwhale.consts import ( @@ -268,6 +269,20 @@ def to_bytes(self, encoding: str = "utf-8") -> bytes: else: raise NoSupportError(f"read raw for type:{type(self.fp)}") + def to_numpy(self) -> ndarray: + ... + + def to_json(self) -> str: + ... + + def to_tensor(self) -> t.Any: + ... + + to_pt_tensor = to_tensor + + def to_tf_tensor(self) -> t.Any: + ... + def carry_raw_data(self: _TBAType) -> _TBAType: self._raw_base64_data = base64.b64encode(self.to_bytes()).decode() return self @@ -654,7 +669,7 @@ class DatasetConfig(ASDictMixin): def __init__( self, name: str = "", - handler: str = "", + handler: t.Any = "", pkg_data: t.List[str] = [], exclude_pkg_data: t.List[str] = [], desc: str = "", @@ -685,7 +700,7 @@ def do_validate(self) -> None: if not _ok: raise FieldTypeOrValueError(f"name field:({self.name}) error: {_reason}") - if ":" not in self.handler: + if isinstance(self.handler, str) and ":" not in self.handler: raise Exception( f"please use module:class format, current is: {self.handler}" ) @@ -702,3 +717,8 @@ def create_by_yaml(cls, fpath: t.Union[str, Path]) -> DatasetConfig: c = load_yaml(fpath) return cls(**c) + + def asdict(self, ignore_keys: t.Optional[t.List[str]] = None) -> t.Dict: + d = super().asdict(["handler"]) + d["handler"] = getattr(self.handler, "__name__", None) or str(self.handler) + return d diff --git a/client/starwhale/core/dataset/view.py b/client/starwhale/core/dataset/view.py index d1bd2e5a93..334ad9624c 100644 --- a/client/starwhale/core/dataset/view.py +++ b/client/starwhale/core/dataset/view.py @@ -134,13 +134,18 @@ def _str_row(row: t.Dict) -> str: @classmethod def list( cls, - project_uri: str = "", + project_uri: t.Union[str, URI] = "", fullname: bool = False, show_removed: bool = False, page: int = DEFAULT_PAGE_IDX, size: int = DEFAULT_PAGE_SIZE, ) -> t.Tuple[t.List[t.Dict[str, t.Any]], t.Dict[str, t.Any]]: - _uri = URI(project_uri, expected_type=URIType.PROJECT) + + if isinstance(project_uri, str): + _uri = URI(project_uri, expected_type=URIType.PROJECT) + else: + _uri = project_uri + fullname = fullname or (_uri.instance_type == InstanceType.CLOUD) _datasets, _pager = Dataset.list(_uri, page, size) _data = BaseTermView.list_data(_datasets, show_removed, fullname) @@ -152,21 +157,23 @@ def build( cls, workdir: str, config: DatasetConfig, + disable_copy_src: bool = False, ) -> URI: dataset_uri = cls.prepare_build_bundle( project=config.project_uri, bundle_name=config.name, typ=URIType.DATASET ) ds = Dataset.get_dataset(dataset_uri) + kwargs = dict( + workdir=Path(workdir), config=config, disable_copy_src=disable_copy_src + ) + if config.runtime_uri: RuntimeProcess.from_runtime_uri( - uri=config.runtime_uri, - target=ds.build, - args=(Path(workdir),), - kwargs=dict(config=config), + uri=config.runtime_uri, target=ds.build, kwargs=kwargs ).run() else: - ds.build(Path(workdir), config=config) + ds.build(**kwargs) return dataset_uri @classmethod @@ -181,13 +188,15 @@ def copy( console.print(":clap: copy done") @BaseTermView._header - def tag(self, tags: t.List[str], remove: bool = False, quiet: bool = False) -> None: + def tag( + self, tags: t.List[str], remove: bool = False, ignore_errors: bool = False + ) -> None: if remove: console.print(f":golfer: remove tags {tags} @ {self.uri}...") - self.dataset.remove_tags(tags, quiet) + self.dataset.remove_tags(tags, ignore_errors) else: console.print(f":surfer: add tags {tags} @ {self.uri}...") - self.dataset.add_tags(tags, quiet) + self.dataset.add_tags(tags, ignore_errors) @BaseTermView._header def head(self, rows: int, show_raw_data: bool = False) -> None: diff --git a/client/starwhale/core/model/model.py b/client/starwhale/core/model/model.py index 33731de07e..6a504f9ef9 100644 --- a/client/starwhale/core/model/model.py +++ b/client/starwhale/core/model/model.py @@ -183,11 +183,14 @@ def __init__(self, uri: URI) -> None: self.yaml_name = DefaultYAMLName.MODEL self._version = uri.object.version - def add_tags(self, tags: t.List[str], quiet: bool = False) -> None: - self.tag.add(tags, quiet) + def list_tags(self) -> t.List[str]: + return self.tag.list() - def remove_tags(self, tags: t.List[str], quiet: bool = False) -> None: - self.tag.remove(tags, quiet) + def add_tags(self, tags: t.List[str], ignore_errors: bool = False) -> None: + self.tag.add(tags, ignore_errors) + + def remove_tags(self, tags: t.List[str], ignore_errors: bool = False) -> None: + self.tag.remove(tags, ignore_errors) def _gen_steps(self, typ: str, ppl: str) -> None: if typ == EvalHandlerType.DEFAULT: @@ -430,7 +433,8 @@ def list( ) return rs, {} - def buildImpl(self, workdir: Path, **kw: t.Any) -> None: + def buildImpl(self, **kw: t.Any) -> None: + workdir = kw["workdir"] yaml_name = kw.get("yaml_name", DefaultYAMLName.MODEL) _mp = workdir / yaml_name _model_config = self.load_model_config(_mp) @@ -555,5 +559,5 @@ def list( crm = CloudRequestMixed() return crm._fetch_bundle_all_list(project_uri, URIType.MODEL, page, size) - def build(self, workdir: Path, yaml_name: str = "", **kw: t.Any) -> None: + def build(self, **kw: t.Any) -> None: raise NoSupportError("no support build model in the cloud instance") diff --git a/client/starwhale/core/model/view.py b/client/starwhale/core/model/view.py index 42a1ee8ae0..9a76bbb8c4 100644 --- a/client/starwhale/core/model/view.py +++ b/client/starwhale/core/model/view.py @@ -116,14 +116,15 @@ def build( project=project, bundle_name=_config.get("name"), typ=URIType.MODEL ) _m = Model.get_model(_model_uri) + kwargs = {"workdir": Path(workdir), "yaml_name": yaml_name} if runtime_uri: RuntimeProcess.from_runtime_uri( uri=runtime_uri, target=_m.build, - args=(Path(workdir), yaml_name), + kwargs=kwargs, ).run() else: - _m.build(Path(workdir), yaml_name) + _m.build(**kwargs) return _model_uri @classmethod @@ -138,13 +139,15 @@ def copy( console.print(":clap: copy done.") @BaseTermView._header - def tag(self, tags: t.List[str], remove: bool = False, quiet: bool = False) -> None: + def tag( + self, tags: t.List[str], remove: bool = False, ignore_errors: bool = False + ) -> None: if remove: console.print(f":golfer: remove tags [red]{tags}[/] @ {self.uri}...") - self.model.remove_tags(tags, quiet) + self.model.remove_tags(tags, ignore_errors) else: console.print(f":surfer: add tags [red]{tags}[/] @ {self.uri}...") - self.model.add_tags(tags, quiet) + self.model.add_tags(tags, ignore_errors) class ModelTermViewRich(ModelTermView): diff --git a/client/starwhale/core/runtime/model.py b/client/starwhale/core/runtime/model.py index d970fbe193..9d362ff5cd 100644 --- a/client/starwhale/core/runtime/model.py +++ b/client/starwhale/core/runtime/model.py @@ -702,11 +702,14 @@ def __init__(self, uri: URI) -> None: def info(self) -> t.Dict[str, t.Any]: return self._get_bundle_info() - def add_tags(self, tags: t.List[str], quiet: bool = False) -> None: - self.tag.add(tags, quiet) + def list_tags(self) -> t.List[str]: + return self.tag.list() - def remove_tags(self, tags: t.List[str], quiet: bool = False) -> None: - self.tag.remove(tags, quiet) + def add_tags(self, tags: t.List[str], ignore_errors: bool = False) -> None: + self.tag.add(tags, ignore_errors) + + def remove_tags(self, tags: t.List[str], ignore_errors: bool = False) -> None: + self.tag.remove(tags, ignore_errors) def remove(self, force: bool = False) -> t.Tuple[bool, str]: return self._do_remove(force) @@ -748,10 +751,10 @@ def history( def build( self, - workdir: Path, - yaml_name: str = DefaultYAMLName.RUNTIME, **kw: t.Any, ) -> None: + workdir = kw["workdir"] + yaml_name = kw.get("yaml_name", DefaultYAMLName.RUNTIME) disable_env_lock = kw.get("disable_env_lock", False) env_name = kw.get("env_name", "") env_prefix_path = kw.get("env_prefix_path", "") @@ -1693,5 +1696,5 @@ def list( crm = CloudRequestMixed() return crm._fetch_bundle_all_list(project_uri, URIType.RUNTIME, page, size) - def build(self, workdir: Path, yaml_name: str = "", **kw: t.Any) -> None: + def build(self, **kw: t.Any) -> None: raise NoSupportError("no support build runtime in the cloud instance") diff --git a/client/starwhale/core/runtime/view.py b/client/starwhale/core/runtime/view.py index 186ae557fb..6acc7cc2a2 100644 --- a/client/starwhale/core/runtime/view.py +++ b/client/starwhale/core/runtime/view.py @@ -125,8 +125,8 @@ def build( _rt = Runtime.get_runtime(_runtime_uri) _rt.build( - Path(workdir), - yaml_name, + workdir=Path(workdir), + yaml_name=yaml_name, gen_all_bundles=gen_all_bundles, include_editable=include_editable, disable_env_lock=disable_env_lock, @@ -226,14 +226,16 @@ def copy( console.print(":clap: copy done.") @BaseTermView._header - def tag(self, tags: t.List[str], remove: bool = False, quiet: bool = False) -> None: + def tag( + self, tags: t.List[str], remove: bool = False, ignore_errors: bool = False + ) -> None: # TODO: refactor model/runtime/dataset tag view-model if remove: console.print(f":golfer: remove tags [red]{tags}[/] @ {self.uri}...") - self.runtime.remove_tags(tags, quiet) + self.runtime.remove_tags(tags, ignore_errors) else: console.print(f":surfer: add tags [red]{tags}[/] @ {self.uri}...") - self.runtime.add_tags(tags, quiet) + self.runtime.add_tags(tags, ignore_errors) class RuntimeTermViewRich(RuntimeTermView): diff --git a/client/starwhale/utils/load.py b/client/starwhale/utils/load.py index df9c5825d0..ce8c99a7f4 100644 --- a/client/starwhale/utils/load.py +++ b/client/starwhale/utils/load.py @@ -14,8 +14,10 @@ ) -def import_object(workdir: Path, handler_path: str, py_env: str = "") -> t.Any: - workdir_path = str(workdir.absolute()) +def import_object( + workdir: t.Union[Path, str], handler_path: str, py_env: str = "" +) -> t.Any: + workdir_path = str(Path(workdir).absolute()) external_paths = [workdir_path] py_env = py_env or guess_current_py_env() _ok, _cur_py, _ex_py = check_python_interpreter_consistency(py_env) diff --git a/client/tests/base/test_tag.py b/client/tests/base/test_tag.py index 5fa814aceb..f54948a7fc 100644 --- a/client/tests/base/test_tag.py +++ b/client/tests/base/test_tag.py @@ -49,7 +49,7 @@ def test_tag_workflow(self) -> None: assert set(st.list()) == {"test", "latest", "me3"} - st.remove(["latest", "notfound"], quiet=True) + st.remove(["latest", "notfound"], ignore_errors=True) assert set(st.list()) == {"test", "me3"} _manifest = st._get_manifest() assert "latest" not in _manifest["tags"] @@ -96,7 +96,7 @@ def test_auto_fast_tag(self) -> None: assert st._get_manifest()["fast_tag_seq"] == 7 assert st._get_manifest()["tags"]["v7"] == version - st.remove(["v7", "v6"], quiet=True) + st.remove(["v7", "v6"], ignore_errors=True) st.add_fast_tag() assert st._get_manifest()["fast_tag_seq"] == 8 assert st._get_manifest()["tags"]["v8"] == version diff --git a/client/tests/core/test_dataset.py b/client/tests/core/test_dataset.py index 7ed8d774d2..a7556e0707 100644 --- a/client/tests/core/test_dataset.py +++ b/client/tests/core/test_dataset.py @@ -60,30 +60,17 @@ def setUp(self) -> None: @patch("starwhale.api._impl.dataset.builder.UserRawBuildExecutor.make_swds") @patch("starwhale.api._impl.dataset.builder.SWDSBinBuildExecutor.make_swds") - @patch("starwhale.core.dataset.model.import_object") def test_function_handler_make_swds( - self, m_import: MagicMock, m_swds_bin: MagicMock, m_user_raw: MagicMock + self, m_swds_bin: MagicMock, m_user_raw: MagicMock ) -> None: name = "mnist" dataset_uri = URI(name, expected_type=URIType.DATASET) sd = StandaloneDataset(dataset_uri) sd._version = "112233" - workdir = "/home/starwhale/myproject" - config = DatasetConfig(name, handler="mnist:handler") - - kwargs = dict( - workdir=Path(workdir), - swds_config=config, - append=False, - append_from_uri=None, - append_from_store=None, - ) + swds_config = DatasetConfig(name=name, handler=lambda: 1) - m_import.return_value = lambda: 1 with self.assertRaises(RuntimeError): - sd._call_make_swds(**kwargs) # type: ignore - - m_import.reset_mock() + sd._call_make_swds(swds_config) def _iter_swds_bin_item() -> t.Generator: yield b"", {} @@ -91,15 +78,16 @@ def _iter_swds_bin_item() -> t.Generator: def _iter_user_raw_item() -> t.Generator: yield Link(""), {} - m_import.return_value = _iter_swds_bin_item - sd._call_make_swds(**kwargs) # type: ignore + swds_config.handler = _iter_swds_bin_item + sd._call_make_swds(swds_config) assert m_swds_bin.call_count == 1 - m_import.return_value = _iter_user_raw_item - sd._call_make_swds(**kwargs) # type: ignore + swds_config.handler = _iter_user_raw_item + sd._call_make_swds(swds_config) assert m_user_raw.call_count == 1 - def test_build_only_cli(self) -> None: + @patch("starwhale.core.dataset.cli.import_object") + def test_build_only_cli(self, m_import: MagicMock) -> None: workdir = "/tmp/workdir" ensure_dir(workdir) @@ -118,10 +106,11 @@ def test_build_only_cli(self) -> None: call_args = mock_obj.build.call_args[0] assert call_args[0] == workdir assert call_args[1].name == "mnist" - assert call_args[1].handler == "mnist:test" assert call_args[1].append is not None + assert m_import.call_args[0][1] == "mnist:test" - def test_build_only_yaml(self) -> None: + @patch("starwhale.core.dataset.cli.import_object") + def test_build_only_yaml(self, m_import: MagicMock) -> None: workdir = "/tmp/workdir" ensure_dir(workdir) @@ -144,9 +133,9 @@ def test_build_only_yaml(self) -> None: assert mock_obj.build.call_count == 1 call_args = mock_obj.build.call_args[0] assert call_args[1].name == "mnist" - assert call_args[1].handler == "dataset:build" assert call_args[1].append assert call_args[1].append_from == "112233" + assert m_import.call_args[0][1] == "dataset:build" new_workdir = "/tmp/workdir-new" ensure_dir(new_workdir) @@ -155,28 +144,26 @@ def test_build_only_yaml(self) -> None: assert new_yaml_path.exists() and not yaml_path.exists() mock_obj.reset_mock() - result = runner.invoke(build_cli, [new_workdir], obj=mock_obj) - - assert result.exit_code == 1 - assert result.exception - - mock_obj.reset_mock() + m_import.reset_mock() result = runner.invoke( build_cli, [new_workdir, "-f", "dataset-new.yaml"], obj=mock_obj ) assert result.exit_code == 0 assert mock_obj.build.call_count == 1 assert call_args[1].name == "mnist" - assert call_args[1].handler == "dataset:build" assert call_args[1].append assert call_args[1].append_from == "112233" + assert m_import.call_args[0][1] == "dataset:build" - def test_build_mixed_cli_yaml(self) -> None: + @patch("starwhale.core.dataset.cli.import_object") + def test_build_mixed_cli_yaml(self, m_import: MagicMock) -> None: + handler_func = lambda: 1 + m_import.return_value = handler_func workdir = "/tmp/workdir" ensure_dir(workdir) config = DatasetConfig( name="mnist-error", - handler="dataset:buildClass", + handler="dataset:not_found", append=True, append_from="112233", ) @@ -205,23 +192,19 @@ def test_build_mixed_cli_yaml(self) -> None: assert mock_obj.build.call_count == 1 call_args = mock_obj.build.call_args[0] assert call_args[1].name == "mnist" - assert call_args[1].handler == "dataset:buildFunction" + assert call_args[1].handler == handler_func assert call_args[1].append assert call_args[1].append_from == "112233" assert call_args[1].attr.data_mime_type == MIMEType.MP4 assert call_args[1].attr.volume_size == D_FILE_VOLUME_SIZE @patch("starwhale.core.dataset.model.copy_fs") - @patch("starwhale.core.dataset.model.import_object") def test_build_workflow( self, - m_import: MagicMock, m_copy_fs: MagicMock, ) -> None: sw = SWCliConfigMixed() - m_import.return_value = MockBuildExecutor - workdir = "/home/starwhale/myproject" name = "mnist" @@ -229,9 +212,10 @@ def test_build_workflow( ensure_file(os.path.join(workdir, "mnist.py"), " ") config = DatasetConfig(**yaml.safe_load(_dataset_yaml)) + config.handler = MockBuildExecutor dataset_uri = URI(name, expected_type=URIType.DATASET) sd = StandaloneDataset(dataset_uri) - sd.build(Path(workdir), config=config) + sd.build(workdir=Path(workdir), config=config) build_version = sd.uri.object.version snapshot_workdir = ( @@ -243,10 +227,6 @@ def test_build_workflow( / f"{build_version}{BundleType.DATASET}" ) - assert m_import.call_count == 1 - assert m_import.call_args[0][0] == Path(workdir) - assert m_import.call_args[0][1] == "mnist.dataset:DatasetProcessExecutor" - assert snapshot_workdir.exists() assert (snapshot_workdir / "data").exists() assert (snapshot_workdir / "src").exists() diff --git a/client/tests/core/test_model.py b/client/tests/core/test_model.py index be0595ca57..c99290f3cf 100644 --- a/client/tests/core/test_model.py +++ b/client/tests/core/test_model.py @@ -55,7 +55,7 @@ def setUp(self) -> None: def test_build_workflow(self, m_copy_fs: MagicMock, m_copy_file: MagicMock) -> None: model_uri = URI(self.name, expected_type=URIType.MODEL) sm = StandaloneModel(model_uri) - sm.build(Path(self.workdir)) + sm.build(workdir=Path(self.workdir)) build_version = sm.uri.object.version diff --git a/client/tests/core/test_runtime.py b/client/tests/core/test_runtime.py index ae42664c47..60e60a7e3b 100644 --- a/client/tests/core/test_runtime.py +++ b/client/tests/core/test_runtime.py @@ -280,7 +280,7 @@ def test_build_venv( uri = URI(name, expected_type=URIType.RUNTIME) sr = StandaloneRuntime(uri) - sr.build(Path(workdir), enable_lock=True, env_prefix_path=venv_dir) + sr.build(workdir=Path(workdir), enable_lock=True, env_prefix_path=venv_dir) assert sr.uri.object.version != "" assert len(sr._version) == 40 build_version = sr._version @@ -477,7 +477,7 @@ def test_build_venv( uri = URI(name, expected_type=URIType.RUNTIME) sr = StandaloneRuntime(uri) sr.build( - Path(workdir), + workdir=Path(workdir), enable_lock=True, env_prefix_path=venv_dir, gen_all_bundles=True, @@ -539,7 +539,7 @@ def test_build_conda( self.fs.create_file(os.path.join(workdir, "dummy.whl"), contents="") uri = URI(name, expected_type=URIType.RUNTIME) sr = StandaloneRuntime(uri) - sr.build(Path(workdir), env_use_shell=True) + sr.build(workdir=Path(workdir), env_use_shell=True) sr.info() sr.history() @@ -604,7 +604,7 @@ def test_build_without_python_version( uri = URI(name, expected_type=URIType.RUNTIME) sr = StandaloneRuntime(uri) with self.assertRaises(ConfigFormatError): - sr.build(Path(workdir), env_use_shell=True) + sr.build(workdir=Path(workdir), env_use_shell=True) assert m_check_call.call_args[0][0][:6] == [ "conda", @@ -622,7 +622,7 @@ def test_build_without_python_version( m_py_ver.assert_called_once() m_py_ver.return_value = "3.10" - sr.build(Path(workdir), env_use_shell=True) + sr.build(workdir=Path(workdir), env_use_shell=True) m_py_ver.assert_has_calls([call(), call()]) sw = SWCliConfigMixed() @@ -668,13 +668,14 @@ def test_build_with_docker_image_specified( yaml_content = { "name": name, "mode": "conda", + "environment": {}, } yaml_file = os.path.join(workdir, DefaultYAMLName.RUNTIME) self.fs.create_file(yaml_file, contents=yaml.safe_dump(yaml_content)) uri = URI(name, expected_type=URIType.RUNTIME) sr = StandaloneRuntime(uri) - sr.build(Path(workdir), env_use_shell=True) + sr.build(workdir=Path(workdir), env_use_shell=True) sw = SWCliConfigMixed() runtime_workdir = os.path.join( @@ -694,7 +695,7 @@ def test_build_with_docker_image_specified( self.fs.remove_object(yaml_file) self.fs.create_file(yaml_file, contents=yaml.safe_dump(yaml_content)) sr = StandaloneRuntime(URI(name, expected_type=URIType.RUNTIME)) - sr.build(Path(workdir), env_use_shell=True) + sr.build(workdir=Path(workdir), env_use_shell=True) runtime_workdir = os.path.join( sw.rootdir, "self", diff --git a/client/tests/sdk/test_data_store.py b/client/tests/sdk/test_data_store.py index b25e4d97f2..ad1b8a4b80 100644 --- a/client/tests/sdk/test_data_store.py +++ b/client/tests/sdk/test_data_store.py @@ -147,6 +147,20 @@ def test_write_and_scan(self) -> None: ), "with start and end", ) + + self.assertEqual( + [{"*": 1, "-": True, "i": "y", "j": 11}, {"*": 2, "i": "z"}], + list( + data_store._scan_parquet_file( + path, + columns={"b": "i", "c": "j"}, + start=1, + end=2, + end_inclusive=True, + ) + ), + "with start and end, with end inclusive", + ) self.assertEqual( [ {"*": 0, "a": 0, "b": "x", "c": 10}, @@ -544,6 +558,19 @@ def test_scan_table(self) -> None: ), "with start and end", ) + self.assertEqual( + [{"*": 2, "j": "2"}, {"*": 3, "i": "3"}], + list( + data_store._scan_table( + self.datastore_root, + {"a": "i", "b": "j"}, + start=2, + end=3, + end_inclusive=True, + ) + ), + "with start and end(inclusive)", + ) self.assertEqual( [ {"*": 0, "k": 0, "a": "0", "b": "0"}, @@ -973,6 +1000,13 @@ def test_mixed(self) -> None: list(table.scan(keep_none=True)), "keep none", ) + self.assertEqual( + [ + {"*": 0, "k": 0, "a": "0"}, + ], + list(table.scan(start=0, end=0, end_inclusive=True)), + "one row", + ) self.assertEqual( [ {"*": 0, "k": 0, "x": "0"}, diff --git a/client/tests/sdk/test_dataset.py b/client/tests/sdk/test_dataset.py index 19be112d4b..9f82a01c32 100644 --- a/client/tests/sdk/test_dataset.py +++ b/client/tests/sdk/test_dataset.py @@ -12,6 +12,7 @@ from unittest.mock import patch, MagicMock from concurrent.futures import as_completed, ThreadPoolExecutor +import pytest from requests_mock import Mocker from pyfakefs.fake_filesystem_unittest import TestCase @@ -49,6 +50,7 @@ ArtifactType, BaseArtifact, GrayscaleImage, + DefaultS3LinkAuth, COCOObjectAnnotation, ) from starwhale.core.dataset.store import DatasetStorage @@ -62,7 +64,9 @@ local_standalone_tdsc, get_dataset_consumption, ) +from starwhale.api._impl.dataset.loader import DataRow from starwhale.api._impl.dataset.builder import ( + RowWriter, _data_magic, _header_size, _header_magic, @@ -1406,3 +1410,260 @@ def test_row(self) -> None: for r in (s_row, u_row, l_row): copy_r = TabularDatasetRow.from_datastore(**r.asdict()) assert copy_r == r + + +class TestRowWriter(BaseTestCase): + def setUp(self) -> None: + super().setUp() + + @patch("starwhale.api._impl.dataset.builder.SWDSBinBuildExecutor.make_swds") + def test_update(self, m_make_swds: MagicMock) -> None: + rw = RowWriter(dataset_name="mnist", dataset_version="123456") + + assert rw._builder is None + assert not rw.is_alive() + assert rw._queue.empty() + + rw._builder = MagicMock() + + rw.update(DataRow(index=1, data=Binary(b"test"), annotations={"label": 1})) + first_builder = rw._builder + assert rw._builder is not None + assert rw._queue.qsize() == 1 + + rw.update(DataRow(index=2, data=Binary(b"test"), annotations={"label": 2})) + second_builder = rw._builder + assert first_builder == second_builder + assert rw._queue.qsize() == 2 + + rw._builder = None + rw.update(DataRow(index=3, data=Binary(b"test"), annotations={"label": 3})) + assert rw._builder is not None + assert rw.isDaemon() + assert isinstance(rw._builder, SWDSBinBuildExecutor) + assert m_make_swds.call_count == 1 + + @pytest.mark.filterwarnings("ignore::pytest.PytestUnhandledThreadExceptionWarning") + @patch("starwhale.api._impl.dataset.builder.SWDSBinBuildExecutor.make_swds") + def test_update_exception(self, m_make_swds: MagicMock) -> None: + rw = RowWriter(dataset_name="mnist", dataset_version="123456") + + assert rw._raise_run_exception() is None + rw._builder = MagicMock() + rw.update(DataRow(index=1, data=Binary(b"test"), annotations={"label": 1})) + assert rw._run_exception is None + + rw._run_exception = ValueError("test") + with self.assertRaises(threading.ThreadError): + rw.update(DataRow(index=1, data=Binary(b"test"), annotations={"label": 1})) + + rw._run_exception = None + rw._builder = None + m_make_swds.side_effect = TypeError("thread test") + with self.assertRaises(threading.ThreadError): + rw.update(DataRow(index=2, data=Binary(b"test"), annotations={"label": 2})) + rw.join() + rw.update(DataRow(index=3, data=Binary(b"test"), annotations={"label": 3})) + + def test_iter(self) -> None: + rw = RowWriter(dataset_name="mnist", dataset_version="123456") + rw._builder = MagicMock() + size = 10 + for i in range(0, size): + rw.update(DataRow(index=i, data=Binary(b"test"), annotations={"label": i})) + + rw.update(None) # type: ignore + assert not rw.is_alive() + assert rw._queue.qsize() == size + 1 + + items = list(rw) + assert len(items) == size + assert items[0].index == 0 + assert items[9].index == 9 + + def test_iter_block(self) -> None: + rw = RowWriter(dataset_name="mnist", dataset_version="123456") + rw._builder = MagicMock() + rw.update(DataRow(index=1, data=Binary(b"test"), annotations={"label": 1})) + + thread = threading.Thread(target=lambda: list(rw), daemon=True) + thread.start() + assert thread.is_alive() + time.sleep(1) + assert thread.is_alive() + + rw.update(None) # type: ignore + time.sleep(0.1) + assert not thread.is_alive() + + def test_iter_none(self) -> None: + rw = RowWriter(dataset_name="mnist", dataset_version="123456") + rw._builder = MagicMock() + size = 10 + for _ in range(0, size): + rw.update(None) # type: ignore + + assert rw._queue.qsize() == size + assert len(list(rw)) == 0 + + def test_iter_merge_none(self) -> None: + rw = RowWriter(dataset_name="mnist", dataset_version="123456") + rw._builder = MagicMock() + size = 10 + for _ in range(0, size): + rw.update(None) # type: ignore + + rw.update(DataRow(index=1, data=Binary(b"test"), annotations={"label": 1})) + rw.update(None) # type: ignore + + assert rw._queue.qsize() == size + 2 + items = list(rw) + assert len(items) == 1 + assert items[0].index == 1 + + @patch("starwhale.api._impl.dataset.builder.SWDSBinBuildExecutor.make_swds") + def test_close(self, m_make_swds: MagicMock) -> None: + rw = RowWriter(dataset_name="mnist", dataset_version="123456") + rw.update(DataRow(index=1, data=Binary(b"test"), annotations={"label": 1})) + rw.close() + assert not rw.is_alive() + + with RowWriter(dataset_name="mnist", dataset_version="123456") as context_rw: + context_rw.update( + DataRow(index=1, data=Binary(b"test"), annotations={"label": 1}) + ) + assert not rw.is_alive() + + def test_make_swds_bin(self) -> None: + workdir = Path(self.local_storage) / ".user" / "workdir" + + assert not workdir.exists() + rw = RowWriter(dataset_name="mnist", dataset_version="123456", workdir=workdir) + assert rw._builder is None + size = 100 + for i in range(0, size): + rw.update(DataRow(index=i, data=Binary(b"test"), annotations={"label": i})) + rw.close() + + assert isinstance(rw._builder, SWDSBinBuildExecutor) + assert rw._queue.qsize() == 0 + assert rw.summary.rows == size + assert not rw.summary.include_link + assert rw.summary.annotations == ["label"] + + data_dir = workdir / "data" + assert data_dir.exists() + files = list(data_dir.iterdir()) + assert len(files) == 1 + assert files[0].is_symlink() + + def test_make_user_raw(self) -> None: + user_dir = Path(self.local_storage) / ".user" + raw_data_file = user_dir / "data_file" + raw_content = "123" + ensure_dir(user_dir) + ensure_file(raw_data_file, content=raw_content) + + workdir = user_dir / "workdir" + assert not workdir.exists() + rw = RowWriter(dataset_name="mnist", dataset_version="123456", workdir=workdir) + assert rw._builder is None + size = 100 + for i in range(0, size): + rw.update( + DataRow( + index=i, + data=Link(uri=raw_data_file, with_local_fs_data=True), + annotations={"label": i, "label2": 2}, + ) + ) + rw.close() + + assert isinstance(rw._builder, UserRawBuildExecutor) + assert rw._queue.qsize() == 0 + assert rw.summary.rows == size + assert not rw.summary.include_link + assert rw.summary.include_user_raw + assert rw.summary.annotations == ["label", "label2"] + + data_dir = workdir / "data" + assert data_dir.exists() + files = list(data_dir.iterdir()) + + assert len(files) == 1 + assert files[0].is_symlink() + assert files[0].read_text() == raw_content + + def test_make_link(self) -> None: + user_dir = Path(self.local_storage) / ".user" + raw_data_file = user_dir / "data_file" + raw_content = "123" + ensure_dir(user_dir) + ensure_file(raw_data_file, content=raw_content) + + workdir = user_dir / "workdir" + assert not workdir.exists() + rw = RowWriter(dataset_name="mnist", dataset_version="123456", workdir=workdir) + assert rw._builder is None + size = 100 + for i in range(0, size): + rw.update( + DataRow( + index=i, + data=Link(uri="minio://1/1/1/", auth=DefaultS3LinkAuth), + annotations={"label": i, "label2": 2}, + ) + ) + rw.close() + + assert isinstance(rw._builder, UserRawBuildExecutor) + assert rw._queue.qsize() == 0 + assert rw.summary.rows == size + assert rw.summary.include_link + assert rw.summary.include_user_raw + assert rw.summary.annotations == ["label", "label2"] + + data_dir = workdir / "data" + assert data_dir.exists() + files = list(data_dir.iterdir()) + assert len(files) == 0 + + @patch("starwhale.api._impl.dataset.builder.SWDSBinBuildExecutor.make_swds") + def test_append_swds_bin(self, m_make_swds: MagicMock) -> None: + rw = RowWriter( + dataset_name="mnist", + dataset_version="123456", + append=True, + append_from_version="abcdefg", + append_with_swds_bin=True, + ) + assert isinstance(rw._builder, SWDSBinBuildExecutor) + + @patch("starwhale.api._impl.dataset.builder.UserRawBuildExecutor.make_swds") + def test_append_user_raw(self, m_make_swds: MagicMock) -> None: + rw = RowWriter( + dataset_name="mnist", + dataset_version="123456", + append=True, + append_from_version="abcdefg", + append_with_swds_bin=False, + ) + assert isinstance(rw._builder, UserRawBuildExecutor) + + def test_flush(self) -> None: + rw = RowWriter(dataset_name="mnist", dataset_version="123456") + rw._builder = MagicMock() + rw.flush() + + rw.update(DataRow(index=1, data=Binary(b"test"), annotations={"label": 1})) + thread = threading.Thread(target=rw.flush, daemon=True) + thread.start() + time.sleep(0.2) + assert thread.is_alive() + + item = rw._queue.get(block=True) + assert item.index == 1 # type: ignore + time.sleep(0.2) + assert not thread.is_alive() + + rw.flush() diff --git a/client/tests/sdk/test_dataset_sdk.py b/client/tests/sdk/test_dataset_sdk.py new file mode 100644 index 0000000000..b782d97906 --- /dev/null +++ b/client/tests/sdk/test_dataset_sdk.py @@ -0,0 +1,747 @@ +import typing as t +from http import HTTPStatus +from pathlib import Path +from unittest.mock import MagicMock +from concurrent.futures import as_completed, ThreadPoolExecutor + +import yaml +from requests_mock import Mocker + +from starwhale import dataset +from starwhale.consts import HTTPMethod +from starwhale.base.uri import URI +from starwhale.utils.fs import ensure_dir, ensure_file +from starwhale.base.type import URIType +from starwhale.utils.error import ExistedError, NotFoundError, NoSupportError +from starwhale.utils.config import SWCliConfigMixed +from starwhale.core.dataset.type import Binary, DatasetSummary +from starwhale.api._impl.dataset.loader import DataRow + +from .test_base import BaseTestCase + + +class TestDatasetSDK(BaseTestCase): + def setUp(self) -> None: + super().setUp() + + def _init_simple_dataset(self) -> URI: + with dataset("mnist", create=True) as ds: + for i in range(0, 10): + ds.append( + DataRow( + index=i, + data=Binary(f"data-{i}".encode()), + annotations={"label": i}, + ) + ) + ds.commit() + return ds.uri + + def _init_simple_dataset_with_str_id(self) -> URI: + with dataset("mnist", create=True) as ds: + for i in range(0, 10): + ds.append( + DataRow( + index=f"{i}", + data=Binary(f"data-{i}".encode()), + annotations={"label": i}, + ) + ) + ds.commit() + return ds.uri + + def test_create_from_empty(self) -> None: + ds = dataset("mnist", create=True) + assert ds.version != "" + assert ds.project_uri.full_uri == "local/project/self/dataset/mnist" + assert ds.uri.object.name == ds.name == "mnist" + assert ds.uri.object.typ == URIType.DATASET + assert ds.uri.object.version == ds.version + assert ds.uri.project == "self" + assert ds.uri.instance == "local" + + assert not ds.readonly + assert ds._append_from_version == "" + assert not ds._create_by_append + assert len(ds) == 0 + assert bool(ds) + + def test_append(self) -> None: + size = 11 + ds = dataset("mnist", create=True) + assert len(ds) == 0 + ds.append(DataRow(index=0, data=Binary(b""), annotations={"label": 1})) + assert len(ds) == 1 + for i in range(1, size): + ds.append((i, Binary(), {"label": i})) + assert len(ds) == size + + ds.append((Binary(), {"label": 1})) + + with self.assertRaises(TypeError): + ds.append(1) + + with self.assertRaises(ValueError): + ds.append((1, 1, 1, 1, 1)) + + ds.commit() + ds.close() + + load_ds = dataset(ds.uri) + assert len(load_ds) == size + 1 + + for i, d in enumerate(ds): + assert d.index == i + + def test_extend(self) -> None: + ds = dataset("mnist", create=True) + assert len(ds) == 0 + size = 10 + ds.extend( + [ + DataRow(index=i, data=Binary(), annotations={"label": i}) + for i in range(0, size) + ] + ) + ds.extend([]) + + with self.assertRaises(TypeError): + ds.extend(None) + + with self.assertRaises(TypeError): + ds.extend([None]) + + assert len(ds) == size + ds.commit() + ds.close() + + load_ds = dataset(ds.uri) + assert load_ds.exists() + assert len(load_ds) == size + assert load_ds[0].index == 0 # type: ignore + assert load_ds[9].index == 9 # type: ignore + + def test_setitem(self) -> None: + ds = dataset("mnist", create=True) + assert len(ds) == 0 + assert ds._row_writer is None + + ds["index-2"] = DataRow( + index="index-2", data=Binary(), annotations={"label": 2} + ) + ds["index-1"] = DataRow( + index="index-1", data=Binary(), annotations={"label": 1} + ) + + assert len(ds) == 2 + assert ds._row_writer is not None + assert ds._row_writer._kw["dataset_name"] == ds.name + assert ds._row_writer._kw["dataset_version"] == ds.version + assert not ds._row_writer._kw["append"] + + ds["index-4"] = "index-4", Binary(), {"label": 4} + ds["index-3"] = Binary(), {"label": 3} + + with self.assertRaises(ValueError): + ds["index-5"] = (1,) + + with self.assertRaises(TypeError): + ds["index-6"] = 1 + + assert len(ds) == 4 + ds.commit() + ds.close() + + load_ds = dataset(ds.uri) + assert len(load_ds) == 4 + index_names = [d.index for d in load_ds] + assert index_names == ["index-1", "index-2", "index-3", "index-4"] + + def test_setitem_exceptions(self) -> None: + ds = dataset("mnist", create=True) + with self.assertRaises(TypeError): + ds[1:3] = ((1, Binary(), {}), (2, Binary(), {})) + + with self.assertRaises(TypeError): + ds[DataRow(1, Binary(), {})] = DataRow(1, Binary(), {}) + + def test_parallel_setitem(self) -> None: + ds = dataset("mnist", create=True) + + size = 100 + + def _do_task(_start: int) -> None: + for i in range(_start, size): + ds.append(DataRow(index=i, data=Binary(), annotations={"label": i})) + + pool = ThreadPoolExecutor(max_workers=10) + tasks = [pool.submit(_do_task, i * 10) for i in range(0, 9)] + list(as_completed(tasks)) + + ds.commit() + ds.close() + + load_ds = dataset(ds.uri) + assert len(load_ds) == size + items = list(load_ds) + assert items[0].index == 0 + assert items[-1].index == 99 + + def test_setitem_same_key(self) -> None: + ds = dataset("mnist", create=True) + ds.append(DataRow(1, Binary(b""), {"label": "1-1"})) + assert len(ds) == 1 + + for i in range(0, 10): + ds[2] = Binary(b""), {"label": f"2-{i}"} + + assert len(ds) == 2 + ds.append(DataRow(3, Binary(b""), {"label": "3-1"})) + + assert len(ds) == 3 + ds.commit() + ds.close() + + load_ds = dataset(ds.uri) + assert len(list(load_ds)) == 3 + assert load_ds[2].annotations == {"label": "2-9"} # type: ignore + assert len(load_ds) == 3 + + def test_readonly(self) -> None: + existed_ds_uri = self._init_simple_dataset() + ds = dataset(existed_ds_uri) + + assert ds.readonly + readonly_msg = "in the readonly mode" + with self.assertRaisesRegex(RuntimeError, readonly_msg): + ds.append(DataRow(1, Binary(), {})) + + with self.assertRaisesRegex(RuntimeError, readonly_msg): + ds.extend([DataRow(1, Binary(), {})]) + + with self.assertRaisesRegex(RuntimeError, readonly_msg): + ds[1] = Binary(), {} + + with self.assertRaisesRegex(RuntimeError, readonly_msg): + ds.flush() + + def test_del_item_from_existed(self) -> None: + existed_ds_uri = self._init_simple_dataset() + ds = dataset(existed_ds_uri) + + with self.assertRaisesRegex(RuntimeError, "in the readonly mode"): + del ds[1] + + ds = dataset(existed_ds_uri, create=True) + del ds[0] + assert len(ds) == 9 + ds.flush() + + del ds[0] + assert len(ds) == 9 + del ds[6:] + assert len(ds) == 5 + + ds.commit() + ds.close() + + ds = dataset(ds.uri) + items = [d.index for d in ds] + assert items == [1, 2, 3, 4, 5] + + def test_del_not_found(self) -> None: + ds = dataset("mnist", create=True) + del ds[0] + del ds["1"] + del ds["not-found"] + + def test_del_item_from_empty(self) -> None: + with dataset("mnist", create=True) as ds: + for i in range(0, 3): + ds.append(DataRow(i, Binary(), {"label": i})) + + ds.flush() + del ds[0] + del ds[1] + ds.commit() + + reopen_ds = dataset(ds.uri) + assert len(reopen_ds) == 1 + items = list(reopen_ds) + assert len(items) == 1 + assert items[0].index == 2 + + def test_build_no_data(self) -> None: + ds = dataset("mnist", create=True) + msg = "no data to build dataset" + with self.assertRaisesRegex(RuntimeError, msg): + ds.build() + + existed_ds_uri = self._init_simple_dataset_with_str_id() + ds = dataset(existed_ds_uri, create=True) + with self.assertRaisesRegex(RuntimeError, msg): + ds.build() + + def test_build_from_handler_empty(self) -> None: + def _handler() -> t.Generator: + for i in range(0, 100): + yield i, Binary(), {"label": i} + + ds = dataset("mnist", create=True) + ds.build_handler = _handler + ds.commit() + ds.close() + + reopen_ds = dataset(ds.uri) + assert len(reopen_ds) == 100 + assert reopen_ds[0].index == 0 # type: ignore + assert reopen_ds[0].annotations == {"label": 0} # type: ignore + items = list(reopen_ds) + assert items[-1].index == 99 + assert items[-1].annotations == {"label": 99} + + def test_build_from_handler_existed(self) -> None: + def _handler() -> t.Generator: + for i in range(0, 100): + yield f"label-{i}", Binary(), {"label": i} + + existed_ds_uri = self._init_simple_dataset_with_str_id() + with dataset(existed_ds_uri, create_from_handler=_handler) as ds: + assert ds._create_by_append + ds.commit() + + reopen_ds = dataset(ds.uri) + assert len(reopen_ds) == 110 + summary = reopen_ds.summary() + assert isinstance(summary, DatasetSummary) + assert summary.rows == 110 + assert not summary.include_link + assert not summary.include_user_raw + assert summary.increased_rows == 100 + items = list(reopen_ds) + assert len(items) == 110 + assert items[0].index == "0" + assert items[-1].index == "label-99" + + def test_build_from_handler_with_copy_src(self) -> None: + def _handler() -> t.Generator: + for i in range(0, 100): + yield DataRow(f"label-{i}", Binary(), {"label": i}) + + workdir = Path(self.local_storage) / ".data" + ensure_dir(workdir) + ensure_file(workdir / "t.py", content="") + + ds = dataset("mnist", create_from_handler=_handler) + ds.build_with_copy_src(workdir) + ds.commit() + ds.close() + + reopen_ds = dataset(ds.uri) + assert reopen_ds.exists() + + _uri = reopen_ds.uri + dataset_dir = ( + Path(self.local_storage) + / _uri.project + / "dataset" + / reopen_ds.name + / reopen_ds.version[:2] + / f"{reopen_ds.version}.swds" + / "src" + ) + assert dataset_dir.exists() + assert (dataset_dir / "t.py").exists() + + def test_forbid_handler(self) -> None: + ds = dataset("mnist", create=True) + for i in range(0, 3): + ds.append(DataRow(i, Binary(), {"label": i})) + + assert ds._trigger_icode_build + assert not ds._trigger_handler_build + + with self.assertRaisesRegex( + RuntimeError, "dataset append by interactive code has already been called" + ): + ds.build_handler = MagicMock() + + def test_forbid_icode(self) -> None: + ds = dataset("mnist", create=True) + ds.build_handler = MagicMock() + assert ds._trigger_handler_build + assert not ds._trigger_icode_build + + msg = "no support build from handler and from cache code at the same time" + with self.assertRaisesRegex(NoSupportError, msg): + ds.append(DataRow(1, Binary(), {"label": 1})) + + with self.assertRaisesRegex(NoSupportError, msg): + ds.extend([DataRow(1, Binary(), {"label": 1})]) + + with self.assertRaisesRegex(NoSupportError, msg): + ds[1] = DataRow(1, Binary(), {"label": 1}) + + with self.assertRaisesRegex(NoSupportError, msg): + del ds[1] + + ds = dataset("mnist", create_from_handler=MagicMock()) + assert ds._trigger_handler_build + assert not ds._trigger_icode_build + with self.assertRaisesRegex(NoSupportError, msg): + ds.append(DataRow(1, Binary(), {"label": 1})) + + with self.assertRaisesRegex(NoSupportError, msg): + ds.extend([DataRow(1, Binary(), {"label": 1})]) + + with self.assertRaisesRegex(NoSupportError, msg): + ds[1] = DataRow(1, Binary(), {"label": 1}) + + with self.assertRaisesRegex(NoSupportError, msg): + del ds[1] + + def test_close(self) -> None: + ds = dataset("mnist", create=True) + ds.close() + + existed_ds_uri = self._init_simple_dataset() + ds = dataset(existed_ds_uri) + ds.close() + ds.close() + + def test_create_from_existed(self) -> None: + existed_ds_uri = self._init_simple_dataset() + ds = dataset(existed_ds_uri, create=True) + + assert ds.version != existed_ds_uri.object.version + assert ds.name == existed_ds_uri.object.name + assert ds.project_uri.project == existed_ds_uri.project + assert ds.version == ds.uri.object.version + assert not ds.readonly + assert not ds.exists() + assert ds._append_from_version == existed_ds_uri.object.version + assert ds._create_by_append + assert len(ds) == 10 + + ds.append(DataRow(index=1, data=Binary(b""), annotations={"label": 101})) + ds.append(DataRow(index=100, data=Binary(b""), annotations={"label": 100})) + ds.append(DataRow(index=101, data=Binary(b""), annotations={"label": 101})) + ds.flush() + assert len(ds) == 12 + ds.commit() + ds.close() + + assert ds[1].annotations == {"label": 101} # type: ignore + + _summary = ds.summary() + assert _summary is not None + assert _summary.rows == 12 + + def test_load_from_empty(self) -> None: + with self.assertRaises(ValueError): + dataset("mnist") + + with self.assertRaises(ExistedError): + dataset("mnist/version/not_found") + + def test_load_from_existed(self) -> None: + existed_ds_uri = self._init_simple_dataset() + ds = dataset(existed_ds_uri) + assert ds.version == ds.uri.object.version == existed_ds_uri.object.version + assert ds.readonly + assert ds.name == existed_ds_uri.object.name + + _summary = ds.summary() + assert _summary is not None + assert _summary.rows == len(ds) == 10 + assert ds._append_from_version == "" + assert not ds._create_by_append + + _d = ds[0] + assert isinstance(_d, DataRow) + assert _d.index == 0 + assert _d.data == Binary(b"data-0") + assert _d.annotations == {"label": 0} + + def test_load_with_tag(self) -> None: + existed_ds_uri = self._init_simple_dataset() + name = existed_ds_uri.object.name + ds = dataset(f"{name}/version/latest") + assert ds.exists() + assert ds.version == existed_ds_uri.object.version + + def test_load_with_short_version(self) -> None: + existed_ds_uri = self._init_simple_dataset() + name = existed_ds_uri.object.name + version = existed_ds_uri.object.version + ds = dataset(f"{name}/version/{version[:7]}") + assert ds.exists() + assert ds.version == existed_ds_uri.object.version + + def test_iter(self) -> None: + existed_ds_uri = self._init_simple_dataset() + ds = dataset(existed_ds_uri) + items = list(ds) + assert len(items) == 10 + assert items[0].index == 0 + + ds = dataset(existed_ds_uri) + cnt = 0 + for item in ds: + cnt += 1 + assert isinstance(item, DataRow) + assert cnt == 10 + + def test_get_item_by_int_id(self) -> None: + existed_ds_uri = self._init_simple_dataset() + ds = dataset(existed_ds_uri) + assert isinstance(ds[0], DataRow) + assert ds[0].index == 0 # type: ignore + + items: t.List[DataRow] = ds[0:3] # type: ignore + assert isinstance(items, list) + assert len(items) == 3 + assert items[-1].index == 2 + + items: t.List[DataRow] = ds[:] # type: ignore + assert isinstance(items, list) + assert len(items) == 10 + + items: t.List[DataRow] = ds[8:] # type: ignore + assert isinstance(items, list) + assert len(items) == 2 + + items: t.List[DataRow] = ds[::2] # type: ignore + assert isinstance(items, list) + assert len(items) == 5 + assert items[0].index == 0 + assert items[1].index == 2 + assert items[4].index == 8 + + def test_get_item_by_str_id(self) -> None: + existed_ds_uri = self._init_simple_dataset_with_str_id() + ds = dataset(existed_ds_uri) + assert isinstance(ds["0"], DataRow) + assert ds["0"].index == "0" # type: ignore + + items: t.List[DataRow] = ds["0":"3"] # type: ignore + assert isinstance(items, list) + assert len(items) == 3 + assert items[-1].index == "2" + + items: t.List[DataRow] = ds[:] # type: ignore + assert isinstance(items, list) + assert len(items) == 10 + + items: t.List[DataRow] = ds["8":] # type: ignore + assert isinstance(items, list) + assert len(items) == 2 + + def test_tags(self) -> None: + existed_ds_uri = self._init_simple_dataset_with_str_id() + ds = dataset(existed_ds_uri) + tags = list(ds.tags) + assert tags == ["latest", "v0"] + + ds.tags.add("new_tag1") + ds.tags.add(["new_tag2", "new_tag3"]) + + tags = list(ds.tags) + assert set(tags) == set(["latest", "v0", "new_tag1", "new_tag2", "new_tag3"]) + + ds.tags.remove("new_tag1") + ds.tags.remove(["new_tag3", "new_tag2"]) + tags = list(ds.tags) + assert tags == ["latest", "v0"] + + ds.tags.remove("not_found", ignore_errors=True) + assert len(list(ds.tags)) == 2 + + with self.assertRaisesRegex(NotFoundError, "tag:not_found"): + ds.tags.remove("not_found") + + @Mocker() + def test_cloud_init(self, rm: Mocker) -> None: + rm.request( + HTTPMethod.HEAD, + "http://1.1.1.1/api/v1/project/self/dataset/not_found/version/1234/file", + json={"message": "not found"}, + status_code=HTTPStatus.NOT_FOUND, + ) + + with self.assertRaisesRegex(ExistedError, "was not found fo load"): + dataset("http://1.1.1.1/project/self/dataset/not_found/version/1234") + + rm.request( + HTTPMethod.HEAD, + "http://1.1.1.1/api/v1/project/self/dataset/mnist/version/1234/file", + json={"message": "existed"}, + status_code=HTTPStatus.OK, + ) + + rm.request( + HTTPMethod.GET, + "http://1.1.1.1/api/v1/project/self/dataset/mnist", + json={ + "data": { + "versionMeta": yaml.safe_dump( + {"dataset_summary": DatasetSummary(rows=101).asdict()} + ) + } + }, + status_code=HTTPStatus.OK, + ) + + ds = dataset("http://1.1.1.1/project/self/dataset/mnist/version/1234") + assert ds.exists() + _summary = ds.summary() + assert _summary is not None + assert _summary.rows == 101 + + def test_consumption(self) -> None: + existed_ds_uri = self._init_simple_dataset_with_str_id() + ds = dataset(existed_ds_uri) + + loader = ds._get_data_loader(disable_consumption=True) + + ds.make_distributed_consumption("1") + assert ds._consumption is not None + + consumption_loader = ds._get_data_loader(disable_consumption=False) + + another_loader = ds._get_data_loader(disable_consumption=True) + assert loader == another_loader + assert loader is not consumption_loader + assert loader.session_consumption is None + assert consumption_loader.session_consumption is not None + + def test_consumption_recreate_exception(self) -> None: + existed_ds_uri = self._init_simple_dataset_with_str_id() + ds = dataset(existed_ds_uri) + ds.make_distributed_consumption("1") + + with self.assertRaisesRegex( + RuntimeError, "distributed consumption has already been created" + ): + ds.make_distributed_consumption("2") + + def test_info(self) -> None: + existed_ds_uri = self._init_simple_dataset_with_str_id() + ds = dataset(existed_ds_uri) + + info = ds.info() + assert isinstance(info, dict) + assert info["name"] == ds.name + assert info["version"] == ds.version + assert info["tags"] == ["latest", "v0"] + assert info["project"] == ds.project_uri.project + + empty_ds = dataset("mnist", create=True) + info = empty_ds.info() + assert info == {} + + def test_remove_recover(self) -> None: + existed_ds_uri = self._init_simple_dataset_with_str_id() + ds = dataset(existed_ds_uri) + list_info, _ = ds.list(ds.project_uri, fullname=True) + assert isinstance(list_info, list) + assert list_info[0]["name"] == ds.name + assert list_info[0]["version"] == ds.version + assert list_info[0]["tags"] == ["latest", "v0"] + + ds.remove() + with self.assertRaisesRegex(RuntimeError, "failed to remove dataset"): + ds.remove() + + list_info, _ = ds.list(ds.project_uri, fullname=True) + assert list_info == [] + + ds.recover() + with self.assertRaisesRegex(RuntimeError, "failed to recover dataset"): + ds.recover() + + list_info, _ = ds.list(ds.project_uri, fullname=True) + assert list_info[0]["version"] == ds.version + + def test_history(self) -> None: + existed_ds_uri = self._init_simple_dataset_with_str_id() + str_ds = dataset(existed_ds_uri) + + history = str_ds.history() + assert len(history) == 1 + + existed_int_ds_uri = self._init_simple_dataset() + int_ds = dataset(existed_int_ds_uri) + + history = int_ds.history() + assert len(history) == 2 + assert history[0]["name"] == history[1]["name"] == int_ds.name + assert {history[0]["version"], history[1]["version"]} == { + int_ds.version, + str_ds.version, + } + + def test_diff(self) -> None: + existed_ds_uri = self._init_simple_dataset_with_str_id() + str_ds = dataset(existed_ds_uri) + + existed_int_ds_uri = self._init_simple_dataset() + int_ds = dataset(existed_int_ds_uri) + + diff = str_ds.diff(int_ds) + assert diff["diff_rows"]["updated"] == 10 + + diff = str_ds.diff(str_ds) + assert diff == {} + + diff = int_ds.diff(str_ds) + assert diff["diff_rows"]["updated"] == 10 + + def test_head(self) -> None: + existed_ds_uri = self._init_simple_dataset_with_str_id() + ds = dataset(existed_ds_uri) + + head = ds.head(n=0) + assert len(head) == 0 + + head = ds.head(n=1) + assert len(head) == 1 + assert head[0]["index"] == 0 + assert "raw" not in head[0]["data"] + + head = ds.head(n=2) + assert len(head) == 2 + assert head[0]["index"] == 0 + assert head[1]["index"] == 1 + assert "raw" not in head[0]["data"] + assert "raw" not in head[1]["data"] + + head = ds.head(n=2, show_raw_data=True) + assert len(head) == 2 + assert head[0]["data"]["raw"] == b"data-0" + assert head[1]["data"]["raw"] == b"data-1" + + @Mocker() + def test_copy(self, rm: Mocker) -> None: + existed_ds_uri = self._init_simple_dataset_with_str_id() + ds = dataset(existed_ds_uri) + + sw = SWCliConfigMixed() + sw.update_instance( + uri="http://1.1.1.1", user_name="test", sw_token="123", alias="test" + ) + + rm.request( + HTTPMethod.HEAD, + f"http://1.1.1.1/api/v1/project/self/dataset/mnist/version/{ds.version}", + json={"message": "existed"}, + status_code=HTTPStatus.OK, + ) + + rm.request( + HTTPMethod.POST, + f"http://1.1.1.1/api/v1/project/self/dataset/mnist/version/{ds.version}/file", + json={"data": {"upload_id": "123"}}, + ) + + ds.copy("cloud://test/project/self") diff --git a/client/tests/sdk/test_loader.py b/client/tests/sdk/test_loader.py index aba87e8563..52e8e23242 100644 --- a/client/tests/sdk/test_loader.py +++ b/client/tests/sdk/test_loader.py @@ -25,7 +25,11 @@ TabularDatasetRow, get_dataset_consumption, ) -from starwhale.api._impl.dataset.loader import SWDSBinDataLoader, UserRawDataLoader +from starwhale.api._impl.dataset.loader import ( + DataRow, + SWDSBinDataLoader, + UserRawDataLoader, +) class TestDataLoader(TestCase): @@ -133,7 +137,7 @@ def test_user_raw_local_store( "local/project/self/dataset/mnist/version/1122334455667788." ].key_prefix - loader = get_data_loader(self.dataset_uri) + loader = get_data_loader("mnist/version/1122334455667788") assert isinstance(loader, UserRawDataLoader) assert loader.session_consumption is None rows = list(loader) @@ -636,3 +640,32 @@ def test_remote_batch_sign( self.assertEqual(req_get_file.call_count, 4) self.assertEqual(len(_label_uris_map), 4) + + def test_data_row(self) -> None: + dr = DataRow(index=1, data=Image(), annotations={"label": 1}) + index, data, annotations = dr + assert index == 1 + assert isinstance(data, Image) + assert annotations == {"label": 1} + assert dr[0] == 1 + assert len(dr) == 3 + + dr_another = DataRow(index=2, data=Image(), annotations={"label": 2}) + assert dr < dr_another + assert dr != dr_another + + dr_third = DataRow(index=1, data=Image(fp=b""), annotations={"label": 10}) + assert dr >= dr_third + + dr_none = DataRow(index=1, data=None, annotations={}) + assert dr_none.data is None + + def test_data_row_exceptions(self) -> None: + with self.assertRaises(TypeError): + DataRow(index=b"", data=Image(), annotations={}) # type: ignore + + with self.assertRaises(TypeError): + DataRow(index=1, data=b"", annotations={}) # type: ignore + + with self.assertRaises(TypeError): + DataRow(index=1, data=Image(), annotations=1) # type: ignore diff --git a/client/tests/sdk/test_metric.py b/client/tests/sdk/test_metric.py index a9ad16cb60..b013e1bfa6 100644 --- a/client/tests/sdk/test_metric.py +++ b/client/tests/sdk/test_metric.py @@ -2,6 +2,7 @@ from pathlib import Path from unittest.mock import patch, MagicMock +import pytest from pyfakefs.fake_filesystem_unittest import TestCase from starwhale.api._impl.job import Context, context_holder @@ -16,6 +17,9 @@ def setUp(self) -> None: project="self", ) + @pytest.mark.filterwarnings( + "ignore::sklearn.metrics._classification.UndefinedMetricWarning" + ) @patch("starwhale.api._impl.wrapper.Evaluation.log_metrics") def test_multi_classification_metric(self, log_metric_mock: MagicMock) -> None: def _cmp(handler, data): @@ -39,6 +43,9 @@ def _cmp(handler, data): assert list(rt["labels"].keys()) == ["a", "b", "c", "d"] assert "confusion_matrix/binarylabel" not in rt + @pytest.mark.filterwarnings( + "ignore::sklearn.metrics._classification.UndefinedMetricWarning" + ) @patch("starwhale.api._impl.wrapper.Evaluation.log_metrics") @patch("starwhale.api._impl.wrapper.Evaluation.log") def test_multi_classification_metric_with_pa( diff --git a/scripts/client_test/cmds/artifacts_cmd.py b/scripts/client_test/cmds/artifacts_cmd.py index a57227277d..7f2b2aff5b 100644 --- a/scripts/client_test/cmds/artifacts_cmd.py +++ b/scripts/client_test/cmds/artifacts_cmd.py @@ -2,6 +2,7 @@ import typing as t from pathlib import Path +from starwhale.utils.load import import_object from starwhale.core.model.view import ModelTermView from starwhale.core.dataset.type import DatasetConfig from starwhale.core.dataset.view import DatasetTermView @@ -213,7 +214,7 @@ def build_with_api( config = DatasetConfig() if yaml_path.exists(): config = DatasetConfig.create_by_yaml(yaml_path) - config.handler = handler or config.handler + config.handler = import_object(workdir, handler or config.handler) _uri = DatasetTermView.build(workdir, config) LocalDataStore.get_instance().dump() return _uri