Skip to content

Commit

Permalink
enhance(client): prefer using the built-in map when building dataset (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
jialeicui authored Oct 31, 2023
1 parent a89d722 commit 86cbb23
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 80 deletions.
8 changes: 4 additions & 4 deletions client/starwhale/api/_impl/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
_RETRY_HTTP_STATUS_CODES,
)
from starwhale.utils.config import SWCliConfigMixed
from starwhale.utils.dict_util import flatten as flatten_dict
from starwhale.base.models.base import SwBaseModel
from starwhale.base.client.models.models import ColumnSchemaDesc, KeyValuePairSchema

Expand Down Expand Up @@ -803,6 +802,9 @@ def __eq__(self, other: Any) -> bool:
and self.sparse_pair_types == other.sparse_pair_types
)

def __hash__(self) -> int:
return hash(str(self))


class SwObjectType(SwCompositeType):
def __init__(self, raw_type: Type, attrs: Dict[str, SwType]) -> None:
Expand Down Expand Up @@ -967,8 +969,7 @@ class Record(UserDict):
def dumps(self) -> Dict[str, Dict]:
return {
"schema": {
k: json.loads(SwType.encode_schema(_get_type(v)).json())
for k, v in self.items()
k: SwType.encode_schema(_get_type(v)).to_dict() for k, v in self.items()
},
"data": {k: _get_type(v).encode(v) for k, v in self.items()},
}
Expand Down Expand Up @@ -2862,7 +2863,6 @@ def _raise_run_exceptions(self, limits: int) -> None:
raise TableWriterException(f"{self} run raise {len(_es)} exceptions: {_es}")

def insert(self, record: Dict[str, Any]) -> None:
record = flatten_dict(record)
for k in record:
for ch in k:
if (
Expand Down
4 changes: 1 addition & 3 deletions client/starwhale/api/_impl/evaluation/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from starwhale.api.service import Input, Output, Service
from starwhale.utils.error import ParameterError, FieldTypeOrValueError
from starwhale.base.context import Context
from starwhale.base.data_type import JsonDict
from starwhale.core.job.store import JobStorage
from starwhale.api._impl.dataset import Dataset
from starwhale.base.uri.resource import Resource, ResourceType
Expand Down Expand Up @@ -438,8 +437,7 @@ def _log_predict_result(
}

input_features = {
f"{self._INPUT_PREFIX}{k}": JsonDict.from_data(v)
for k, v in _log_features.items()
f"{self._INPUT_PREFIX}{k}": v for k, v in _log_features.items()
}
if self.predict_log_mode == PredictLogMode.PICKLE:
output = dill.dumps(output)
Expand Down
12 changes: 6 additions & 6 deletions client/starwhale/core/dataset/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
InvalidObjectName,
FieldTypeOrValueError,
)
from starwhale.base.data_type import Link, JsonDict, Sequence, BaseArtifact
from starwhale.base.data_type import Link, Sequence, BaseArtifact
from starwhale.base.uri.project import Project
from starwhale.api._impl.wrapper import Dataset as DatastoreWrapperDataset
from starwhale.api._impl.wrapper import DatasetTableKind
Expand Down Expand Up @@ -76,16 +76,16 @@ def __init__(self, mapping: t.Any = None, **kwargs: t.Any) -> None:
raise TypeError(f"key:{k} is not str type")

# TODO: add validator for value?
converted_mapping[k] = JsonDict.from_data(v)
converted_mapping[k] = v
super().__init__(converted_mapping)

def __getitem__(self, k: str) -> t.Any:
return JsonDict.to_data(super().__getitem__(k))
return super().__getitem__(k)

def __setitem__(self, k: str, v: t.Any) -> None:
if not isinstance(k, str):
raise TypeError(f"key:{k} is not str type")
super().__setitem__(k, JsonDict.from_data(v))
super().__setitem__(k, v)

