From ee3acba97054d28c1c5cd99c19dd1b051204c42d Mon Sep 17 00:00:00 2001 From: tianwei Date: Tue, 23 Aug 2022 10:22:55 +0800 Subject: [PATCH] feat(dataset): support `Link` type for dataset (#944) * support Link type for dataset * add test for dataset loader --- client/starwhale/api/_impl/dataset.py | 268 +++++++++++++++-- client/starwhale/api/_impl/loader.py | 274 +++++++++++------ client/starwhale/api/_impl/model.py | 13 +- client/starwhale/api/dataset.py | 6 + client/starwhale/consts/__init__.py | 1 + client/starwhale/core/dataset/dataset.py | 21 +- client/starwhale/core/dataset/model.py | 2 +- client/starwhale/utils/__init__.py | 14 + client/starwhale/utils/fs.py | 6 + client/tests/sdk/test_dataset.py | 21 +- client/tests/sdk/test_loader.py | 359 +++++++++++++++++++++++ client/tests/sdk/test_model.py | 21 +- client/tests/utils/test_common.py | 35 ++- example/PennFudanPed/code/utils.py | 145 +++++---- example/mnist/dataset.yaml | 2 +- example/mnist/mnist/ppl.py | 36 +-- example/mnist/mnist/process.py | 31 +- 17 files changed, 1015 insertions(+), 240 deletions(-) create mode 100644 client/tests/sdk/test_loader.py diff --git a/client/starwhale/api/_impl/dataset.py b/client/starwhale/api/_impl/dataset.py index 262efb7615..324a59ca71 100644 --- a/client/starwhale/api/_impl/dataset.py +++ b/client/starwhale/api/_impl/dataset.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os import sys import json import shutil @@ -7,22 +8,25 @@ import typing as t from abc import ABCMeta, abstractmethod from copy import deepcopy +from enum import Enum, unique from types import TracebackType from pathlib import Path from binascii import crc32 +from functools import partial import jsonlines from loguru import logger from starwhale.utils import console, validate_obj_name from starwhale.consts import ( + AUTH_ENV_FNAME, VERSION_PREFIX_CNT, STANDALONE_INSTANCE, SWDS_DATA_FNAME_FMT, DUMPED_SWDS_META_FNAME, ) from starwhale.base.uri import URI -from starwhale.utils.fs import ensure_dir +from starwhale.utils.fs import ensure_dir, FilePosition from starwhale.base.type import ( URIType, InstanceType, @@ -31,6 +35,7 @@ ObjectStoreType, ) from starwhale.utils.error import ( + FormatError, NotFoundError, NoSupportError, InvalidObjectName, @@ -52,6 +57,173 @@ _header_version = 0 +@unique +class LinkType(Enum): + FUSE = "fuse" + S3 = "s3" + UNDEFINED = "undefined" + # TODO: support hdfs, http, ssh link type + + +@unique +class MIMEType(Enum): + PNG = "image/png" + JPEG = "image/jpeg" + WEBP = "image/webp" + SVG = "image/svg+xml" + GIF = "image/gif" + APNG = "image/apng" + AVIF = "image/avif" + MP4 = "video/mp4" + AVI = "video/avi" + WAV = "audio/wav" + MP3 = "audio/mp3" + PLAIN = "text/plain" + CSV = "text/csv" + HTML = "text/html" + GRAYSCALE = "x/grayscale" + UNDEFINED = "x/undefined" + + @classmethod + def create_by_file_suffix(cls, name: str) -> MIMEType: + # ref: https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/MIME_types + _map = { + ".png": cls.PNG, + ".jpeg": cls.JPEG, + ".jpg": cls.JPEG, + ".jfif": cls.JPEG, + ".pjpeg": cls.JPEG, + ".pjp": cls.JPEG, + ".webp": cls.WEBP, + ".svg": cls.SVG, + ".gif": cls.GIF, + ".apng": cls.APNG, + ".htm": cls.HTML, + ".html": cls.HTML, + ".mp3": cls.MP3, + ".mp4": cls.MP4, + ".avif": cls.AVIF, + ".avi": cls.AVI, + ".wav": cls.WAV, + ".csv": cls.CSV, + ".txt": cls.PLAIN, + } + return _map.get(Path(name).suffix, MIMEType.UNDEFINED) + + +_LAType = t.TypeVar("_LAType", bound="LinkAuth") + + +class LinkAuth(metaclass=ABCMeta): + def __init__(self, name: str = "", ltype: LinkType = LinkType.UNDEFINED) -> None: + self.name = name.strip() + self.ltype = ltype + self._do_validate() + + def _do_validate(self) -> None: + if self.ltype not in LinkType: + raise NoSupportError(f"Link Type: {self.ltype}") + + @abstractmethod + def dump_env(self) -> t.List[str]: + raise NotImplementedError + + @classmethod + def from_env(cls: t.Type[_LAType], name: str = "") -> _LAType: + raise NotImplementedError + + +class S3LinkAuth(LinkAuth): + _ENDPOINT_FMT = "USER.S3.{name}ENDPOINT" + _REGION_FMT = "USER.S3.{name}REGION" + _SECRET_FMT = "USER.S3.{name}SECRET" + _ACCESS_KEY_FMT = "USER.S3.{name}ACCESS_KEY" + _fmt: t.Callable[[str], str] = ( + lambda x: (f"{x}." if x.strip() else x).strip().upper() + ) + + def __init__( + self, + name: str = "", + access_key: str = "", + secret: str = "", + endpoint: str = "", + region: str = "", + ) -> None: + super().__init__(name, LinkType.S3) + self.access_key = access_key + self.secret = secret + self.endpoint = endpoint + self.region = region + + def dump_env(self) -> t.List[str]: + _name = S3LinkAuth._fmt(self.name) + _map = { + self._SECRET_FMT: self.secret, + self._REGION_FMT: self.region, + self._ACCESS_KEY_FMT: self.access_key, + self._ENDPOINT_FMT: self.endpoint, + } + return [f"{k.format(name=_name)}={v}" for k, v in _map.items()] + + @classmethod + def from_env(cls, name: str = "") -> S3LinkAuth: + _env = os.environ.get + + _name = cls._fmt(name) + _secret_name = cls._SECRET_FMT.format(name=_name) + _access_name = cls._ACCESS_KEY_FMT.format(name=_name) + + _secret = _env(_secret_name, "") + _access = _env(_access_name, "") + return cls( + name, + _access, + _secret, + endpoint=_env(cls._ENDPOINT_FMT.format(name=_name), ""), + region=_env(cls._REGION_FMT.format(name=_name), ""), + ) + + +FuseLinkAuth = partial(LinkAuth, ltype=LinkType.FUSE) +DefaultS3LinkAuth = S3LinkAuth() + + +class Link: + def __init__( + self, + uri: str, + auth: t.Optional[LinkAuth] = DefaultS3LinkAuth, + offset: int = FilePosition.START, + size: int = -1, + mime_type: MIMEType = MIMEType.UNDEFINED, + ) -> None: + self.uri = uri.strip() + self.offset = offset + self.size = size + self.auth = auth + + if mime_type == MIMEType.UNDEFINED or mime_type not in MIMEType: + self.mime_type = MIMEType.create_by_file_suffix(self.uri) + else: + self.mime_type = mime_type + + self.do_validate() + + def do_validate(self) -> None: + if self.offset < 0: + raise FieldTypeOrValueError(f"offset({self.offset}) must be non-negative") + + if self.size < -1: + raise FieldTypeOrValueError(f"size({self.size}) must be non-negative or -1") + + def __str__(self) -> str: + return f"Link {self.uri}" + + def __repr__(self) -> str: + return f"Link uri:{self.uri}, offset:{self.offset}, size:{self.size}, mime type:{self.mime_type}" + + class TabularDatasetRow: def __init__( self, @@ -63,6 +235,8 @@ def __init__( data_offset: int = 0, data_size: int = 0, data_origin: DataOriginType = DataOriginType.NEW, + data_mime_type: MIMEType = MIMEType.UNDEFINED, + auth_name: str = "", **kw: t.Any, ) -> None: self.id = id @@ -72,7 +246,9 @@ def __init__( self.data_size = data_size self.data_origin = data_origin self.object_store_type = object_store_type + self.data_mime_type = data_mime_type self.label = label.encode() if isinstance(label, str) else label + self.auth_name = auth_name # TODO: add non-starwhale object store related fields, such as address, authority # TODO: add data uri crc for versioning @@ -101,13 +277,17 @@ def __str__(self) -> str: return f"row-{self.id}, data-{self.data_uri}, origin-[{self.data_origin}]" def __repr__(self) -> str: - return f"row-{self.id}, data-{self.data_uri}(offset:{self.data_offset}, size:{self.data_size}, format:{self.data_format}), origin-[{self.data_origin}], object store-{self.object_store_type}" + return ( + f"row-{self.id}, data-{self.data_uri}(offset:{self.data_offset}, size:{self.data_size}," + f"format:{self.data_format}, mime type:{self.data_mime_type}), " + f"origin-[{self.data_origin}], object store-{self.object_store_type}" + ) def asdict(self) -> t.Dict[str, t.Union[str, bytes, int]]: d = deepcopy(self.__dict__) - d["data_format"] = self.data_format.value - d["data_origin"] = self.data_origin.value - d["object_store_type"] = self.object_store_type.value + for k, v in d.items(): + if isinstance(v, Enum): + d[k] = v.value return d @@ -269,7 +449,7 @@ def __init__( dataset_version: str, project_name: str, data_dir: Path = Path("."), - output_dir: Path = Path("./sw_output"), + workdir: Path = Path("./sw_output"), data_filter: str = "*", label_filter: str = "*", alignment_bytes_size: int = D_ALIGNMENT_SIZE, @@ -280,7 +460,8 @@ def __init__( self.data_dir = data_dir self.data_filter = data_filter self.label_filter = label_filter - self.output_dir = output_dir + self.workdir = workdir + self.data_output_dir = workdir / "data" self.alignment_bytes_size = alignment_bytes_size self.volume_bytes_size = volume_bytes_size @@ -295,7 +476,7 @@ def __init__( self._prepare() def _prepare(self) -> None: - self.output_dir.mkdir(parents=True, exist_ok=True) + ensure_dir(self.data_output_dir) def __enter__(self: _BDType) -> _BDType: return self @@ -412,14 +593,16 @@ def data_format_type(self) -> DataFormatType: def make_swds(self) -> DatasetSummary: # TODO: add lock fno, wrote_size = 0, 0 - dwriter = (self.output_dir / self._DATA_FMT.format(index=fno)).open("wb") - object_store_type = ObjectStoreType.LOCAL + dwriter = (self.data_output_dir / self._DATA_FMT.format(index=fno)).open("wb") rows, increased_rows = 0, 0 total_label_size, total_data_size = 0, 0 for idx, ((_, data), (_, label)) in enumerate( zip(self.iter_all_dataset_slice(), self.iter_all_label_slice()) ): + if not isinstance(data, bytes) or not isinstance(label, bytes): + raise FormatError("data and label must be bytes type") + # TODO: support inherit data from old dataset version data_origin = DataOriginType.NEW data_offset, data_size = self._write(dwriter, idx, data) @@ -445,9 +628,9 @@ def make_swds(self) -> DatasetSummary: fno += 1 dwriter.close() - dwriter = (self.output_dir / self._DATA_FMT.format(index=fno)).open( - "wb" - ) + dwriter = ( + self.data_output_dir / self._DATA_FMT.format(index=fno) + ).open("wb") rows += 1 if data_origin == DataOriginType.NEW: @@ -461,10 +644,10 @@ def make_swds(self) -> DatasetSummary: summary = DatasetSummary( rows=rows, increased_rows=increased_rows, - data_format_type=self.data_format_type, - object_store_type=object_store_type, label_byte_size=total_label_size, data_byte_size=total_data_size, + include_user_raw=False, + include_link=False, ) return summary @@ -483,28 +666,45 @@ class UserRawBuildExecutor(BaseBuildExecutor): def make_swds(self) -> DatasetSummary: rows, increased_rows = 0, 0 total_label_size, total_data_size = 0, 0 - object_store_type = ObjectStoreType.LOCAL ds_copy_candidates = {} + auth_candidates = {} + include_link = False - for idx, ((data_path, data), (_, label)) in enumerate( + for idx, (data, (_, label)) in enumerate( zip(self.iter_all_dataset_slice(), self.iter_all_label_slice()) ): - data_origin = DataOriginType.NEW - data_offset, data_size = data + if isinstance(data, Link): + data_uri = data.uri + data_offset, data_size = data.offset, data.size + if data.auth: + auth = data.auth.name + auth_candidates[f"{data.auth.ltype}.{data.auth.name}"] = data.auth + else: + auth = "" + object_store_type = ObjectStoreType.REMOTE + include_link = True + elif isinstance(data, (tuple, list)): + data_path, (data_offset, data_size) = data + auth = "" + data_uri = str(Path(data_path).relative_to(self.data_dir)) + ds_copy_candidates[data_uri] = data_path + object_store_type = ObjectStoreType.LOCAL + else: + raise FormatError(f"data({data}) type error, no list, tuple or Link") - relative_path = str(Path(data_path).relative_to(self.data_dir)) - ds_copy_candidates[relative_path] = data_path + data_origin = DataOriginType.NEW self.tabular_dataset.put( TabularDatasetRow( id=idx, - data_uri=str(relative_path), + data_uri=str(data_uri), label=label, data_format=self.data_format_type, object_store_type=object_store_type, data_offset=data_offset, data_size=data_size, data_origin=data_origin, + auth_name=auth, ) ) @@ -515,21 +715,33 @@ def make_swds(self) -> DatasetSummary: if data_origin == DataOriginType.NEW: increased_rows += 1 - for fname, src in ds_copy_candidates.items(): - dest = self.output_dir / fname - ensure_dir(dest.parent) - shutil.copyfile(str(src.absolute()), str(dest.absolute())) + self._copy_files(ds_copy_candidates) + self._copy_auth(auth_candidates) summary = DatasetSummary( rows=rows, increased_rows=increased_rows, - data_format_type=self.data_format_type, - object_store_type=object_store_type, label_byte_size=total_label_size, data_byte_size=total_data_size, + include_link=include_link, + include_user_raw=True, ) return summary + def _copy_files(self, ds_copy_candidates: t.Dict[str, Path]) -> None: + for fname, src in ds_copy_candidates.items(): + dest = self.data_output_dir / fname + ensure_dir(dest.parent) + shutil.copyfile(str(src.absolute()), str(dest.absolute())) + + def _copy_auth(self, auth_candidates: t.Dict[str, LinkAuth]) -> None: + if not auth_candidates: + return + + with (self.workdir / AUTH_ENV_FNAME).open("w") as f: + for auth in auth_candidates.values(): + f.write("\n".join(auth.dump_env())) + def iter_data_slice(self, path: str) -> t.Generator[t.Any, None, None]: yield 0, Path(path).stat().st_size diff --git a/client/starwhale/api/_impl/loader.py b/client/starwhale/api/_impl/loader.py index 1e556e5a68..bbbdd114b6 100644 --- a/client/starwhale/api/_impl/loader.py +++ b/client/starwhale/api/_impl/loader.py @@ -6,6 +6,7 @@ import typing as t from abc import ABCMeta, abstractmethod from pathlib import Path +from urllib.parse import urlparse import boto3 import loguru @@ -13,17 +14,21 @@ from botocore.client import Config as S3Config from typing_extensions import Protocol -from starwhale.consts import SWDSBackendType +from starwhale.utils import load_dotenv +from starwhale.consts import AUTH_ENV_FNAME, SWDSBackendType from starwhale.base.uri import URI -from starwhale.base.type import URIType, InstanceType, DataFormatType -from starwhale.utils.error import NoSupportError, FieldTypeOrValueError +from starwhale.utils.fs import FilePosition +from starwhale.base.type import URIType, InstanceType, DataFormatType, ObjectStoreType +from starwhale.utils.error import FormatError, NoSupportError, FieldTypeOrValueError from starwhale.core.dataset.store import DatasetStorage -from .dataset import TabularDataset, TabularDatasetRow +from .dataset import S3LinkAuth, TabularDataset, TabularDatasetRow # TODO: config chunk size _CHUNK_SIZE = 8 * 1024 * 1024 # 8MB -_FILE_END_POS = -1 +_DEFAULT_S3_REGION = "local" +_DEFAULT_S3_ENDPOINT = "localhost:9000" +_DEFAULT_S3_BUCKET = "starwhale" class FileLikeObj(Protocol): @@ -34,6 +39,13 @@ def read(self, size: int) -> t.Union[bytes, memoryview]: ... +class S3Uri(t.NamedTuple): + bucket: str + key: str + protocol: str = "s3" + endpoint: str = _DEFAULT_S3_ENDPOINT + + class ObjectStoreS3Connection: def __init__( self, @@ -41,14 +53,19 @@ def __init__( access_key: str, secret_key: str, region: str = "", + bucket: str = "", connect_timeout: float = 10.0, read_timeout: float = 50.0, total_max_attempts: int = 6, ) -> None: - self.endpoint = endpoint + self.endpoint = endpoint.strip() + if self.endpoint and not self.endpoint.startswith(("http://", "https://")): + self.endpoint = f"http://{self.endpoint}" + self.access_key = access_key self.secret_key = secret_key self.region = region + self.bucket = bucket self.connect_timeout = float( os.environ.get("SW_S3_CONNECT_TIMEOUT", connect_timeout) ) @@ -63,70 +80,127 @@ def __str__(self) -> str: __repr__ = __str__ @classmethod - def create_from_env(cls) -> ObjectStoreS3Connection: + def from_uri(cls, uri: str, auth_name: str) -> ObjectStoreS3Connection: + """make S3 Connection by uri + + uri: + - s3://username:password@127.0.0.1:8000@bucket/key + - s3://127.0.0.1:8000@bucket/key + - s3://bucket/key + """ + uri = uri.strip() + if not uri or not uri.startswith("s3://"): + raise NoSupportError( + f"s3 connection only support s3:// prefix, the actual uri is {uri}" + ) + + r = urlparse(uri) + netloc = r.netloc + + link_auth = S3LinkAuth.from_env(auth_name) + access = link_auth.access_key + secret = link_auth.secret + region = link_auth.region + + _nl = netloc.split("@") + if len(_nl) == 1: + endpoint = link_auth.endpoint + bucket = _nl[0] + elif len(_nl) == 2: + endpoint, bucket = _nl + elif len(_nl) == 3: + _key, endpoint, bucket = _nl + access, secret = _key.split(":", 1) + else: + raise FormatError(netloc) + + if not endpoint: + raise FieldTypeOrValueError("endpoint is empty") + + if not access or not secret: + raise FieldTypeOrValueError("no access_key or secret_key") + + if not bucket: + raise FieldTypeOrValueError("bucket is empty") + + return cls( + endpoint=endpoint, + access_key=access, + secret_key=secret, + region=region or _DEFAULT_S3_REGION, + bucket=bucket, + ) + + @classmethod + def from_env(cls) -> ObjectStoreS3Connection: # TODO: support multi s3 backend servers + _env = os.environ.get return ObjectStoreS3Connection( - endpoint=os.environ.get("SW_S3_ENDPOINT", "127.0.0.1:9000"), - access_key=os.environ.get("SW_S3_ACCESS_KEY", "foo"), - secret_key=os.environ.get("SW_S3_SECRET", "bar"), - region=os.environ.get("SW_S3_REGION", "local"), + endpoint=_env("SW_S3_ENDPOINT", _DEFAULT_S3_ENDPOINT), + access_key=_env("SW_S3_ACCESS_KEY", ""), + secret_key=_env("SW_S3_SECRET", ""), + region=_env("SW_S3_REGION", _DEFAULT_S3_REGION), + bucket=_env("SW_S3_BUCKET", _DEFAULT_S3_BUCKET), ) class DatasetObjectStore: def __init__( self, - uri: URI, - backend: str = "", - conn: t.Optional[ObjectStoreS3Connection] = None, - bucket: str = "", + backend: str, + bucket: str, key_prefix: str = "", + **kw: t.Any, ) -> None: - self.uri = uri - _backend = backend or self._get_default_backend() + self.bucket = bucket - self.conn: t.Optional[ObjectStoreS3Connection] self.backend: StorageBackend - - _env_bucket = os.environ.get("SW_S3_BUCKET", "") - - if _backend == SWDSBackendType.S3: - self.conn = conn or ObjectStoreS3Connection.create_from_env() - self.bucket = bucket or _env_bucket - self.backend = S3StorageBackend(self.conn) + if backend == SWDSBackendType.S3: + conn = kw.get("conn") or ObjectStoreS3Connection.from_env() + self.backend = S3StorageBackend(conn) else: - self.conn = None - self.bucket = bucket or _env_bucket or self._get_bucket_by_uri() self.backend = FuseStorageBackend() self.key_prefix = key_prefix or os.environ.get("SW_OBJECT_STORE_KEY_PREFIX", "") - self._do_validate() + def __str__(self) -> str: + return f"DatasetObjectStore backend:{self.backend}" - def _do_validate(self) -> None: - if self.uri.object.typ != URIType.DATASET: - raise NoSupportError(f"{self.uri} is not dataset uri") + def __repr__(self) -> str: + return f"DatasetObjectStore backend:{self.backend}, bucket:{self.bucket}, key_prefix:{self.key_prefix}" + + @classmethod + def from_data_link_uri(cls, data_uri: str, auth_name: str) -> DatasetObjectStore: + data_uri = data_uri.strip() + if not data_uri: + raise FieldTypeOrValueError("data_uri is empty") + + # TODO: support other uri type + if data_uri.startswith("s3://"): + backend = SWDSBackendType.S3 + conn = ObjectStoreS3Connection.from_uri(data_uri, auth_name) + bucket = conn.bucket + else: + backend = SWDSBackendType.FUSE + bucket = "" + conn = None - if not self.bucket: - raise FieldTypeOrValueError("no bucket field") + return cls(backend=backend, bucket=bucket, conn=conn) - def _get_default_backend(self) -> str: - _type = self.uri.instance_type + @classmethod + def from_dataset_uri(cls, dataset_uri: URI) -> DatasetObjectStore: + if dataset_uri.object.typ != URIType.DATASET: + raise NoSupportError(f"{dataset_uri} is not dataset uri") + _type = dataset_uri.instance_type if _type == InstanceType.STANDALONE: - return SWDSBackendType.FUSE - elif _type == InstanceType.CLOUD: - return SWDSBackendType.S3 + backend = SWDSBackendType.FUSE + bucket = str(DatasetStorage(dataset_uri).data_dir.absolute()) else: - raise NoSupportError( - f"get object store backend by the instance type({_type})" - ) - - def _get_bucket_by_uri(self) -> str: - if self.uri.instance_type == InstanceType.CLOUD: - raise NoSupportError(f"{self.uri} to fetch bucket") + backend = SWDSBackendType.S3 + bucket = os.environ.get("SW_S3_BUCKET", _DEFAULT_S3_BUCKET) - return str(DatasetStorage(self.uri).data_dir.absolute()) + return cls(backend=backend, bucket=bucket) class DataField(t.NamedTuple): @@ -156,31 +230,74 @@ def __iter__(self) -> t.Any: class DataLoader(metaclass=ABCMeta): def __init__( self, - storage: DatasetObjectStore, - dataset: TabularDataset, + dataset_uri: URI, + start: int = 0, + end: int = sys.maxsize, logger: t.Union[loguru.Logger, None] = None, - deserializer: t.Optional[t.Callable] = None, ): - self.storage = storage - self.dataset = dataset + self.dataset_uri = dataset_uri + self.start = start + self.end = end self.logger = logger or _logger - self.deserializer = deserializer + # TODO: refactor TabularDataset with dataset_uri + # TODO: refactor dataset, tabular_dataset and standalone dataset module + self.tabular_dataset = TabularDataset.from_uri( + dataset_uri, start=start, end=end + ) + self._stores: t.Dict[str, DatasetObjectStore] = {} + + self._load_dataset_auth_env() + + def _load_dataset_auth_env(self) -> None: + # TODO: support multi datasets + if self.dataset_uri.instance_type == InstanceType.STANDALONE: + auth_env_fpath = ( + DatasetStorage(self.dataset_uri).snapshot_workdir / AUTH_ENV_FNAME + ) + load_dotenv(auth_env_fpath) + + def _get_store(self, row: TabularDatasetRow) -> DatasetObjectStore: + _k = f"{row.object_store_type.value}.{row.auth_name}" + _store = self._stores.get(_k) + if _store: + return _store + + if row.object_store_type == ObjectStoreType.REMOTE: + _store = DatasetObjectStore.from_data_link_uri(row.data_uri, row.auth_name) + else: + _store = DatasetObjectStore.from_dataset_uri(self.dataset_uri) + + self._stores[_k] = _store + return _store + + def _get_key_compose( + self, row: TabularDatasetRow, store: DatasetObjectStore + ) -> str: + if row.object_store_type == ObjectStoreType.REMOTE: + data_uri = urlparse(row.data_uri).path + else: + data_uri = row.data_uri + if store.key_prefix: + data_uri = os.path.join(store.key_prefix, data_uri.lstrip("/")) + + _key_compose = ( + f"{data_uri}:{row.data_offset}:{row.data_offset + row.data_size - 1}" + ) + return _key_compose def __iter__(self) -> t.Generator[t.Tuple[DataField, DataField], None, None]: - _attr = {"ds_name": self.dataset.name, "ds_version": self.dataset.version} - for row in self.dataset.scan(): + _attr = { + "ds_name": self.tabular_dataset.name, + "ds_version": self.tabular_dataset.version, + } + for row in self.tabular_dataset.scan(): # TODO: tune performance by fetch in batch # TODO: remove ext_attr field + _store = self._get_store(row) + _key_compose = self._get_key_compose(row, _store) - _data_uri = row.data_uri - if self.storage.key_prefix: - _data_uri = os.path.join(self.storage.key_prefix, _data_uri.lstrip("/")) - - _key_compose = ( - f"{_data_uri}:{row.data_offset}:{row.data_offset + row.data_size - 1}" - ) - self.logger.info(f"@{self.storage.bucket}/{_key_compose}") - _file = self.storage.backend._make_file(self.storage.bucket, _key_compose) + self.logger.info(f"@{_store.bucket}/{_key_compose}") + _file = _store.backend._make_file(_store.bucket, _key_compose) for data_content, data_size in self._do_iter(_file, row): label = DataField( idx=row.id, @@ -201,10 +318,10 @@ def _do_iter( raise NotImplementedError def __str__(self) -> str: - return f"[{self.kind.name}]DataLoader for {self.storage.backend}" + return f"[{self.kind.name}]DataLoader for {self.dataset_uri}" def __repr__(self) -> str: - return f"[{self.kind.name}]DataLoader for {self.storage.backend}, extra:{self.storage.conn}" + return f"[{self.kind.name}]DataLoader for {self.dataset_uri}, start:{self.start}, end:{self.end}" @property def kind(self) -> DataFormatType: @@ -261,9 +378,9 @@ def _parse_key(self, key: str) -> t.Tuple[str, int, int]: # TODO: add start end normalize _r = key.split(":") if len(_r) == 1: - return _r[0], 0, _FILE_END_POS + return _r[0], 0, FilePosition.END elif len(_r) == 2: - return _r[0], int(_r[1]), _FILE_END_POS + return _r[0], int(_r[1]), FilePosition.END else: return _r[0], int(_r[1]), int(_r[2]) @@ -388,7 +505,7 @@ def close(self) -> None: def _next_data(self) -> t.Tuple[bytes, int]: end = _CHUNK_SIZE + self._current_s3_start - 1 - end = end if self.end == _FILE_END_POS else min(self.end, end) + end = end if self.end == FilePosition.END else min(self.end, end) data, length = self._do_fetch_data(self._current_s3_start, end) self._current_s3_start += length @@ -397,7 +514,7 @@ def _next_data(self) -> t.Tuple[bytes, int]: def _do_fetch_data(self, _start: int, _end: int) -> t.Tuple[bytes, int]: # TODO: add more exception handle - if self._s3_eof or (_end != _FILE_END_POS and _end < _start): + if self._s3_eof or (_end != FilePosition.END and _end < _start): return b"", 0 resp = self.obj.get(Range=f"bytes={_start}-{_end}") @@ -406,7 +523,7 @@ def _do_fetch_data(self, _start: int, _end: int) -> t.Tuple[bytes, int]: out = resp["Body"].read() body.close() - self._s3_eof = _end == _FILE_END_POS or (_end - _start + 1) > length + self._s3_eof = _end == FilePosition.END or (_end - _start + 1) > length return out, length @@ -414,20 +531,11 @@ def get_data_loader( dataset_uri: URI, start: int = 0, end: int = sys.maxsize, - backend: str = "", logger: t.Union[loguru.Logger, None] = None, ) -> DataLoader: from starwhale.core.dataset import model - logger = logger or _logger - object_store = DatasetObjectStore(dataset_uri, backend) - # TODO: refactor dataset, tabular_dataset and standalone dataset module - tabular_dataset = TabularDataset.from_uri(dataset_uri, start=start, end=end) - df_type = model.Dataset.get_dataset(dataset_uri).summary().data_format_type - - if df_type == DataFormatType.SWDS_BIN: - return SWDSBinDataLoader(object_store, tabular_dataset, logger) - elif df_type == DataFormatType.USER_RAW: - return UserRawDataLoader(object_store, tabular_dataset, logger) - else: - raise NoSupportError(f"cannot get data format type({df_type}) data loader") + summary = model.Dataset.get_dataset(dataset_uri).summary() + include_user_raw = summary.include_user_raw + _cls = UserRawDataLoader if include_user_raw else SWDSBinDataLoader + return _cls(dataset_uri, start, end, logger or _logger) diff --git a/client/starwhale/api/_impl/model.py b/client/starwhale/api/_impl/model.py index 51f4b28182..a721eabc05 100644 --- a/client/starwhale/api/_impl/model.py +++ b/client/starwhale/api/_impl/model.py @@ -17,24 +17,18 @@ import loguru import jsonlines -from starwhale.utils import now_str, in_production +from starwhale.utils import now_str from starwhale.consts import CURRENT_FNAME from starwhale.base.uri import URI from starwhale.utils.fs import ensure_dir, ensure_file from starwhale.base.type import URIType, RunSubDirType from starwhale.utils.log import StreamWrapper +from starwhale.utils.error import FieldTypeOrValueError from starwhale.api._impl.job import Context from starwhale.api._impl.loader import DataField, ResultLoader, get_data_loader from starwhale.api._impl.wrapper import Evaluation from starwhale.core.dataset.model import Dataset -_TASK_ROOT_DIR = "/var/starwhale" if in_production() else "/tmp/starwhale" - -_ptype = t.Union[str, None, Path] -_p: t.Callable[[_ptype, str], Path] = ( - lambda p, sub: Path(p) if p else Path(_TASK_ROOT_DIR) / sub -) - class _LogType: SW = "starwhale" @@ -95,6 +89,7 @@ def __init__( self._ppl_data_field = "result" self._label_field = "label" self.evaluation = self._init_datastore() + self._monkey_patch() def _init_dir(self) -> None: @@ -258,7 +253,7 @@ def _starwhale_internal_run_cmp(self) -> None: def _starwhale_internal_run_ppl(self) -> None: self._update_status(self.STATUS.START) if not self.context.dataset_uris: - raise RuntimeError("no dataset uri!") + raise FieldTypeOrValueError("context.dataset_uris is empty") # TODO: support multi dataset uris _dataset_uri = URI(self.context.dataset_uris[0], expected_type=URIType.DATASET) _dataset = Dataset.get_dataset(_dataset_uri) diff --git a/client/starwhale/api/dataset.py b/client/starwhale/api/dataset.py index 751490c26a..3dad877aed 100644 --- a/client/starwhale/api/dataset.py +++ b/client/starwhale/api/dataset.py @@ -1,4 +1,7 @@ from ._impl.dataset import ( + Link, + MIMEType, + S3LinkAuth, BuildExecutor, MNISTBuildExecutor, SWDSBinBuildExecutor, @@ -12,4 +15,7 @@ "MNISTBuildExecutor", "UserRawBuildExecutor", "SWDSBinBuildExecutor", + "S3LinkAuth", + "Link", + "MIMEType", ] diff --git a/client/starwhale/consts/__init__.py b/client/starwhale/consts/__init__.py index 824818a385..b64fc181e0 100644 --- a/client/starwhale/consts/__init__.py +++ b/client/starwhale/consts/__init__.py @@ -146,3 +146,4 @@ class SWDSSubFileType: DEFAULT_CONDA_CHANNEL = "conda-forge" WHEEL_FILE_EXTENSION = ".whl" +AUTH_ENV_FNAME = ".auth_env" diff --git a/client/starwhale/core/dataset/dataset.py b/client/starwhale/core/dataset/dataset.py index 130477febf..57e554842f 100644 --- a/client/starwhale/core/dataset/dataset.py +++ b/client/starwhale/core/dataset/dataset.py @@ -7,7 +7,6 @@ from starwhale.utils import load_yaml, convert_to_bytes from starwhale.consts import DEFAULT_STARWHALE_API_VERSION -from starwhale.base.type import DataFormatType, ObjectStoreType from starwhale.utils.error import NoSupportError @@ -25,27 +24,19 @@ def __init__( self, rows: int = 0, increased_rows: int = 0, - data_format_type: t.Union[DataFormatType, str] = DataFormatType.UNDEFINED, - object_store_type: t.Union[ObjectStoreType, str] = ObjectStoreType.UNDEFINED, label_byte_size: int = 0, data_byte_size: int = 0, + include_link: bool = False, + include_user_raw: bool = False, **kw: t.Any, ) -> None: self.rows = rows self.increased_rows = increased_rows self.unchanged_rows = rows - increased_rows - self.data_format_type: DataFormatType = ( - DataFormatType(data_format_type) - if isinstance(data_format_type, str) - else data_format_type - ) - self.object_store_type: ObjectStoreType = ( - ObjectStoreType(object_store_type) - if isinstance(object_store_type, str) - else object_store_type - ) self.label_byte_size = label_byte_size self.data_byte_size = data_byte_size + self.include_link = include_link + self.include_user_raw = include_user_raw def as_dict(self) -> t.Dict[str, t.Any]: d = deepcopy(self.__dict__) @@ -55,12 +46,12 @@ def as_dict(self) -> t.Dict[str, t.Any]: return d def __str__(self) -> str: - return f"Dataset Summary: rows({self.rows}), data_format({self.data_format_type}), object_store({self.object_store_type})" + return f"Dataset Summary: rows({self.rows}), include user-raw({self.include_user_raw}), include link({self.include_link})" def __repr__(self) -> str: return ( f"Dataset Summary: rows({self.rows}, increased: {self.increased_rows}), " - f"data_format({self.data_format_type}), object_store({self.object_store_type})," + f"include user-raw({self.include_user_raw}), include link({self.include_link})," f"size(data:{self.data_byte_size}, label: {self.label_byte_size})" ) diff --git a/client/starwhale/core/dataset/model.py b/client/starwhale/core/dataset/model.py index 7fdbb9ab29..d7eabf5dfc 100644 --- a/client/starwhale/core/dataset/model.py +++ b/client/starwhale/core/dataset/model.py @@ -299,7 +299,7 @@ def _call_make_swds(self, workdir: Path, swds_config: DatasetConfig) -> None: dataset_version=self._version, project_name=self.uri.project, data_dir=workdir / swds_config.data_dir, - output_dir=self.store.data_dir, + workdir=self.store.snapshot_workdir, data_filter=swds_config.data_filter, label_filter=swds_config.label_filter, alignment_bytes_size=swds_config.attr.alignment_size, diff --git a/client/starwhale/utils/__init__.py b/client/starwhale/utils/__init__.py index d18a67cefb..56c07e3a38 100644 --- a/client/starwhale/utils/__init__.py +++ b/client/starwhale/utils/__init__.py @@ -197,3 +197,17 @@ def make_dir_gitignore(d: Path) -> None: ensure_dir(d) ensure_file(d / ".gitignore", "*") + + +def load_dotenv(fpath: Path) -> None: + if not fpath.exists(): + return + + with fpath.open("r") as f: + for line in f.readlines(): + line = line.strip() + if not line or line.startswith("#") or "=" not in line: + continue + + k, v = line.split("=", 1) + os.environ[k.strip()] = v.strip() diff --git a/client/starwhale/utils/fs.py b/client/starwhale/utils/fs.py index 8a4b02afab..7a0c2d5a91 100644 --- a/client/starwhale/utils/fs.py +++ b/client/starwhale/utils/fs.py @@ -4,6 +4,7 @@ import typing as t import hashlib import tarfile +from enum import IntEnum from pathlib import Path from starwhale.utils import console, timestamp_to_datatimestr @@ -14,6 +15,11 @@ _MIN_GUESS_NAME_LENGTH = 5 +class FilePosition(IntEnum): + START = 0 + END = -1 + + def ensure_file(path: t.Union[str, Path], content: str, mode: int = 0o644) -> None: p = Path(path) try: diff --git a/client/tests/sdk/test_dataset.py b/client/tests/sdk/test_dataset.py index dd7ce7277c..7265ffc2b8 100644 --- a/client/tests/sdk/test_dataset.py +++ b/client/tests/sdk/test_dataset.py @@ -3,7 +3,6 @@ from pathlib import Path from starwhale.utils.fs import ensure_dir -from starwhale.base.type import DataFormatType, ObjectStoreType from starwhale.api._impl.dataset import ( _data_magic, _header_size, @@ -48,7 +47,7 @@ def setUp(self) -> None: super().setUp() self.raw_data = os.path.join(self.local_storage, ".user", "data") - self.output_data = os.path.join(self.local_storage, ".user", "output") + self.workdir = os.path.join(self.local_storage, ".user", "workdir") ensure_dir(self.raw_data) with open(os.path.join(self.raw_data, "mnist-data-0"), "wb") as f: @@ -63,7 +62,7 @@ def test_user_raw_workflow(self) -> None: dataset_version="332211", project_name="self", data_dir=Path(self.raw_data), - output_dir=Path(self.output_data), + workdir=Path(self.workdir), data_filter="mnist-data-*", label_filter="mnist-data-*", alignment_bytes_size=64, @@ -72,9 +71,9 @@ def test_user_raw_workflow(self) -> None: summary = e.make_swds() assert summary.rows == 10 - assert summary.data_format_type == DataFormatType.USER_RAW - assert summary.object_store_type == ObjectStoreType.LOCAL - data_path = Path(self.output_data, "mnist-data-0") + assert summary.include_user_raw + assert not summary.include_link + data_path = Path(self.workdir, "data", "mnist-data-0") assert data_path.exists() assert data_path.stat().st_size == 28 * 28 * summary.rows + 16 @@ -90,7 +89,7 @@ def test_swds_bin_workflow(self) -> None: dataset_version="112233", project_name="self", data_dir=Path(self.raw_data), - output_dir=Path(self.output_data), + workdir=Path(self.workdir), data_filter="mnist-data-*", label_filter="mnist-data-*", alignment_bytes_size=64, @@ -103,13 +102,13 @@ def test_swds_bin_workflow(self) -> None: assert summary.rows == 10 assert summary.increased_rows == 10 assert summary.unchanged_rows == 0 - assert summary.data_format_type == DataFormatType.SWDS_BIN - assert summary.object_store_type == ObjectStoreType.LOCAL + assert not summary.include_user_raw + assert not summary.include_link - data_path = Path(self.output_data, "data_ubyte_0.swds_bin") + data_path = Path(self.workdir, "data", "data_ubyte_0.swds_bin") for i in range(0, 5): - assert Path(self.output_data) / f"data_ubyte_{i}.swds_bin" + assert Path(self.workdir) / "data" / f"data_ubyte_{i}.swds_bin" data_content = data_path.read_bytes() _parser = _header_struct.unpack(data_content[:_header_size]) diff --git a/client/tests/sdk/test_loader.py b/client/tests/sdk/test_loader.py new file mode 100644 index 0000000000..87da1dcfec --- /dev/null +++ b/client/tests/sdk/test_loader.py @@ -0,0 +1,359 @@ +import os +import shutil +from unittest.mock import patch, MagicMock + +from pyfakefs.fake_filesystem_unittest import TestCase + +from starwhale.consts import AUTH_ENV_FNAME, SWDSBackendType +from starwhale.base.uri import URI +from starwhale.utils.fs import ensure_dir, ensure_file +from starwhale.base.type import URIType, DataFormatType, DataOriginType, ObjectStoreType +from starwhale.api._impl.loader import ( + get_data_loader, + SWDSBinDataLoader, + UserRawDataLoader, +) +from starwhale.api._impl.dataset import MIMEType, S3LinkAuth, TabularDatasetRow +from starwhale.core.dataset.store import DatasetStorage +from starwhale.core.dataset.dataset import DatasetSummary + +from .. import ROOT_DIR + + +class TestDataLoader(TestCase): + def setUp(self) -> None: + self.setUpPyfakefs() + self.dataset_uri = URI("mnist/version/1122334455667788", URIType.DATASET) + self.swds_dir = os.path.join(ROOT_DIR, "data", "dataset", "swds") + self.fs.add_real_directory(self.swds_dir) + + @patch("starwhale.core.dataset.model.StandaloneDataset.summary") + @patch("starwhale.api._impl.loader.TabularDataset.scan") + def test_user_raw_local_store( + self, m_scan: MagicMock, m_summary: MagicMock + ) -> None: + m_summary.return_value = DatasetSummary( + include_user_raw=True, + include_link=False, + ) + loader = get_data_loader(self.dataset_uri) + assert isinstance(loader, UserRawDataLoader) + + fname = "data" + m_scan.return_value = [ + TabularDatasetRow( + id=0, + object_store_type=ObjectStoreType.LOCAL, + data_uri=fname, + data_offset=16, + data_size=784, + label=b"0", + data_origin=DataOriginType.NEW, + data_format=DataFormatType.UNDEFINED, + data_mime_type=MIMEType.UNDEFINED, + auth_name="", + ) + ] + + raw_data_fpath = os.path.join(ROOT_DIR, "data", "dataset", "mnist", "data") + self.fs.add_real_file(raw_data_fpath) + data_dir = DatasetStorage(self.dataset_uri).data_dir + ensure_dir(data_dir) + shutil.copy(raw_data_fpath, str(data_dir / fname)) + + assert loader._stores == {} + + rows = list(loader) + assert len(rows) == 1 + + _data, _label = rows[0] + + assert _label.idx == 0 + assert _label.data_size == 1 + assert _data.ext_attr == {"ds_name": "mnist", "ds_version": "1122334455667788"} + assert _data.data_size == len(_data.data) + assert len(_data.data) == 28 * 28 + + assert loader.kind == DataFormatType.USER_RAW + assert list(loader._stores.keys()) == ["local."] + assert loader._stores["local."].bucket == str(data_dir) + assert loader._stores["local."].backend.kind == SWDSBackendType.FUSE + assert not loader._stores["local."].key_prefix + + @patch.dict(os.environ, {}) + @patch("starwhale.api._impl.loader.boto3.resource") + @patch("starwhale.core.dataset.model.StandaloneDataset.summary") + @patch("starwhale.api._impl.loader.TabularDataset.scan") + def test_user_raw_remote_store( + self, + m_scan: MagicMock, + m_summary: MagicMock, + m_boto3: MagicMock, + ) -> None: + m_summary.return_value = DatasetSummary( + include_user_raw=True, + include_link=True, + ) + + snapshot_workdir = DatasetStorage(self.dataset_uri).snapshot_workdir + ensure_dir(snapshot_workdir) + envs = { + "USER.S3.SERVER1.SECRET": "11", + "USER.S3.SERVER1.ACCESS_KEY": "11", + "USER.S3.SERVER2.SECRET": "11", + "USER.S3.SERVER2.ACCESS_KEY": "11", + "USER.S3.SERVER2.ENDPOINT": "127.0.0.1:19000", + } + os.environ.update(envs) + auth_env = S3LinkAuth.from_env(name="server1").dump_env() + auth_env.extend(S3LinkAuth.from_env(name="server2").dump_env()) + ensure_file( + snapshot_workdir / AUTH_ENV_FNAME, + content="\n".join(auth_env), + ) + + for k in envs: + os.environ.pop(k) + + loader = get_data_loader(self.dataset_uri) + assert isinstance(loader, UserRawDataLoader) + assert loader.kind == DataFormatType.USER_RAW + for k in envs: + assert k in os.environ + + version = "1122334455667788" + + m_scan.return_value = [ + TabularDatasetRow( + id=0, + object_store_type=ObjectStoreType.REMOTE, + data_uri=f"s3://127.0.0.1:9000@starwhale/project/2/dataset/11/{version}", + data_offset=16, + data_size=784, + label=b"0", + data_origin=DataOriginType.NEW, + data_format=DataFormatType.USER_RAW, + data_mime_type=MIMEType.GRAYSCALE, + auth_name="server1", + ), + TabularDatasetRow( + id=1, + object_store_type=ObjectStoreType.REMOTE, + data_uri=f"s3://127.0.0.1:19000@starwhale/project/2/dataset/11/{version}", + data_offset=16, + data_size=784, + label=b"1", + data_origin=DataOriginType.NEW, + data_format=DataFormatType.USER_RAW, + data_mime_type=MIMEType.GRAYSCALE, + auth_name="server2", + ), + TabularDatasetRow( + id=2, + object_store_type=ObjectStoreType.REMOTE, + data_uri=f"s3://starwhale/project/2/dataset/11/{version}", + data_offset=16, + data_size=784, + label=b"1", + data_origin=DataOriginType.NEW, + data_format=DataFormatType.USER_RAW, + data_mime_type=MIMEType.GRAYSCALE, + auth_name="server2", + ), + TabularDatasetRow( + id=3, + object_store_type=ObjectStoreType.REMOTE, + data_uri=f"s3://username:password@127.0.0.1:29000@starwhale/project/2/dataset/11/{version}", + data_offset=16, + data_size=784, + label=b"1", + data_origin=DataOriginType.NEW, + data_format=DataFormatType.USER_RAW, + data_mime_type=MIMEType.GRAYSCALE, + auth_name="server3", + ), + ] + + raw_data_fpath = os.path.join(ROOT_DIR, "data", "dataset", "mnist", "data") + self.fs.add_real_file(raw_data_fpath) + with open(raw_data_fpath, "rb") as f: + raw_content = f.read(-1) + + m_boto3.return_value = MagicMock( + **{ + "Object.return_value": MagicMock( + **{ + "get.return_value": { + "Body": MagicMock(**{"read.return_value": raw_content}), + "ContentLength": len(raw_content), + } + } + ) + } + ) + + assert loader.kind == DataFormatType.USER_RAW + assert loader._stores == {} + + rows = list(loader) + assert len(rows) == 4 + + _data, _label = rows[0] + assert _label.idx == 0 + assert _label.data == b"0" + assert len(_data.data) == 28 * 28 + assert len(_data.data) == _data.data_size + assert len(loader._stores) == 3 + assert loader._stores["remote.server1"].backend.kind == SWDSBackendType.S3 + assert loader._stores["remote.server1"].bucket == "starwhale" + + @patch.dict(os.environ, {}) + @patch("starwhale.api._impl.loader.boto3.resource") + @patch("starwhale.core.dataset.model.CloudDataset.summary") + @patch("starwhale.api._impl.loader.TabularDataset.scan") + def test_swds_bin_s3( + self, m_scan: MagicMock, m_summary: MagicMock, m_boto3: MagicMock + ) -> None: + m_summary.return_value = DatasetSummary( + include_user_raw=False, + include_link=False, + ) + version = "1122334455667788" + dataset_uri = URI( + f"http://127.0.0.1:1234/project/self/dataset/mnist/version/{version}", + expected_type=URIType.DATASET, + ) + loader = get_data_loader(dataset_uri) + assert isinstance(loader, SWDSBinDataLoader) + assert loader.kind == DataFormatType.SWDS_BIN + + fname = "data_ubyte_0.swds_bin" + m_scan.return_value = [ + TabularDatasetRow( + id=0, + object_store_type=ObjectStoreType.LOCAL, + data_uri=fname, + data_offset=0, + data_size=8160, + label=b"0", + data_origin=DataOriginType.NEW, + data_format=DataFormatType.SWDS_BIN, + data_mime_type=MIMEType.UNDEFINED, + auth_name="", + ) + ] + os.environ.update( + { + "SW_S3_BUCKET": "starwhale", + "SW_OBJECT_STORE_KEY_PREFIX": f"project/self/dataset/mnist/version/11/{version}", + "SW_S3_ENDPOINT": "starwhale.mock:9000", + "SW_S3_ACCESS_KEY": "foo", + "SW_S3_SECRET": "bar", + } + ) + + with open(os.path.join(self.swds_dir, fname), "rb") as f: + swds_content = f.read(-1) + + m_boto3.return_value = MagicMock( + **{ + "Object.return_value": MagicMock( + **{ + "get.return_value": { + "Body": MagicMock(**{"read.return_value": swds_content}), + "ContentLength": len(swds_content), + } + } + ) + } + ) + assert loader._stores == {} + + rows = list(loader) + assert len(rows) == 1 + _data, _label = rows[0] + assert _label.idx == 0 + assert _label.data == b"0" + assert _label.data_size == 1 + + assert len(_data.data) == _data.data_size + assert _data.data_size == 10 * 28 * 28 + assert _data.ext_attr == {"ds_name": "mnist", "ds_version": version} + + assert list(loader._stores.keys()) == ["local."] + backend = loader._stores["local."].backend + assert backend.kind == SWDSBackendType.S3 + assert backend.s3.Object.call_args[0] == ( + "starwhale", + f"project/self/dataset/mnist/version/11/{version}/{fname}", + ) + + assert loader._stores["local."].bucket == "starwhale" + assert ( + loader._stores["local."].key_prefix + == f"project/self/dataset/mnist/version/11/{version}" + ) + + @patch.dict(os.environ, {}) + @patch("starwhale.core.dataset.model.StandaloneDataset.summary") + @patch("starwhale.api._impl.loader.TabularDataset.scan") + def test_swds_bin_fuse(self, m_scan: MagicMock, m_summary: MagicMock) -> None: + m_summary.return_value = DatasetSummary( + include_user_raw=False, + include_link=False, + rows=2, + increased_rows=2, + ) + loader = get_data_loader(self.dataset_uri) + assert isinstance(loader, SWDSBinDataLoader) + assert loader.kind == DataFormatType.SWDS_BIN + + fname = "data_ubyte_0.swds_bin" + m_scan.return_value = [ + TabularDatasetRow( + id=0, + object_store_type=ObjectStoreType.LOCAL, + data_uri=fname, + data_offset=0, + data_size=8160, + label=b"0", + data_origin=DataOriginType.NEW, + data_format=DataFormatType.SWDS_BIN, + data_mime_type=MIMEType.UNDEFINED, + auth_name="", + ), + TabularDatasetRow( + id=1, + object_store_type=ObjectStoreType.LOCAL, + data_uri=fname, + data_offset=0, + data_size=8160, + label=b"1", + data_origin=DataOriginType.NEW, + data_format=DataFormatType.SWDS_BIN, + data_mime_type=MIMEType.UNDEFINED, + auth_name="", + ), + ] + + data_dir = DatasetStorage(self.dataset_uri).data_dir + ensure_dir(data_dir) + shutil.copyfile(os.path.join(self.swds_dir, fname), str(data_dir / fname)) + assert loader._stores == {} + + rows = list(loader) + assert len(rows) == 2 + + _data, _label = rows[0] + assert _label.idx == 0 + assert _label.data == b"0" + assert _label.data_size == 1 + + assert len(_data.data) == _data.data_size + assert _data.data_size == 10 * 28 * 28 + assert _data.ext_attr == {"ds_name": "mnist", "ds_version": "1122334455667788"} + + assert list(loader._stores.keys()) == ["local."] + assert loader._stores["local."].backend.kind == SWDSBackendType.FUSE + assert loader._stores["local."].bucket == str(data_dir) + assert not loader._stores["local."].key_prefix diff --git a/client/tests/sdk/test_model.py b/client/tests/sdk/test_model.py index b909e94a08..686187b0d9 100644 --- a/client/tests/sdk/test_model.py +++ b/client/tests/sdk/test_model.py @@ -10,18 +10,14 @@ import jsonlines from pyfakefs.fake_filesystem_unittest import TestCase -from starwhale.consts import DEFAULT_PROJECT, SWDSBackendType +from starwhale.consts import DEFAULT_PROJECT from starwhale.base.uri import URI from starwhale.utils.fs import ensure_dir, ensure_file -from starwhale.base.type import URIType, DataFormatType, ObjectStoreType +from starwhale.api.model import PipelineHandler +from starwhale.base.type import URIType from starwhale.consts.env import SWEnv from starwhale.api._impl.job import Context -from starwhale.api._impl.model import PipelineHandler -from starwhale.api._impl.loader import ( - get_data_loader, - S3StorageBackend, - UserRawDataLoader, -) +from starwhale.api._impl.loader import get_data_loader, UserRawDataLoader from starwhale.api._impl.dataset import TabularDatasetRow from starwhale.api._impl.wrapper import Evaluation from starwhale.core.dataset.dataset import DatasetSummary @@ -65,15 +61,14 @@ def tearDown(self) -> None: @patch("starwhale.core.dataset.model.StandaloneDataset.summary") def test_s3_loader(self, m_summary: MagicMock, m_resource: MagicMock) -> None: m_summary.return_value = DatasetSummary( - data_format_type=DataFormatType.USER_RAW + include_user_raw=True, ) _loader = get_data_loader( dataset_uri=URI("mnist/version/latest", URIType.DATASET), - backend=SWDSBackendType.S3, ) assert isinstance(_loader, UserRawDataLoader) - assert isinstance(_loader.storage.backend, S3StorageBackend) + assert not _loader._stores @pytest.mark.skip(reason="wait job scheduler feature, cmp will use datastore") def test_cmp(self) -> None: @@ -133,8 +128,8 @@ def test_ppl(self, m_summary: MagicMock, m_scan: MagicMock) -> None: m_summary.return_value = DatasetSummary( rows=1, increased_rows=1, - data_format_type=DataFormatType.SWDS_BIN, - object_store_type=ObjectStoreType.LOCAL, + include_user_raw=False, + include_link=False, label_byte_size=1, data_byte_size=10, ) diff --git a/client/tests/utils/test_common.py b/client/tests/utils/test_common.py index ab3dc934aa..b5e1955891 100644 --- a/client/tests/utils/test_common.py +++ b/client/tests/utils/test_common.py @@ -1,6 +1,13 @@ import os +import typing as t +from pathlib import Path +from unittest.mock import patch -from starwhale.utils import validate_obj_name +import pytest +from pyfakefs.fake_filesystem import FakeFilesystem +from pyfakefs.fake_filesystem_unittest import Patcher + +from starwhale.utils import load_dotenv, validate_obj_name from starwhale.consts import ENV_LOG_LEVEL from starwhale.utils.debug import init_logger @@ -24,3 +31,29 @@ def test_logger() -> None: init_logger(3) assert os.environ[ENV_LOG_LEVEL] == "DEBUG" + + +@pytest.fixture +def fake_fs() -> t.Generator[t.Optional[FakeFilesystem], None, None]: + with Patcher() as patcher: + yield patcher.fs + + +@patch.dict(os.environ, {"TEST_ENV": "1"}, clear=True) +def test_load_dotenv(fake_fs: FakeFilesystem) -> None: + content = """ + # this is a comment line + A=1 + B = 2 + c = + ddd + """ + fpath = "/home/starwhale/test/.auth_env" + fake_fs.create_file(fpath, contents=content) + assert os.environ["TEST_ENV"] == "1" + load_dotenv(Path(fpath)) + assert os.environ["A"] == "1" + assert os.environ["B"] == "2" + assert not os.environ["c"] + assert "ddd" not in os.environ + assert len(os.environ) == 4 diff --git a/example/PennFudanPed/code/utils.py b/example/PennFudanPed/code/utils.py index 0b7ed61988..209d1014c1 100644 --- a/example/PennFudanPed/code/utils.py +++ b/example/PennFudanPed/code/utils.py @@ -1,14 +1,13 @@ -from collections import defaultdict, deque -import datetime -import pickle +import os import time +import errno +import pickle +import datetime +from collections import deque, defaultdict import torch import torch.distributed as dist -import errno -import os - class SmoothedValue(object): """Track a series of values and provide access to smoothed values over a @@ -34,7 +33,7 @@ def synchronize_between_processes(self): """ if not is_dist_avail_and_initialized(): return - t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") dist.barrier() dist.all_reduce(t) t = t.tolist() @@ -69,7 +68,8 @@ def __str__(self): avg=self.avg, global_avg=self.global_avg, max=self.max, - value=self.value) + value=self.value, + ) def all_gather(data): @@ -103,7 +103,9 @@ def all_gather(data): for _ in size_list: tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) if local_size != max_size: - padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") + padding = torch.empty( + size=(max_size - local_size,), dtype=torch.uint8, device="cuda" + ) tensor = torch.cat((tensor, padding), dim=0) dist.all_gather(tensor_list, tensor) @@ -159,15 +161,14 @@ def __getattr__(self, attr): return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] - raise AttributeError("'{}' object has no attribute '{}'".format( - type(self).__name__, attr)) + raise AttributeError( + "'{}' object has no attribute '{}'".format(type(self).__name__, attr) + ) def __str__(self): loss_str = [] for name, meter in self.meters.items(): - loss_str.append( - "{}: {}".format(name, str(meter)) - ) + loss_str.append("{}: {}".format(name, str(meter))) return self.delimiter.join(loss_str) def synchronize_between_processes(self): @@ -180,31 +181,35 @@ def add_meter(self, name, meter): def log_every(self, iterable, print_freq, header=None): i = 0 if not header: - header = '' + header = "" start_time = time.time() end = time.time() - iter_time = SmoothedValue(fmt='{avg:.4f}') - data_time = SmoothedValue(fmt='{avg:.4f}') - space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" if torch.cuda.is_available(): - log_msg = self.delimiter.join([ - header, - '[{0' + space_fmt + '}/{1}]', - 'eta: {eta}', - '{meters}', - 'time: {time}', - 'data: {data}', - 'max mem: {memory:.0f}' - ]) + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + "max mem: {memory:.0f}", + ] + ) else: - log_msg = self.delimiter.join([ - header, - '[{0' + space_fmt + '}/{1}]', - 'eta: {eta}', - '{meters}', - 'time: {time}', - 'data: {data}' - ]) + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] + ) MB = 1024.0 * 1024.0 for obj in iterable: data_time.update(time.time() - end) @@ -214,22 +219,37 @@ def log_every(self, iterable, print_freq, header=None): eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if torch.cuda.is_available(): - print(log_msg.format( - i, len(iterable), eta=eta_string, - meters=str(self), - time=str(iter_time), data=str(data_time), - memory=torch.cuda.max_memory_allocated() / MB)) + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) else: - print(log_msg.format( - i, len(iterable), eta=eta_string, - meters=str(self), - time=str(iter_time), data=str(data_time))) + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + ) + ) i += 1 end = time.time() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print('{} Total time: {} ({:.4f} s / it)'.format( - header, total_time_str, total_time / len(iterable))) + print( + "{} Total time: {} ({:.4f} s / it)".format( + header, total_time_str, total_time / len(iterable) + ) + ) def collate_fn(batch): @@ -237,7 +257,6 @@ def collate_fn(batch): def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor): - def f(x): if x >= warmup_iters: return 1 @@ -260,10 +279,11 @@ def setup_for_distributed(is_master): This function disables printing when not in master process """ import builtins as __builtin__ + builtin_print = __builtin__.print def print(*args, **kwargs): - force = kwargs.pop('force', False) + force = kwargs.pop("force", False) if is_master or force: builtin_print(*args, **kwargs) @@ -300,25 +320,30 @@ def save_on_master(*args, **kwargs): def init_distributed_mode(args): - if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: args.rank = int(os.environ["RANK"]) - args.world_size = int(os.environ['WORLD_SIZE']) - args.gpu = int(os.environ['LOCAL_RANK']) - elif 'SLURM_PROCID' in os.environ: - args.rank = int(os.environ['SLURM_PROCID']) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) args.gpu = args.rank % torch.cuda.device_count() else: - print('Not using distributed mode') + print("Not using distributed mode") args.distributed = False return args.distributed = True torch.cuda.set_device(args.gpu) - args.dist_backend = 'nccl' - print('| distributed init (rank {}): {}'.format( - args.rank, args.dist_url), flush=True) - torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, - world_size=args.world_size, rank=args.rank) + args.dist_backend = "nccl" + print( + "| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True + ) + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) torch.distributed.barrier() - setup_for_distributed(args.rank == 0) \ No newline at end of file + setup_for_distributed(args.rank == 0) diff --git a/example/mnist/dataset.yaml b/example/mnist/dataset.yaml index 190f619ef5..b8626b9805 100644 --- a/example/mnist/dataset.yaml +++ b/example/mnist/dataset.yaml @@ -8,7 +8,7 @@ process: mnist.process:RawDataSetProcessExecutor desc: MNIST data and label test dataset tag: - - bin + - bin attr: alignment_size: 4k diff --git a/example/mnist/mnist/ppl.py b/example/mnist/mnist/ppl.py index 716291a07f..30ba46754f 100644 --- a/example/mnist/mnist/ppl.py +++ b/example/mnist/mnist/ppl.py @@ -6,11 +6,14 @@ from PIL import Image from torchvision import transforms +from starwhale.api.job import Context from starwhale.api.model import PipelineHandler from starwhale.api.metric import multi_classification -from starwhale.api.job import Context -from .model import Net +try: + from .model import Net +except ImportError: + from model import Net ROOTDIR = Path(__file__).parent.parent IMAGE_WIDTH = 28 @@ -41,7 +44,6 @@ def handle_label(self, label, **kw): def cmp(self, _data_loader): _result, _label, _pr = [], [], [] for _data in _data_loader: - # logger.debug(f"cmp data:{_data}") _label.extend([int(l) for l in _data[self._label_field]]) # unpack data according to the return value of function ppl (pred, pr) = _data[self._ppl_data_field] @@ -72,19 +74,19 @@ def _load_model(self, device): return model -def load_test_env(fuse=True): - _p = lambda p: str((ROOTDIR / "test" / p).resolve()) - - os.environ["SW_TASK_STATUS_DIR"] = _p("task_volume/status") - os.environ["SW_TASK_LOG_DIR"] = _p("task_volume/log") - os.environ["SW_TASK_RESULT_DIR"] = _p("task_volume/result") - - fname = "swds_fuse.json" if fuse else "swds_s3.json" - # fname = "swds_fuse_simple.json" if fuse else "swds_s3_simple.json" - os.environ["SW_TASK_INPUT_CONFIG"] = _p(fname) - - if __name__ == "__main__": - load_test_env(fuse=False) - mnist = MNISTInference() + from starwhale.api.job import Context + + context = Context( + workdir=Path("."), + src_dir=Path("."), + dataset_uris=["mnist/version/latest"], + project="self", + version="latest", + kw={ + "status_dir": "/tmp/mnist/status", + "log_dir": "/tmp/mnist/log", + }, + ) + mnist = MNISTInference(context) mnist._starwhale_internal_run_ppl() diff --git a/example/mnist/mnist/process.py b/example/mnist/mnist/process.py index 53bd140644..20028845d0 100644 --- a/example/mnist/mnist/process.py +++ b/example/mnist/mnist/process.py @@ -1,7 +1,13 @@ import struct from pathlib import Path -from starwhale.api.dataset import SWDSBinBuildExecutor, UserRawBuildExecutor +from starwhale.api.dataset import ( + Link, + MIMEType, + S3LinkAuth, + SWDSBinBuildExecutor, + UserRawBuildExecutor, +) def _do_iter_label_slice(path: str): @@ -51,3 +57,26 @@ def iter_data_slice(self, path: str): def iter_label_slice(self, path: str): return _do_iter_label_slice(path) + + +class LinkRawDataSetProcessExecutor(RawDataSetProcessExecutor): + _auth = S3LinkAuth(name="mnist", access_key="minioadmin", secret="minioadmin") + _endpoint = "10.131.0.1:9000" + _bucket = "users" + + def iter_all_dataset_slice(self): + offset = 16 + size = 28 * 28 + uri = ( + f"s3://{self._endpoint}@{self._bucket}/dataset/mnist/t10k-images-idx3-ubyte" + ) + for _ in range(10000): + link = Link( + f"{uri}", + self._auth, + offset=offset, + size=size, + mime_type=MIMEType.GRAYSCALE, + ) + yield link + offset += size