Skip to content

Commit

Permalink
add dataset sdk interface
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweidut committed Dec 2, 2022
1 parent 76ee74b commit 5495c48
Show file tree
Hide file tree
Showing 33 changed files with 2,334 additions and 256 deletions.
5 changes: 5 additions & 0 deletions client/starwhale/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Image,
Video,
Binary,
Dataset,
LinkAuth,
LinkType,
MIMEType,
Expand All @@ -28,10 +29,14 @@
from starwhale.api.evaluation import Evaluation
from starwhale.core.dataset.tabular import get_dataset_consumption

dataset = Dataset.dataset

__all__ = [
"__version__",
"PipelineHandler",
"multi_classification",
"Dataset",
"dataset",
"URI",
"URIType",
"step",
Expand Down
62 changes: 51 additions & 11 deletions client/starwhale/api/_impl/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,19 @@
import threading
from abc import ABCMeta, abstractmethod
from http import HTTPStatus
from typing import Any, Set, cast, Dict, List, Type, Tuple, Union, Iterator, Optional
from typing import (
Any,
Set,
cast,
Dict,
List,
Type,
Tuple,
Union,
Callable,
Iterator,
Optional,
)

import dill
import numpy as np
Expand Down Expand Up @@ -665,6 +677,7 @@ def _scan_parquet_file(
start: Optional[Any] = None,
end: Optional[Any] = None,
keep_none: bool = False,
end_inclusive: bool = False,
) -> Iterator[dict]:
f = pq.ParquetFile(path)
schema_arrow = f.schema_arrow
Expand Down Expand Up @@ -708,7 +721,10 @@ def _scan_parquet_file(
n_cols = len(names)
for j in range(n_rows):
key = types[0].deserialize(table[0][j].as_py())
if (start is not None and key < start) or (end is not None and key >= end):
_end_check: Callable = lambda x, y: x > y if end_inclusive else x >= y
if (start is not None and key < start) or (
end is not None and _end_check(key, end)
):
continue
d = {"*": key}
if key_alias is not None:
Expand Down Expand Up @@ -828,14 +844,15 @@ def _scan_table(
start: Optional[Any] = None,
end: Optional[Any] = None,
keep_none: bool = False,
end_inclusive: bool = False,
) -> Iterator[dict]:
iters = []
for file in _get_table_files(path):
if os.path.basename(file).startswith("patch"):
keep = True
else:
keep = keep_none
iters.append(_scan_parquet_file(file, columns, start, end, keep))
iters.append(_scan_parquet_file(file, columns, start, end, keep, end_inclusive))
return _merge_scan(iters, keep_none)


Expand Down Expand Up @@ -911,16 +928,24 @@ def scan(
start: Optional[Any] = None,
end: Optional[Any] = None,
keep_none: bool = False,
end_inclusive: bool = False,
) -> Iterator[Dict[str, Any]]:
_end_check: Callable = lambda x, y: x <= y if end_inclusive else x < y

with self.lock:
schema = self.schema.copy()
records = [
{self.schema.key_column: key, "-": True}
for key in self.deletes
if (start is None or key >= start) and (end is None or key < end)
]

records = []
for key in self.deletes:
if (start is None or key >= start) and (
end is None or _end_check(key, end)
):
records.append({self.schema.key_column: key, "-": True})

for k, v in self.records.items():
if (start is None or k >= start) and (end is None or k < end):
if (start is None or k >= start) and (
end is None or _end_check(k, end)
):
records.append(v)
records.sort(key=lambda x: cast(str, x[self.schema.key_column]))
for r in records:
Expand Down Expand Up @@ -1105,6 +1130,7 @@ def scan_tables(
start: Optional[Any] = None,
end: Optional[Any] = None,
keep_none: bool = False,
end_inclusive: bool = False,
) -> Iterator[Dict[str, Any]]:
class TableInfo:
def __init__(
Expand Down Expand Up @@ -1163,9 +1189,14 @@ def __init__(
start,
end,
info.keep_none,
end_inclusive,
),
self.tables[info.name].scan(
info.columns, start, end, True
info.columns,
start,
end,
True,
end_inclusive,
),
],
info.keep_none,
Expand All @@ -1174,7 +1205,11 @@ def __init__(
else:
iters.append(
self.tables[info.name].scan(
info.columns, start, end, info.keep_none
info.columns,
start,
end,
info.keep_none,
end_inclusive,
)
)
else:
Expand All @@ -1185,6 +1220,7 @@ def __init__(
start,
end,
info.keep_none,
end_inclusive,
)
)
for record in _merge_scan(iters, keep_none):
Expand Down Expand Up @@ -1284,6 +1320,7 @@ def scan_tables(
start: Optional[Any] = None,
end: Optional[Any] = None,
keep_none: bool = False,
end_inclusive: bool = False,
) -> Iterator[Dict[str, Any]]:
post_data: Dict[str, Any] = {"tables": [table.to_dict() for table in tables]}
key_type = _get_type(start)
Expand All @@ -1294,6 +1331,8 @@ def scan_tables(
post_data["limit"] = 1000
if keep_none:
post_data["keepNone"] = True
if end_inclusive:
post_data["endInclusive"] = True
assert self.token is not None
while True:
resp_json = self._do_scan_table_request(post_data)
Expand Down Expand Up @@ -1338,6 +1377,7 @@ def scan_tables(
start: Optional[Any] = None,
end: Optional[Any] = None,
keep_none: bool = False,
end_inclusive: bool = False,
) -> Iterator[Dict[str, Any]]:
...

Expand Down
2 changes: 2 additions & 0 deletions client/starwhale/api/_impl/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
COCOObjectAnnotation,
)

from .model import Dataset
from .loader import get_data_loader, SWDSBinDataLoader, UserRawDataLoader
from .builder import BuildExecutor, SWDSBinBuildExecutor, UserRawBuildExecutor

Expand All @@ -43,4 +44,5 @@
"BoundingBox",
"GrayscaleImage",
"COCOObjectAnnotation",
"Dataset",
]
Loading

0 comments on commit 5495c48

Please sign in to comment.