@classmethod
def load_from_datastore(
Expand Down Expand Up @@ -138,7 +138,7 @@ def from_datastore(
for k, v in kw.items():
if k.startswith(cls._FEATURES_PREFIX):
_, name = k.split(cls._FEATURES_PREFIX, 1)
_content[name] = JsonDict.to_data(v)
_content[name] = v
else:
_extra_kw[k] = v

Expand Down Expand Up @@ -173,7 +173,7 @@ def asdict(self, ignore_keys: t.Optional[t.List[str]] = None) -> t.Dict:
d = super().asdict(ignore_keys=ignore_keys or ["features", "extra_kw"])
d.update(_do_asdict_convert(self.extra_kw))
for k, v in self.features.items():
d[f"{self._FEATURES_PREFIX}{k}"] = JsonDict.from_data(v)
d[f"{self._FEATURES_PREFIX}{k}"] = v
return d

@classmethod
Expand Down
6 changes: 3 additions & 3 deletions client/tests/core/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,12 @@ def test_no_support_type_for_build_handler(self) -> None:

def _iter_rows() -> t.Generator:
for _ in range(0, 5):
yield {"a": {1: "a", b"b": "b"}}
yield type("can not find")

sd = StandaloneDataset(dataset_uri)
with self.assertRaisesRegex(
RuntimeError,
"RowPutThread raise exception: json like dict shouldn't have none-str keys 1",
TypeError,
"value only supports tuple, dict or DataRow type",
):
sd.build(config=DatasetConfig(name=name, handler=_iter_rows))

Expand Down
2 changes: 1 addition & 1 deletion client/tests/sdk/test_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2289,7 +2289,7 @@ def test_insert_and_delete(self) -> None:
{"k": 0, "a": None},
{"k": 2, "a": "22"},
{"k": 3, "b": "3"},
{"k": 4, "a": 0, "a/b": 0, "a/c": 1},
{"k": 4, "a": 0},
{
"k": 5,
"x": data_store.Link("http://test.com/1.jpg"),
Expand Down
113 changes: 59 additions & 54 deletions client/tests/sdk/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
Text,
Image,
Binary,
JsonDict,
MIMEType,
Sequence,
BoundingBox,
Expand All @@ -48,7 +47,6 @@
from starwhale.core.dataset.copy import DatasetCopy
from starwhale.core.dataset.model import DatasetConfig, StandaloneDataset
from starwhale.api._impl.data_store import Link as DataStoreRawLink
from starwhale.api._impl.data_store import SwObject
from starwhale.core.dataset.tabular import (
CloudTDSC,
StandaloneTDSC,
Expand Down Expand Up @@ -248,51 +246,53 @@ def test_upload(self, rm: Mocker) -> None:
} in content["tableSchemaDesc"]["columnSchemaList"]

assert {
"type": "OBJECT",
"attributes": [
{"type": "LIST", "elementType": {"type": "INT64"}, "name": "box"},
{
"attributes": [
{"name": "as_mask", "type": "BOOL"},
{"name": "mask_uri", "type": "STRING"},
{"name": "_type", "type": "STRING"},
{"name": "display_name", "type": "STRING"},
{"name": "_mime_type", "type": "STRING"},
{
"type": "LIST",
"name": "shape",
"elementType": {"type": "UNKNOWN"},
"attributes": [{"index": 2, "type": "INT64"}],
},
{"name": "_dtype_name", "type": "STRING"},
{"name": "encoding", "type": "STRING"},
{
"attributes": [
{"name": "_type", "type": "STRING"},
{"name": "uri", "type": "STRING"},
{"name": "scheme", "type": "STRING"},
{"name": "offset", "type": "INT64"},
{"name": "size", "type": "INT64"},
{"name": "data_type", "type": "UNKNOWN"},
{
"keyType": {"type": "UNKNOWN"},
"name": "extra_info",
"type": "MAP",
"valueType": {"type": "UNKNOWN"},
},
],
"name": "link",
"pythonType": "starwhale.base.data_type.Link",
"type": "OBJECT",
},
],
"name": "mask",
"pythonType": "starwhale.base.data_type.Image",
"type": "OBJECT",
},
],
"pythonType": "starwhale.base.data_type.JsonDict",
"name": "features/seg",
"type": "MAP",
"keyType": {"type": "STRING"},
"valueType": {
"attributes": [
{"name": "as_mask", "type": "BOOL"},
{"name": "mask_uri", "type": "STRING"},
{"name": "_type", "type": "STRING"},
{"name": "display_name", "type": "STRING"},
{"name": "_mime_type", "type": "STRING"},
{
"attributes": [{"index": 2, "type": "INT64"}],
"elementType": {"type": "UNKNOWN"},
"name": "shape",
"type": "LIST",
},
{"name": "_dtype_name", "type": "STRING"},
{"name": "encoding", "type": "STRING"},
{
"attributes": [
{"name": "_type", "type": "STRING"},
{"name": "uri", "type": "STRING"},
{"name": "scheme", "type": "STRING"},
{"name": "offset", "type": "INT64"},
{"name": "size", "type": "INT64"},
{"name": "data_type", "type": "UNKNOWN"},
{
"keyType": {"type": "UNKNOWN"},
"name": "extra_info",
"type": "MAP",
"valueType": {"type": "UNKNOWN"},
},
],
"name": "link",
"pythonType": "starwhale.base.data_type.Link",
"type": "OBJECT",
},
],
"pythonType": "starwhale.base.data_type.Image",
"type": "OBJECT",
},
"sparseKeyValuePairSchema": {
"0": {
"keyType": {"type": "STRING"},
"valueType": {"elementType": {"type": "INT64"}, "type": "LIST"},
}
},
} in content["tableSchemaDesc"]["columnSchemaList"]
assert len(content["records"]) > 0

Expand Down Expand Up @@ -894,18 +894,23 @@ def test_copy(self) -> None:

def test_inner_json_dict(self) -> None:
info = TabularDatasetInfo(
{"int": 1, "dict": {"a": 1, "b": 2}, "list_dict": [{"a": 1}, {"a": 2}]}
{
"int": 1,
"dict": {"a": 1, "b": 2},
"list_dict": [{"a": 1}, {"a": 2}],
"dict with int key": {1: "a"},
},
)
assert info["int"] == 1
assert info["dict"] == {"a": 1, "b": 2}
assert info["list_dict"] == [{"a": 1}, {"a": 2}]
assert info["dict with int key"] == {1: "a"}

assert isinstance(info.data["dict"], SwObject)
assert isinstance(info.data["dict"], JsonDict)
assert info.data["dict"].__dict__ == {"a": 1, "b": 2}
assert isinstance(info.data["dict"], dict)
assert info.data["dict"] == {"a": 1, "b": 2}
assert isinstance(info.data["list_dict"], list)
assert isinstance(info.data["list_dict"][0], JsonDict)
assert info.data["list_dict"][0].__dict__ == {"a": 1}
assert isinstance(info.data["list_dict"][0], dict)
assert info.data["list_dict"][0] == {"a": 1}

def test_exceptions(self) -> None:
info = TabularDatasetInfo()
Expand Down Expand Up @@ -1081,7 +1086,7 @@ def test_row(self) -> None:

u_row_dict = u_row.asdict()
assert u_row_dict["features/a"] == 1
assert u_row_dict["features/b"] == JsonDict({"c": 1})
assert u_row_dict["features/b"] == {"c": 1}
assert l_row.asdict()["id"] == "path/1"

with self.assertRaises(FieldTypeOrValueError):
Expand Down Expand Up @@ -1274,7 +1279,7 @@ def test_rows_put_exception(self) -> None:
)
with self.assertRaisesRegex(RuntimeError, "RowPutThread raise exception"):
for i in range(0, 5):
mdb.put(DataRow(index=i, features={"a": {b"b": 1}}))
mdb.put(DataRow(index=i, features={"a": type("unknown")}))
mdb.flush()

@pytest.mark.filterwarnings("ignore::pytest.PytestUnhandledThreadExceptionWarning")
Expand Down
20 changes: 11 additions & 9 deletions client/tests/sdk/test_dataset_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,17 @@ def test_create_mode(self) -> None:
):
_ = dataset("mnist", create="not-option")

def test_no_support_type(self) -> None:
with self.assertRaisesRegex(
RuntimeError, "json like dict shouldn't have none-str key"
):
with dataset("no-support") as ds:
ds.append({"dict": {1: "a"}})
ds.append({"dict": {b"test": "b"}})
ds.append({"dict": {2.0: "c"}})
ds.flush()
def test_various_dicts(self) -> None:
with dataset("test") as ds:
ds.append({"dict": {1: "a"}})
ds.append({"dict": {b"test": "b"}})
ds.append({"dict": {2.0: "c"}})
ds.flush()

with dataset("test") as ds:
assert ds[0].features.dict == {1: "a"}
assert ds[1].features.dict == {b"test": "b"}
assert ds[2].features.dict == {2.0: "c"}

def test_append(self) -> None:
size = 11
Expand Down

0 comments on commit 86cbb23

Please sign in to comment.