Skip to content

Commit

Permalink
feat(dataset): support Link type for dataset (#944)
Browse files Browse the repository at this point in the history
* support Link type for dataset

* add test for dataset loader
  • Loading branch information
tianweidut authored Aug 23, 2022
1 parent e284eca commit ee3acba
Show file tree
Hide file tree
Showing 17 changed files with 1,015 additions and 240 deletions.
268 changes: 240 additions & 28 deletions client/starwhale/api/_impl/dataset.py

Large diffs are not rendered by default.

274 changes: 191 additions & 83 deletions client/starwhale/api/_impl/loader.py

Large diffs are not rendered by default.

13 changes: 4 additions & 9 deletions client/starwhale/api/_impl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,18 @@
import loguru
import jsonlines

from starwhale.utils import now_str, in_production
from starwhale.utils import now_str
from starwhale.consts import CURRENT_FNAME
from starwhale.base.uri import URI
from starwhale.utils.fs import ensure_dir, ensure_file
from starwhale.base.type import URIType, RunSubDirType
from starwhale.utils.log import StreamWrapper
from starwhale.utils.error import FieldTypeOrValueError
from starwhale.api._impl.job import Context
from starwhale.api._impl.loader import DataField, ResultLoader, get_data_loader
from starwhale.api._impl.wrapper import Evaluation
from starwhale.core.dataset.model import Dataset

_TASK_ROOT_DIR = "/var/starwhale" if in_production() else "/tmp/starwhale"

_ptype = t.Union[str, None, Path]
_p: t.Callable[[_ptype, str], Path] = (
lambda p, sub: Path(p) if p else Path(_TASK_ROOT_DIR) / sub
)


class _LogType:
SW = "starwhale"
Expand Down Expand Up @@ -95,6 +89,7 @@ def __init__(
self._ppl_data_field = "result"
self._label_field = "label"
self.evaluation = self._init_datastore()

self._monkey_patch()

def _init_dir(self) -> None:
Expand Down Expand Up @@ -258,7 +253,7 @@ def _starwhale_internal_run_cmp(self) -> None:
def _starwhale_internal_run_ppl(self) -> None:
self._update_status(self.STATUS.START)
if not self.context.dataset_uris:
raise RuntimeError("no dataset uri!")
raise FieldTypeOrValueError("context.dataset_uris is empty")
# TODO: support multi dataset uris
_dataset_uri = URI(self.context.dataset_uris[0], expected_type=URIType.DATASET)
_dataset = Dataset.get_dataset(_dataset_uri)
Expand Down
6 changes: 6 additions & 0 deletions client/starwhale/api/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from ._impl.dataset import (
Link,
MIMEType,
S3LinkAuth,
BuildExecutor,
MNISTBuildExecutor,
SWDSBinBuildExecutor,
Expand All @@ -12,4 +15,7 @@
"MNISTBuildExecutor",
"UserRawBuildExecutor",
"SWDSBinBuildExecutor",
"S3LinkAuth",
"Link",
"MIMEType",
]
1 change: 1 addition & 0 deletions client/starwhale/consts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,4 @@ class SWDSSubFileType:
DEFAULT_CONDA_CHANNEL = "conda-forge"

WHEEL_FILE_EXTENSION = ".whl"
AUTH_ENV_FNAME = ".auth_env"
21 changes: 6 additions & 15 deletions client/starwhale/core/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from starwhale.utils import load_yaml, convert_to_bytes
from starwhale.consts import DEFAULT_STARWHALE_API_VERSION
from starwhale.base.type import DataFormatType, ObjectStoreType
from starwhale.utils.error import NoSupportError


Expand All @@ -25,27 +24,19 @@ def __init__(
self,
rows: int = 0,
increased_rows: int = 0,
data_format_type: t.Union[DataFormatType, str] = DataFormatType.UNDEFINED,
object_store_type: t.Union[ObjectStoreType, str] = ObjectStoreType.UNDEFINED,
label_byte_size: int = 0,
data_byte_size: int = 0,
include_link: bool = False,
include_user_raw: bool = False,
**kw: t.Any,
) -> None:
self.rows = rows
self.increased_rows = increased_rows
self.unchanged_rows = rows - increased_rows
self.data_format_type: DataFormatType = (
DataFormatType(data_format_type)
if isinstance(data_format_type, str)
else data_format_type
)
self.object_store_type: ObjectStoreType = (
ObjectStoreType(object_store_type)
if isinstance(object_store_type, str)
else object_store_type
)
self.label_byte_size = label_byte_size
self.data_byte_size = data_byte_size
self.include_link = include_link
self.include_user_raw = include_user_raw

def as_dict(self) -> t.Dict[str, t.Any]:
d = deepcopy(self.__dict__)
Expand All @@ -55,12 +46,12 @@ def as_dict(self) -> t.Dict[str, t.Any]:
return d

def __str__(self) -> str:
return f"Dataset Summary: rows({self.rows}), data_format({self.data_format_type}), object_store({self.object_store_type})"
return f"Dataset Summary: rows({self.rows}), include user-raw({self.include_user_raw}), include link({self.include_link})"

def __repr__(self) -> str:
return (
f"Dataset Summary: rows({self.rows}, increased: {self.increased_rows}), "
f"data_format({self.data_format_type}), object_store({self.object_store_type}),"
f"include user-raw({self.include_user_raw}), include link({self.include_link}),"
f"size(data:{self.data_byte_size}, label: {self.label_byte_size})"
)

Expand Down
2 changes: 1 addition & 1 deletion client/starwhale/core/dataset/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def _call_make_swds(self, workdir: Path, swds_config: DatasetConfig) -> None:
dataset_version=self._version,
project_name=self.uri.project,
data_dir=workdir / swds_config.data_dir,
output_dir=self.store.data_dir,
workdir=self.store.snapshot_workdir,
data_filter=swds_config.data_filter,
label_filter=swds_config.label_filter,
alignment_bytes_size=swds_config.attr.alignment_size,
Expand Down
14 changes: 14 additions & 0 deletions client/starwhale/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,17 @@ def make_dir_gitignore(d: Path) -> None:

ensure_dir(d)
ensure_file(d / ".gitignore", "*")


def load_dotenv(fpath: Path) -> None:
if not fpath.exists():
return

with fpath.open("r") as f:
for line in f.readlines():
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue

k, v = line.split("=", 1)
os.environ[k.strip()] = v.strip()
6 changes: 6 additions & 0 deletions client/starwhale/utils/fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import typing as t
import hashlib
import tarfile
from enum import IntEnum
from pathlib import Path

from starwhale.utils import console, timestamp_to_datatimestr
Expand All @@ -14,6 +15,11 @@
_MIN_GUESS_NAME_LENGTH = 5


class FilePosition(IntEnum):
START = 0
END = -1


def ensure_file(path: t.Union[str, Path], content: str, mode: int = 0o644) -> None:
p = Path(path)
try:
Expand Down
21 changes: 10 additions & 11 deletions client/tests/sdk/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from pathlib import Path

from starwhale.utils.fs import ensure_dir
from starwhale.base.type import DataFormatType, ObjectStoreType
from starwhale.api._impl.dataset import (
_data_magic,
_header_size,
Expand Down Expand Up @@ -48,7 +47,7 @@ def setUp(self) -> None:
super().setUp()

self.raw_data = os.path.join(self.local_storage, ".user", "data")
self.output_data = os.path.join(self.local_storage, ".user", "output")
self.workdir = os.path.join(self.local_storage, ".user", "workdir")

ensure_dir(self.raw_data)
with open(os.path.join(self.raw_data, "mnist-data-0"), "wb") as f:
Expand All @@ -63,7 +62,7 @@ def test_user_raw_workflow(self) -> None:
dataset_version="332211",
project_name="self",
data_dir=Path(self.raw_data),
output_dir=Path(self.output_data),
workdir=Path(self.workdir),
data_filter="mnist-data-*",
label_filter="mnist-data-*",
alignment_bytes_size=64,
Expand All @@ -72,9 +71,9 @@ def test_user_raw_workflow(self) -> None:
summary = e.make_swds()

assert summary.rows == 10
assert summary.data_format_type == DataFormatType.USER_RAW
assert summary.object_store_type == ObjectStoreType.LOCAL
data_path = Path(self.output_data, "mnist-data-0")
assert summary.include_user_raw
assert not summary.include_link
data_path = Path(self.workdir, "data", "mnist-data-0")

assert data_path.exists()
assert data_path.stat().st_size == 28 * 28 * summary.rows + 16
Expand All @@ -90,7 +89,7 @@ def test_swds_bin_workflow(self) -> None:
dataset_version="112233",
project_name="self",
data_dir=Path(self.raw_data),
output_dir=Path(self.output_data),
workdir=Path(self.workdir),
data_filter="mnist-data-*",
label_filter="mnist-data-*",
alignment_bytes_size=64,
Expand All @@ -103,13 +102,13 @@ def test_swds_bin_workflow(self) -> None:
assert summary.rows == 10
assert summary.increased_rows == 10
assert summary.unchanged_rows == 0
assert summary.data_format_type == DataFormatType.SWDS_BIN
assert summary.object_store_type == ObjectStoreType.LOCAL
assert not summary.include_user_raw
assert not summary.include_link

data_path = Path(self.output_data, "data_ubyte_0.swds_bin")
data_path = Path(self.workdir, "data", "data_ubyte_0.swds_bin")

for i in range(0, 5):
assert Path(self.output_data) / f"data_ubyte_{i}.swds_bin"
assert Path(self.workdir) / "data" / f"data_ubyte_{i}.swds_bin"

data_content = data_path.read_bytes()
_parser = _header_struct.unpack(data_content[:_header_size])
Expand Down
Loading

0 comments on commit ee3acba

Please sign in to comment.