From 679580860ec0fad06f0d8cb81ffd86aa202fa596 Mon Sep 17 00:00:00 2001 From: Chuan Xu Date: Fri, 12 Aug 2022 13:46:19 +0800 Subject: [PATCH] add support for RemoteDataStore and make LocalDataStore work with multi processing --- client/.flake8 | 2 +- client/starwhale/api/_impl/data_store.py | 563 +++++++++++++----- client/starwhale/api/_impl/wrapper.py | 4 +- client/tests/sdk/test_data_store.py | 723 +++++++++++++++++++---- client/tests/sdk/test_wrapper.py | 6 +- 5 files changed, 1011 insertions(+), 287 deletions(-) diff --git a/client/.flake8 b/client/.flake8 index 6f7cd1004d..764e1583a9 100644 --- a/client/.flake8 +++ b/client/.flake8 @@ -4,7 +4,7 @@ # E203: whitespace before ':' # E731: do not assign a lambda expression, use a def # W605: invalid escape sequence '\#' -ignore= H101,W503,E203,E731,W605 +ignore= H101,W503,E203,E731,W605,C901 max-line-length = 200 max-complexity = 18 exclude = diff --git a/client/starwhale/api/_impl/data_store.py b/client/starwhale/api/_impl/data_store.py index 5e3ecacad4..ec6d173297 100644 --- a/client/starwhale/api/_impl/data_store.py +++ b/client/starwhale/api/_impl/data_store.py @@ -3,19 +3,51 @@ import sys import json import atexit +import base64 +import struct +import urllib import pathlib +import binascii import threading -from typing import Any, Set, cast, Dict, List, Tuple, Iterator, Optional +from typing import Any, Set, cast, Dict, List, Tuple, Union, Iterator, Optional import numpy as np import pyarrow as pa # type: ignore +import requests import pyarrow.parquet as pq # type: ignore -from loguru import logger from typing_extensions import Protocol from starwhale.utils.fs import ensure_dir from starwhale.utils.config import SWCliConfigMixed +try: + import fcntl + + has_fcntl = True +except ImportError: + has_fcntl = False + + +def _check_move(src: str, dest: str) -> bool: + if has_fcntl: + with open(os.path.join(os.path.dirname(src), ".lock"), "w") as f: + try: + fcntl.flock(f, fcntl.LOCK_EX) # type: ignore + except OSError: + return False + try: + os.rename(src, dest) + return True + finally: + fcntl.flock(f, fcntl.LOCK_UN) # type: ignore + else: + # windows + try: + os.rename(src, dest) + return True + except FileExistsError: + return False + class Type: def __init__( @@ -32,34 +64,63 @@ def serialize(self, value: Any) -> Any: def deserialize(self, value: Any) -> Any: return value + def encode(self, value: Any) -> Optional[str]: + if value is None: + return None + if self is UNKNOWN: + return None + if self is BOOL: + if value: + return "1" + else: + return "0" + if self is STRING: + return cast(str, value) + if self is BYTES: + return base64.b64encode(value).decode() + if self.name == "int": + return f"{value:x}" + if self.name == "float": + if self.nbits == 16: + return binascii.hexlify(struct.pack(">e", value)).decode() + if self.nbits == 32: + return binascii.hexlify(struct.pack(">f", value)).decode() + if self.nbits == 64: + return binascii.hexlify(struct.pack(">d", value)).decode() + raise RuntimeError("invalid type " + str(self)) + + def decode(self, value: str) -> Any: + if value is None: + return None + if self is UNKNOWN: + return None + if self is BOOL: + return value == "1" + if self is STRING: + return value + if self is BYTES: + return base64.b64decode(value) + if self.name == "int": + return int(value, 16) + if self.name == "float": + raw = binascii.unhexlify(value) + if self.nbits == 16: + return struct.unpack(">e", raw)[0] + if self.nbits == 32: + return struct.unpack(">f", raw)[0] + if self.nbits == 64: + return struct.unpack(">d", raw)[0] + raise RuntimeError("invalid type " + str(self)) + def __str__(self) -> str: if self.name == "int" or self.name == "float": - return f"{self.name}{self.nbits}" + return f"{self.name}{self.nbits}".upper() else: - return self.name + return self.name.upper() __repr__ = __str__ -class LinkType(Type): - def __init__(self) -> None: - super().__init__("link", pa.string(), 32, "") - - def serialize(self, value: Any) -> Any: - if value is None: - return None - assert isinstance(value, Link) - return str(value) - - def deserialize(self, value: Any) -> Any: - if value is None: - return None - d = json.loads(value) - return Link( - d.get("uri", None), d.get("display_text", None), d.get("mime_type", None) - ) - - class Link: def __init__( self, @@ -89,7 +150,7 @@ def __eq__(self, other: Any) -> bool: ) -NONE = Type("none", None, 1, None) +UNKNOWN = Type("unknown", None, 1, None) INT8 = Type("int", pa.int8(), 8, 0) INT16 = Type("int", pa.int16(), 16, 0) INT32 = Type("int", pa.int32(), 32, 0) @@ -98,12 +159,11 @@ def __eq__(self, other: Any) -> bool: FLOAT32 = Type("float", pa.float32(), 32, 0.0) FLOAT64 = Type("float", pa.float64(), 64, 0.0) BOOL = Type("bool", pa.bool_(), 1, 0) -STRING = Type("str", pa.string(), 32, "") +STRING = Type("string", pa.string(), 32, "") BYTES = Type("bytes", pa.binary(), 32, b"") -LINK = LinkType() _TYPE_DICT: Dict[Any, Type] = { - type(None): NONE, + type(None): UNKNOWN, np.byte: INT8, np.int8: INT8, np.int16: INT16, @@ -118,7 +178,6 @@ def __eq__(self, other: Any) -> bool: bool: BOOL, str: STRING, bytes: BYTES, - Link: LINK, } _TYPE_NAME_DICT = {str(v): v for k, v in _TYPE_DICT.items()} @@ -141,6 +200,14 @@ def __eq__(self, other: Any) -> bool: ) +class TableSchemaDesc: + def __init__( + self, key_column: Optional[str], columns: Optional[List[ColumnSchema]] + ) -> None: + self.key_column = key_column + self.columns = columns + + class TableSchema: def __init__(self, key_column: str, columns: List[ColumnSchema]) -> None: self.key_column = key_column @@ -149,6 +216,35 @@ def __init__(self, key_column: str, columns: List[ColumnSchema]) -> None: def copy(self) -> "TableSchema": return TableSchema(self.key_column, list(self.columns.values())) + def merge(self, other: "TableSchema") -> None: + if self.key_column != other.key_column: + raise RuntimeError( + f"conflicting key column, expected {self.key_column}, acutal {other.key_column}" + ) + new_schema = {} + for col in other.columns.values(): + column_schema = self.columns.get(col.name, None) + if ( + column_schema is not None + and column_schema.type is not UNKNOWN + and col.type is not UNKNOWN + and col.type is not column_schema.type + and col.type.name != column_schema.type.name + ): + raise RuntimeError( + f"conflicting column type, name {col.name}, expected {column_schema.type}, actual {col.type}" + ) + if ( + column_schema is None + or column_schema.type is UNKNOWN + or ( + col.type is not UNKNOWN + and col.type.nbits > column_schema.type.nbits + ) + ): + new_schema[col.name] = col + self.columns.update(new_schema) + @staticmethod def parse(json_str: str) -> "TableSchema": d = json.loads(json_str) @@ -171,7 +267,7 @@ def __str__(self) -> str: } ) - __rept__ = __str__ + __repr__ = __str__ def __eq__(self, other: Any) -> bool: return ( @@ -208,6 +304,7 @@ def _scan_parquet_file( columns: Optional[Dict[str, str]] = None, start: Optional[Any] = None, end: Optional[Any] = None, + keep_none: bool = False, ) -> Iterator[dict]: f = pq.ParquetFile(path) schema_arrow = f.schema_arrow @@ -263,7 +360,12 @@ def _scan_parquet_file( if value is not None: d["-"] = value elif name.startswith("~") and value: - d.pop(columns.get(name[1:], ""), "") + alias = columns.get(name[1:], "") + if alias != "": + if keep_none: + d[alias] = None + else: + d.pop(alias, "") else: alias = columns.get(name, "") if alias != "" and value is not None: @@ -271,7 +373,9 @@ def _scan_parquet_file( yield d -def _merge_scan(iters: List[Iterator[Dict[str, Any]]]) -> Iterator[dict]: +def _merge_scan( + iters: List[Iterator[Dict[str, Any]]], keep_none: bool +) -> Iterator[dict]: class Node: def __init__(self, index: int, iter: Iterator[dict]) -> None: self.index = index @@ -291,10 +395,10 @@ def nextItem(self) -> None: self.key = "" nodes = [] - for _i, _iter in enumerate(iters): - _node = Node(_i, _iter) - if not _node.exhausted: - nodes.append(_node) + for i, iter in enumerate(iters): + node = Node(i, iter) + if not node.exhausted: + nodes.append(node) while len(nodes) > 0: key = min(nodes, key=lambda x: x.key).key @@ -312,17 +416,14 @@ def nextItem(self) -> None: nodes[i].nextItem() if len(d) > 0: d["*"] = key + if not keep_none: + d = {k: v for k, v in d.items() if v is not None} yield d nodes = [node for node in nodes if not node.exhausted] def _get_table_files(path: str) -> List[str]: - if not os.path.exists(path): - logger.warning(f"not find path {path} as table file path") - return [] - - if not os.path.isdir(path): - raise RuntimeError(f"{path} is not a directory") + ensure_dir(path) patches = [] base_index = -1 @@ -344,10 +445,7 @@ def _get_table_files(path: str) -> List[str]: def _read_table_schema(path: str) -> TableSchema: - if not os.path.exists(path): - raise RuntimeError(f"path not found: {path}") - if not os.path.isdir(path): - raise RuntimeError(f"{path} is not a directory") + ensure_dir(path) files = _get_table_files(path) if len(files) == 0: @@ -369,24 +467,16 @@ def _scan_table( columns: Optional[Dict[str, str]] = None, start: Optional[Any] = None, end: Optional[Any] = None, - explicit_none: bool = False, + keep_none: bool = False, ) -> Iterator[dict]: iters = [] for file in _get_table_files(path): - iters.append(_scan_parquet_file(file, columns, start, end)) - column_names = [] - if len(iters) > 0: - schema = _read_table_schema(path) - column_names = [ - col.name - for col in schema.columns.values() - if col.name != "-" and not col.name.startswith("~") - ] - for record in _merge_scan(iters): - if explicit_none: - for col in column_names: - record.setdefault(col, None) - yield record + if os.path.basename(file).startswith("patch"): + keep = True + else: + keep = keep_none + iters.append(_scan_parquet_file(file, columns, start, end, keep)) + return _merge_scan(iters, keep_none) def _records_to_table( @@ -442,8 +532,8 @@ def _update_schema(schema: TableSchema, record: Dict[str, Any]) -> TableSchema: column_schema = schema.columns.get(col, None) if ( column_schema is not None - and column_schema.type is not NONE - and value_type is not NONE + and column_schema.type is not UNKNOWN + and value_type is not UNKNOWN and value_type is not column_schema.type and value_type.name != column_schema.type.name ): @@ -452,9 +542,9 @@ def _update_schema(schema: TableSchema, record: Dict[str, Any]) -> TableSchema: ) if column_schema is None: new_schema.columns[col] = ColumnSchema(col, value_type) - elif column_schema.type is NONE: + elif column_schema.type is UNKNOWN: new_schema.columns[col].type = value_type - elif value_type is not NONE and value_type.nbits > column_schema.type.nbits: + elif value_type is not UNKNOWN and value_type.nbits > column_schema.type.nbits: new_schema.columns[col].type = value_type return new_schema @@ -472,23 +562,12 @@ def get_schema(self) -> TableSchema: with self.lock: return self.schema.copy() - def load(self, root_path: str) -> None: - for record in _scan_table(_get_table_path(root_path, self.table_name)): - key = record.get(self.schema.key_column, None) - actual = record.pop("*") - if record.get(self.schema.key_column, None) != key: - raise RuntimeError( - f"failed to load table {self.table_name}: key column={self.schema.key_column}, expected key:{key}, actual key:{actual}" - ) - record.pop("-", None) - self.records[key] = record - def scan( self, columns: Optional[Dict[str, str]] = None, start: Optional[Any] = None, end: Optional[Any] = None, - explicit_none: bool = False, + keep_none: bool = False, ) -> Iterator[Dict[str, Any]]: with self.lock: schema = self.schema.copy() @@ -503,16 +582,13 @@ def scan( records.sort(key=lambda x: cast(str, x[self.schema.key_column])) for r in records: if columns is None: - d = r + d = dict(r) else: d = {columns[k]: v for k, v in r.items() if k in columns} if "-" in r: d["-"] = r["-"] - d["*"] = r[self.schema.key_column] - if explicit_none: - for col in schema.columns.values(): - d.setdefault(col.name, None) - else: + d["*"] = r[schema.key_column] + if not keep_none: d = {k: v for k, v in d.items() if v is not None} yield d @@ -538,26 +614,86 @@ def delete(self, keys: List[Any]) -> None: self.nbytes -= _get_size(r) def dump(self, root_path: str) -> None: + with self.lock: + schema = self.schema.copy() path = _get_table_path(root_path, self.table_name) ensure_dir(path) + while True: + max_index = -1 + for file in os.listdir(path): + type, index = _parse_parquet_name(file) + if type != "" and index > max_index: + max_index = index + if max_index < 0: + filename = "base-0.parquet" + else: + filename = f"base-{max_index + 1}.parquet" + temp_filename = f"temp.{os.getpid()}" + if max_index >= 0: + s = _read_table_schema(os.path.join(path)) + s.merge(schema) + schema = s + _write_parquet_file( + os.path.join(path, temp_filename), + _records_to_table( + schema, + list( + _merge_scan( + [ + _scan_table(path, keep_none=True), + self.scan(keep_none=True), + ], + True, + ) + ), + [], + ), + ) + if _check_move( + os.path.join(path, temp_filename), os.path.join(path, filename) + ): + break - max_index = -1 - for file in os.listdir(path): - type, index = _parse_parquet_name(file) - if type != "" and index > max_index: - max_index = index - if max_index < 0: - filename = "base-0.parquet" - else: - filename = f"base-{max_index + 1}.parquet" - _write_parquet_file( - os.path.join(path, filename), - _records_to_table( - self.schema, - list(self.records.values()), - list(self.deletes), - ), - ) + +class TableDesc: + def __init__( + self, + table_name: str, + columns: Union[Dict[str, str], List[str], None] = None, + keep_none: bool = False, + ) -> None: + self.table_name = table_name + self.columns: Optional[Dict[str, str]] = None + self.keep_none = keep_none + if columns is not None: + self.columns = {} + if isinstance(columns, dict): + alias_map: Dict[str, str] = {} + for col, alias in columns.items(): + key = alias_map.setdefault(alias, col) + if key != col: + raise RuntimeError( + f"duplicate alias {alias} for column {col} and {key}" + ) + self.columns = columns + else: + for col in columns: + if col in self.columns: + raise RuntimeError(f"duplicate column name {col}") + self.columns[col] = col + + def to_dict(self) -> Dict[str, Any]: + ret: Dict[str, Any] = { + "tableName": self.table_name, + } + if self.columns is not None: + ret["columns"] = [ + {"columnName": col, "alias": alias} + for col, alias in self.columns.items() + ] + if self.keep_none: + ret["keepNone"] = True + return ret class LocalDataStore: @@ -581,8 +717,11 @@ def __init__(self, root_path: str) -> None: self.name_pattern = re.compile(r"^[A-Za-z0-9-_/]+$") self.tables: Dict[str, MemoryTable] = {} - def put( - self, table_name: str, schema: TableSchema, records: List[Dict[str, Any]] + def update_table( + self, + table_name: str, + schema: TableSchema, + records: List[Dict[str, Any]], ) -> None: if self.name_pattern.match(table_name) is None: raise RuntimeError( @@ -590,14 +729,19 @@ def put( ) for r in records: for k in r.keys(): - if self.name_pattern.match(k) is None: + if k != "-" and self.name_pattern.match(k) is None: raise RuntimeError( f"invalid column name {k}, only letters(A-Z, a-z), digits(0-9), hyphen('-'), and underscore('_') are allowed" ) table = self.tables.get(table_name, None) if table is None: - table = MemoryTable(table_name, schema) - table.load(self.root_path) + table_path = _get_table_path(self.root_path, table_name) + if _get_table_files(table_path): + table_schema = _read_table_schema(table_path) + table_schema.merge(schema) + else: + table_schema = schema + table = MemoryTable(table_name, table_schema) self.tables[table_name] = table if schema.key_column != table.schema.key_column: raise RuntimeError( @@ -609,11 +753,6 @@ def put( raise RuntimeError( f"key {schema.key_column} should not be none, record: {r.keys()}" ) - for name in r.keys(): - if self.name_pattern.match(name) is None: - raise RuntimeError( - f"invalid column name {name}, only letters(A-Z, a-z), digits(0-9), hyphen('-'), and underscore('_') are allowed" - ) if "-" in r: table.delete([key]) else: @@ -621,60 +760,41 @@ def put( def scan_tables( self, - tables: List[Tuple[str, str, bool]], - columns: Optional[Dict[str, str]] = None, + tables: List[TableDesc], start: Optional[Any] = None, end: Optional[Any] = None, + keep_none: bool = False, ) -> Iterator[Dict[str, Any]]: - # check for alias duplications - if columns is not None: - alias_map: Dict[str, str] = {} - for col, alias in columns.items(): - key = alias_map.setdefault(alias, col) - if key != col: - raise RuntimeError( - f"duplicate alias {alias} for column {col} and {key}" - ) - class TableInfo: def __init__( self, name: str, key_column_type: pa.DataType, columns: Optional[Dict[str, str]], - explicit_none: bool, - path: str, + keep_none: bool, ) -> None: self.name = name self.key_column_type = key_column_type self.columns = columns - self.explicit_none = explicit_none - self.path = path + self.keep_none = keep_none infos: List[TableInfo] = [] - for table_name, table_alias, explicit_none in tables: - table_path = _get_table_path(self.root_path, table_name) - table = self.tables.get(table_name, None) - if table is None: - schema = _read_table_schema(table_path) - else: + for table_desc in tables: + table = self.tables.get(table_desc.table_name, None) + if table is not None: schema = table.get_schema() - key_column_type = schema.columns[schema.key_column].type.pa_type - column_names = schema.columns.keys() - col_prefix = table_alias + "." - - cols: Optional[Dict[str, str]] - if columns is None or col_prefix + "*" in columns: - cols = None else: - cols = {} - for name in column_names: - alias = columns.get(name, "") - alias = columns.get(col_prefix + name, alias) - if alias != "": - cols[name] = alias + schema = _read_table_schema( + _get_table_path(self.root_path, table_desc.table_name) + ) + key_column_type = schema.columns[schema.key_column].type.pa_type infos.append( - TableInfo(table_name, key_column_type, cols, explicit_none, table_path) + TableInfo( + table_desc.table_name, + key_column_type, + table_desc.columns, + table_desc.keep_none, + ) ) # check for key type conflictions @@ -687,20 +807,46 @@ def __init__( f"{info.name} has a key of type {info.key_column_type}," f" while {infos[0].name} has a key of type {infos[0].key_column_type}" ) + iters = [] for info in infos: + table_path = _get_table_path(self.root_path, info.name) if info.name in self.tables: - iters.append( - self.tables[info.name].scan( - info.columns, start, end, info.explicit_none + if _get_table_files(table_path): + iters.append( + _merge_scan( + [ + _scan_table( + table_path, + info.columns, + start, + end, + info.keep_none, + ), + self.tables[info.name].scan( + info.columns, start, end, True + ), + ], + info.keep_none, + ) + ) + else: + iters.append( + self.tables[info.name].scan( + info.columns, start, end, info.keep_none + ) ) - ) else: iters.append( - _scan_table(info.path, info.columns, start, end, info.explicit_none) + _scan_table( + table_path, + info.columns, + start, + end, + info.keep_none, + ) ) - - for record in _merge_scan(iters): + for record in _merge_scan(iters, keep_none): record.pop("*", None) yield record @@ -710,33 +856,122 @@ def dump(self) -> None: class RemoteDataStore: - def put( - self, table_name: str, schema: TableSchema, records: List[Dict[str, Any]] + def __init__(self, instance_uri: str) -> None: + self.instance_uri = instance_uri + self.token = os.getenv("SW_TOKEN") + if self.token is None: + raise RuntimeError("SW_TOKEN is not found in environment") + + def update_table( + self, + table_name: str, + schema: TableSchema, + records: List[Dict[str, Any]], ) -> None: - ... + data: Dict[str, Any] = {"tableName": table_name} + schema_data: Dict[str, Any] = { + "keyColumn": schema.key_column, + "columnSchemaList": [ + {"name": col.name, "type": str(col.type)} + for col in schema.columns.values() + ], + } + data["tableSchemaDesc"] = schema_data + if records is not None: + encoded: List[Dict[str, List[Dict[str, Optional[str]]]]] = [] + for record in records: + r: List[Dict[str, Optional[str]]] = [] + for k, v in record.items(): + if k == "-": + r.append({"key": "-", "value": "1"}) + else: + r.append({"key": k, "value": schema.columns[k].type.encode(v)}) + encoded.append({"values": r}) + data["records"] = encoded + + assert self.token is not None + resp = requests.post( + urllib.parse.urljoin(self.instance_uri, "/api/v1/datastore/updateTable"), + data=json.dumps(data, separators=(",", ":")), + headers={ + "Content-Type": "application/json; charset=utf-8", + "Authorization": self.token, + }, + timeout=5.0, + ) + resp.raise_for_status() def scan_tables( self, - tables: List[Tuple[str, str, bool]], - columns: Optional[Dict[str, str]] = None, + tables: List[TableDesc], start: Optional[Any] = None, end: Optional[Any] = None, + keep_none: bool = False, ) -> Iterator[Dict[str, Any]]: - ... + post_data: Dict[str, Any] = {"tables": [table.to_dict() for table in tables]} + key_type = _get_type(start) + assert key_type is not None + if end is not None: + post_data["end"] = key_type.encode(end) + if start is not None: + post_data["start"] = key_type.encode(start) + post_data["limit"] = 1000 + if keep_none: + post_data["keepNone"] = True + assert self.token is not None + while True: + resp = requests.post( + urllib.parse.urljoin(self.instance_uri, "/api/v1/datastore/scanTable"), + data=json.dumps(post_data, separators=(",", ":")), + headers={ + "Content-Type": "application/json; charset=utf-8", + "Authorization": self.token, + }, + timeout=5.0, + ) + resp.raise_for_status() + resp_json: Dict[str, Any] = resp.json() + records = resp_json.get("records", None) + if records is None or len(records) == 0: + break + if "columnTypes" not in resp_json: + raise RuntimeError("no column types in response") + column_types = { + col: _TYPE_NAME_DICT[type] + for col, type in resp_json["columnTypes"].items() + } + for record in records: + r = {} + for k, v in record.items(): + col_type = column_types.get(k, None) + if col_type is None: + raise RuntimeError( + f"unknown type for column {k}, record={record}" + ) + r[k] = col_type.decode(v) + yield r + if len(records) == 1000: + post_data["start"] = resp_json["lastKey"] + post_data["startInclusive"] = False + else: + break class DataStore(Protocol): - def put( - self, table_name: str, schema: TableSchema, records: List[Dict[str, Any]] + def update_table( + self, + table_name: str, + schema: TableSchema, + records: List[Dict[str, Any]], ) -> None: ... def scan_tables( self, - tables: List[Tuple[str, str, bool]], - columns: Optional[Dict[str, str]] = None, + tables: List[TableDesc], start: Optional[Any] = None, end: Optional[Any] = None, + keep_none: bool = False, ) -> Iterator[Dict[str, Any]]: ... @@ -746,7 +981,7 @@ def get_data_store() -> DataStore: if instance is None or instance == "local": return LocalDataStore.get_instance() else: - return RemoteDataStore() + return RemoteDataStore(instance) def _flatten(record: Dict[str, Any]) -> Dict[str, Any]: @@ -823,4 +1058,4 @@ def run(self) -> None: break records = self.records self.records = [] - self.data_store.put(self.table_name, self.schema, records) + self.data_store.update_table(self.table_name, self.schema, records) diff --git a/client/starwhale/api/_impl/wrapper.py b/client/starwhale/api/_impl/wrapper.py index ed52b72788..fccf728e3f 100644 --- a/client/starwhale/api/_impl/wrapper.py +++ b/client/starwhale/api/_impl/wrapper.py @@ -69,7 +69,7 @@ def log_metrics( def get_results(self) -> Iterator[Dict[str, Any]]: return self._data_store.scan_tables( - [(self._results_table_name, "result", False)] + [data_store.TableDesc(self._results_table_name)] ) @@ -95,7 +95,7 @@ def put(self, data_id: Union[int, str], **kwargs: Any) -> None: def scan(self, start: Any, end: Any) -> Iterator[Dict[str, Any]]: return self._data_store.scan_tables( - [(self._meta_table_name, "meta", False)], start=start, end=end + [data_store.TableDesc(self._meta_table_name)], start=start, end=end ) def __str__(self) -> str: diff --git a/client/tests/sdk/test_data_store.py b/client/tests/sdk/test_data_store.py index df9eb04e82..8f291afd14 100644 --- a/client/tests/sdk/test_data_store.py +++ b/client/tests/sdk/test_data_store.py @@ -1,6 +1,8 @@ import os +import json import unittest from typing import Dict, List +from unittest.mock import Mock, patch import numpy as np import pyarrow as pa # type: ignore @@ -52,7 +54,6 @@ def test_write_and_scan(self) -> None: "a": [0, 1, 2], "b": ["x", "y", "z"], "c": [10, 11, 12], - "d": [None, None, str(data_store.Link(""))], "-": [None, True, None], "~c": [False, False, True], }, @@ -64,7 +65,6 @@ def test_write_and_scan(self) -> None: data_store.ColumnSchema("a", data_store.INT64), data_store.ColumnSchema("b", data_store.STRING), data_store.ColumnSchema("c", data_store.INT64), - data_store.ColumnSchema("d", data_store.LINK), data_store.ColumnSchema("-", data_store.BOOL), data_store.ColumnSchema("~c", data_store.BOOL), ], @@ -77,7 +77,7 @@ def test_write_and_scan(self) -> None: [ {"*": 0, "a": 0, "b": "x", "c": 10}, {"*": 1, "a": 1, "b": "y", "c": 11, "-": True}, - {"*": 2, "a": 2, "b": "z", "d": data_store.Link("")}, + {"*": 2, "a": 2, "b": "z"}, ], list(data_store._scan_parquet_file(path)), "scan all", @@ -142,12 +142,25 @@ def test_write_and_scan(self) -> None: ), "with start and end", ) + self.assertEqual( + [ + {"*": 0, "a": 0, "b": "x", "c": 10}, + {"*": 1, "a": 1, "b": "y", "c": 11, "-": True}, + {"*": 2, "a": 2, "b": "z", "c": None}, + ], + list(data_store._scan_parquet_file(path, keep_none=True)), + "keep none", + ) def test_merge_scan(self) -> None: - self.assertEqual([], list(data_store._merge_scan([])), "no iter") + self.assertEqual([], list(data_store._merge_scan([], False)), "no iter") self.assertEqual( [{"*": 0, "a": 0}, {"*": 1, "a": 1}], - list(data_store._merge_scan([iter([{"*": 0, "a": 0}, {"*": 1, "a": 1}])])), + list( + data_store._merge_scan( + [iter([{"*": 0, "a": 0}, {"*": 1, "a": 1}])], False + ) + ), "one iter - ignore none", ) self.assertEqual( @@ -157,7 +170,8 @@ def test_merge_scan(self) -> None: [ iter([{"*": 0, "a": 0}, {"*": 2, "a": 2}]), iter([{"*": 1, "a": 1}, {"*": 3, "a": 3}]), - ] + ], + False, ) ), "two iters", @@ -169,7 +183,8 @@ def test_merge_scan(self) -> None: [ iter([{"*": 0, "a": 0}, {"*": 1, "a": 1}]), iter([{"*": 2, "a": 2}, {"*": 3, "a": 3}]), - ] + ], + False, ) ), "two iters without range overlap", @@ -188,7 +203,8 @@ def test_merge_scan(self) -> None: {"*": 3, "a": 3}, ] ), - ] + ], + False, ) ), "0 and 4", @@ -202,6 +218,7 @@ def test_merge_scan(self) -> None: iter([{"*": 0, "a": 0}, {"*": 3, "a": 3}]), iter([{"*": 2, "a": 2}]), ], + False, ) ), "1 and 3", @@ -215,10 +232,65 @@ def test_merge_scan(self) -> None: iter([{"*": 1, "-": True}, {"*": 1, "b": 2}]), iter([{"*": 1, "c": 3, "-": False}]), ], + False, ) ), "removal", ) + self.assertEqual( + [ + {"*": 0, "a": "0"}, + {"*": 1, "a": "1"}, + {"*": 3}, + ], + list( + data_store._merge_scan( + [ + iter( + [ + {"*": 1, "a": "1"}, + {"*": 3, "a": "3"}, + ] + ), + iter( + [ + {"*": 0, "a": "0"}, + {"*": 3, "a": None}, + ] + ), + ], + False, + ) + ), + "keep none 1", + ) + self.assertEqual( + [ + {"*": 0, "a": "0"}, + {"*": 1, "a": "1"}, + {"*": 3, "a": None}, + ], + list( + data_store._merge_scan( + [ + iter( + [ + {"*": 1, "a": "1"}, + {"*": 3, "a": "3"}, + ] + ), + iter( + [ + {"*": 0, "a": "0"}, + {"*": 3, "a": None}, + ] + ), + ], + True, + ) + ), + "keep none 2", + ) self.assertEqual( [ {"*": 0, "a": "0", "b": "0"}, @@ -262,6 +334,7 @@ def test_merge_scan(self) -> None: ] ), ], + True, ) ), "mixed", @@ -413,6 +486,7 @@ def test_scan_table(self) -> None: "k": [0, 1, 2, 3], "a": [None, None, None, "3"], "b": ["0", "1", "2", "3"], + "~b": [False, False, False, True], }, metadata={ "schema": str( @@ -422,6 +496,7 @@ def test_scan_table(self) -> None: data_store.ColumnSchema("k", data_store.INT64), data_store.ColumnSchema("a", data_store.STRING), data_store.ColumnSchema("b", data_store.STRING), + data_store.ColumnSchema("~b", data_store.BOOL), ], ) ) @@ -438,7 +513,7 @@ def test_scan_table(self) -> None: {"*": 0, "k": 0, "a": "0", "b": "0"}, {"*": 1, "k": 1, "a": "1", "b": "1"}, {"*": 2, "k": 2, "b": "2"}, - {"*": 3, "k": 3, "a": "3", "b": "3"}, + {"*": 3, "k": 3, "a": "3"}, {"*": 5, "k": 5, "a": "5.5"}, ], list(data_store._scan_table(self.datastore_root)), @@ -449,14 +524,14 @@ def test_scan_table(self) -> None: {"*": 0, "i": "0", "j": "0"}, {"*": 1, "i": "1", "j": "1"}, {"*": 2, "j": "2"}, - {"*": 3, "i": "3", "j": "3"}, + {"*": 3, "i": "3"}, {"*": 5, "i": "5.5"}, ], list(data_store._scan_table(self.datastore_root, {"a": "i", "b": "j"})), "some columns", ) self.assertEqual( - [{"*": 2, "j": "2"}, {"*": 3, "i": "3", "j": "3"}], + [{"*": 2, "j": "2"}, {"*": 3, "i": "3"}], list( data_store._scan_table( self.datastore_root, {"a": "i", "b": "j"}, start=2, end=5 @@ -468,12 +543,12 @@ def test_scan_table(self) -> None: [ {"*": 0, "k": 0, "a": "0", "b": "0"}, {"*": 1, "k": 1, "a": "1", "b": "1"}, - {"*": 2, "k": 2, "a": None, "b": "2"}, - {"*": 3, "k": 3, "a": "3", "b": "3"}, - {"*": 5, "k": 5, "a": "5.5", "b": None}, + {"*": 2, "k": 2, "b": "2"}, + {"*": 3, "k": 3, "a": "3", "b": None}, + {"*": 5, "k": 5, "a": "5.5"}, ], - list(data_store._scan_table(self.datastore_root, explicit_none=True)), - "explicit none", + list(data_store._scan_table(self.datastore_root, keep_none=True)), + "keep none", ) def test_update_schema(self) -> None: @@ -545,18 +620,18 @@ def test_update_schema(self) -> None: ) self.assertEqual( data_store.TableSchema( - "a", [data_store.ColumnSchema("a", data_store.NONE)] + "a", [data_store.ColumnSchema("a", data_store.UNKNOWN)] ), data_store._update_schema(data_store.TableSchema("a", []), {"a": None}), "none 1", ) self.assertEqual( data_store.TableSchema( - "a", [data_store.ColumnSchema("a", data_store.NONE)] + "a", [data_store.ColumnSchema("a", data_store.UNKNOWN)] ), data_store._update_schema( data_store.TableSchema( - "a", [data_store.ColumnSchema("a", data_store.NONE)] + "a", [data_store.ColumnSchema("a", data_store.UNKNOWN)] ), {"a": None}, ), @@ -568,7 +643,7 @@ def test_update_schema(self) -> None: ), data_store._update_schema( data_store.TableSchema( - "a", [data_store.ColumnSchema("a", data_store.NONE)] + "a", [data_store.ColumnSchema("a", data_store.UNKNOWN)] ), {"a": 0}, ), @@ -601,77 +676,45 @@ def test_mixed(self) -> None: table.insert({"k": 2, "a": "2"}) table.insert({"k": 3, "a": "3"}) table.insert({"k": 1, "b": "1"}) - table.insert({"k": 4, "x": data_store.Link("t")}) table.delete([2]) + table.insert({"k": 1, "a": None}) self.assertEqual( [ {"*": 0, "k": 0, "a": "0"}, - {"*": 1, "k": 1, "a": "1", "b": "1"}, + {"*": 1, "k": 1, "b": "1"}, {"*": 2, "k": 2, "-": True}, {"*": 3, "k": 3, "a": "3"}, - {"*": 4, "k": 4, "x": data_store.Link("t")}, ], list(table.scan()), "scan all", ) self.assertEqual( [ - {"*": 0, "k": 0, "a": "0", "b": None, "x": None}, - {"*": 1, "k": 1, "a": "1", "b": "1", "x": None}, - {"*": 2, "k": 2, "a": None, "b": None, "x": None, "-": True}, - {"*": 3, "k": 3, "a": "3", "b": None, "x": None}, - {"*": 4, "k": 4, "a": None, "b": None, "x": data_store.Link("t")}, + {"*": 0, "k": 0, "a": "0"}, + {"*": 1, "k": 1, "a": None, "b": "1"}, + {"*": 2, "k": 2, "-": True}, + {"*": 3, "k": 3, "a": "3"}, ], - list(table.scan(explicit_none=True)), - "explicit none", + list(table.scan(keep_none=True)), + "keep none", ) self.assertEqual( [ {"*": 0, "k": 0, "x": "0"}, - {"*": 1, "k": 1, "x": "1"}, + {"*": 1, "k": 1}, {"*": 2, "k": 2, "-": True}, {"*": 3, "k": 3, "x": "3"}, - {"*": 4, "k": 4}, ], list(table.scan({"k": "k", "a": "x"})), "some columns", ) - table.dump(self.datastore_root) - self.assertEqual( - [os.path.join(self.datastore_root, "test", "base-0.parquet")], - data_store._get_table_files(os.path.join(self.datastore_root, "test")), - "dump 1", - ) - table.dump(self.datastore_root) - self.assertEqual( - [os.path.join(self.datastore_root, "test", "base-1.parquet")], - data_store._get_table_files(os.path.join(self.datastore_root, "test")), - "dump 2", - ) - table = data_store.MemoryTable( - "test", - data_store.TableSchema( - "k", [data_store.ColumnSchema("k", data_store.INT64)] - ), - ) - table.load(self.datastore_root) - self.assertEqual( - [ - {"*": 0, "k": 0, "a": "0"}, - {"*": 1, "k": 1, "a": "1", "b": "1"}, - {"*": 3, "k": 3, "a": "3"}, - {"*": 4, "k": 4, "x": data_store.Link("t")}, - ], - list(table.scan()), - "load", - ) class TestLocalDataStore(BaseTestCase): - def test_data_store_put(self) -> None: + def test_data_store_update_table(self) -> None: ds = data_store.LocalDataStore(self.datastore_root) with self.assertRaises(RuntimeError, msg="invalid column name"): - ds.put( + ds.update_table( "test", data_store.TableSchema( "+", [data_store.ColumnSchema("+", data_store.INT64)] @@ -679,14 +722,14 @@ def test_data_store_put(self) -> None: [{"+": 0}], ) with self.assertRaises(RuntimeError, msg="no key field"): - ds.put( + ds.update_table( "test", data_store.TableSchema( "k", [data_store.ColumnSchema("k", data_store.INT64)] ), [{"a": 0}], ) - ds.put( + ds.update_table( "project/a_b/eval/test-0", data_store.TableSchema( "k", @@ -700,10 +743,14 @@ def test_data_store_put(self) -> None: ) self.assertEqual( [{"k": 0, "a": "0", "b": "0"}, {"k": 1, "a": "1"}], - list(ds.scan_tables([("project/a_b/eval/test-0", "test", False)])), + list( + ds.scan_tables( + [data_store.TableDesc("project/a_b/eval/test-0", None, False)] + ) + ), "name check", ) - ds.put( + ds.update_table( "test", data_store.TableSchema( "k", @@ -717,10 +764,10 @@ def test_data_store_put(self) -> None: ) self.assertEqual( [{"k": 0, "a": "0", "b": "0"}, {"k": 1, "a": "1"}], - list(ds.scan_tables([("test", "test", False)])), + list(ds.scan_tables([data_store.TableDesc("test", None, False)])), "base", ) - ds.put( + ds.update_table( "test", data_store.TableSchema( "k", @@ -743,10 +790,23 @@ def test_data_store_put(self) -> None: {"k": 2, "b": "2"}, {"k": 3, "a": "3", "b": "3"}, ], - list(ds.scan_tables([("test", "test", False)])), + list(ds.scan_tables([data_store.TableDesc("test", None, False)])), "batch+patch", ) - ds.put( + self.assertEqual( + [ + {"k": 0, "a": "0", "b": "0"}, + {"k": 2, "a": None, "b": "2"}, + {"k": 3, "a": "3", "b": "3"}, + ], + list( + ds.scan_tables( + [data_store.TableDesc("test", None, True)], keep_none=True + ) + ), + "batch+patch keep none", + ) + ds.update_table( "test", data_store.TableSchema( "k", @@ -769,34 +829,13 @@ def test_data_store_put(self) -> None: {"k": 2, "b": "2"}, {"k": 3, "a": "33", "b": "3", "c": 3}, ], - list(ds.scan_tables([("test", "test", False)])), + list(ds.scan_tables([data_store.TableDesc("test", None, False)])), "overwrite", ) - ds.put( - "test", - data_store.TableSchema( - "k", - [ - data_store.ColumnSchema("k", data_store.INT64), - data_store.ColumnSchema("x", data_store.LINK), - ], - ), - [{"k": 4, "x": data_store.Link("tt", "a", "b")}], - ) - self.assertEqual( - [ - {"k": 1, "a": "1", "b": "1"}, - {"k": 2, "b": "2"}, - {"k": 3, "a": "33", "b": "3", "c": 3}, - {"k": 4, "x": data_store.Link("tt", "a", "b")}, - ], - list(ds.scan_tables([("test", "test", False)])), - "link", - ) def test_data_store_scan(self) -> None: ds = data_store.LocalDataStore(self.datastore_root) - ds.put( + ds.update_table( "1", data_store.TableSchema( "k", @@ -813,7 +852,7 @@ def test_data_store_scan(self) -> None: {"k": 3, "a": "3", "b": "3"}, ], ) - ds.put( + ds.update_table( "2", data_store.TableSchema( "a", @@ -829,7 +868,7 @@ def test_data_store_scan(self) -> None: {"a": 3, "b": "3"}, ], ) - ds.put( + ds.update_table( "3", data_store.TableSchema( "a", @@ -845,17 +884,17 @@ def test_data_store_scan(self) -> None: {"a": 3, "x": "3"}, ], ) - ds.put( + ds.update_table( "4", data_store.TableSchema( "x", [ - data_store.ColumnSchema("b", data_store.STRING), + data_store.ColumnSchema("x", data_store.STRING), ], ), [{"x": "0"}, {"x": "1"}, {"x": "2"}, {"x": "3"}], ) - ds.put( + ds.update_table( "5", data_store.TableSchema( "a", @@ -869,9 +908,18 @@ def test_data_store_scan(self) -> None: with open(os.path.join(self.datastore_root, "6"), "w"): pass with self.assertRaises(RuntimeError, msg="duplicate alias"): - list(ds.scan_tables([("1", "1", False)], {"k": "v", "a": "v"})) + list( + ds.scan_tables([data_store.TableDesc("1", {"k": "v", "a": "v"}, False)]) + ) with self.assertRaises(RuntimeError, msg="conflicting key type"): - list(ds.scan_tables([("1", "1", False), ("4", "4", False)])) + list( + ds.scan_tables( + [ + data_store.TableDesc("1", None, False), + data_store.TableDesc("4", None, False), + ] + ) + ) self.assertEqual( [ {"k": 0, "a": "0", "b": "0"}, @@ -879,7 +927,7 @@ def test_data_store_scan(self) -> None: {"k": 2, "b": "2"}, {"k": 3, "a": "3", "b": "3"}, ], - list(ds.scan_tables([("1", "1", False)])), + list(ds.scan_tables([data_store.TableDesc("1", None, False)])), "scan all", ) self.assertEqual( @@ -892,10 +940,10 @@ def test_data_store_scan(self) -> None: list( ds.scan_tables( [ - ("1", "1", False), - ("2", "2", False), - ("3", "3", False), - ("5", "5", False), + data_store.TableDesc("1", None, False), + data_store.TableDesc("2", None, False), + data_store.TableDesc("3", None, False), + data_store.TableDesc("5", None, False), ] ) ), @@ -911,12 +959,15 @@ def test_data_store_scan(self) -> None: list( ds.scan_tables( [ - ("1", "1", False), - ("2", "2", False), - ("3", "3", False), - ("5", "5", False), + data_store.TableDesc("1", {"a": "a", "b": "c"}, False), + data_store.TableDesc("2", {"a": "a", "b": "c"}, False), + data_store.TableDesc( + "3", {"a": "a", "b": "c", "x": "x"}, False + ), + data_store.TableDesc( + "5", {"a": "a", "b": "c", "x": "y"}, False + ), ], - {"a": "a", "b": "c", "5.x": "y", "3.*": ""}, ) ), "some columns", @@ -926,12 +977,15 @@ def test_data_store_scan(self) -> None: list( ds.scan_tables( [ - ("1", "1", False), - ("2", "2", False), - ("3", "3", False), - ("5", "5", False), + data_store.TableDesc("1", {"a": "a", "b": "c"}, False), + data_store.TableDesc("2", {"a": "a", "b": "c"}, False), + data_store.TableDesc( + "3", {"a": "a", "b": "c", "x": "x"}, False + ), + data_store.TableDesc( + "5", {"a": "a", "b": "c", "x": "y"}, False + ), ], - {"a": "a", "b": "c", "5.x": "y", "3.*": ""}, 1, 2, ) @@ -939,6 +993,441 @@ def test_data_store_scan(self) -> None: "with start and end", ) + ds.update_table( + "1", + data_store.TableSchema( + "k", + [ + data_store.ColumnSchema("k", data_store.INT64), + data_store.ColumnSchema("a", data_store.STRING), + ], + ), + [ + {"k": 0, "a": None}, + ], + ) + ds.dump() + ds = data_store.LocalDataStore(self.datastore_root) + self.assertEqual( + [ + {"k": 0, "a": None, "b": "0"}, + {"k": 1, "a": "1", "b": "1"}, + {"k": 2, "b": "2"}, + {"k": 3, "a": "3", "b": "3"}, + ], + list( + ds.scan_tables( + [ + data_store.TableDesc("1", None, True), + ], + keep_none=True, + ) + ), + "scan disk", + ) + + ds.update_table( + "1", + data_store.TableSchema( + "k", + [ + data_store.ColumnSchema("k", data_store.INT64), + data_store.ColumnSchema("c", data_store.INT32), + ], + ), + [ + {"k": 0, "c": 1}, + {"k": 1, "-": True}, + ], + ) + self.assertEqual( + [ + {"k": 0, "a": None, "b": "0", "c": 1}, + {"k": 2, "b": "2"}, + {"k": 3, "a": "3", "b": "3"}, + ], + list( + ds.scan_tables( + [ + data_store.TableDesc("1", None, True), + ], + keep_none=True, + ) + ), + "scan mem and disk", + ) + + ds.dump() + ds = data_store.LocalDataStore(self.datastore_root) + self.assertEqual( + [ + {"k": 0, "a": None, "b": "0", "c": 1}, + {"k": 2, "b": "2"}, + {"k": 3, "a": "3", "b": "3"}, + ], + list( + ds.scan_tables( + [ + data_store.TableDesc("1", None, True), + ], + keep_none=True, + ) + ), + "merge dump", + ) + + +class TestRemoteDataStore(unittest.TestCase): + def setUp(self) -> None: + os.environ["SW_TOKEN"] = "tt" + self.ds = data_store.RemoteDataStore("http://test") + + @patch("starwhale.api._impl.data_store.requests.post") + def test_update_table(self, mock_post: Mock) -> None: + mock_post.return_value.status_code = 200 + mock_post.return_value.json.return_value = "" + self.ds.update_table( + "t1", + data_store.TableSchema( + "k", + [ + data_store.ColumnSchema("k", data_store.INT64), + data_store.ColumnSchema("a", data_store.STRING), + ], + ), + [ + {"k": 1, "a": "1"}, + {"k": 2, "a": "2"}, + {"k": 3, "-": True}, + {"k": 4, "a": None}, + ], + ) + self.ds.update_table( + "t1", + data_store.TableSchema( + "k", + [ + data_store.ColumnSchema("k", data_store.INT64), + ], + ), + [], + ) + self.ds.update_table( + "t1", + data_store.TableSchema( + "k", + [ + data_store.ColumnSchema("k", data_store.INT64), + data_store.ColumnSchema("b", data_store.BOOL), + data_store.ColumnSchema("c", data_store.INT8), + data_store.ColumnSchema("d", data_store.INT16), + data_store.ColumnSchema("e", data_store.INT32), + data_store.ColumnSchema("f", data_store.FLOAT16), + data_store.ColumnSchema("g", data_store.FLOAT32), + data_store.ColumnSchema("h", data_store.FLOAT64), + data_store.ColumnSchema("i", data_store.BYTES), + ], + ), + [ + { + "k": 1, + "b": True, + "c": 1, + "d": 1, + "e": 1, + "f": 1.0, + "g": 1.0, + "h": 1.0, + "i": b"1", + } + ], + ) + mock_post.assert_any_call( + "http://test/api/v1/datastore/updateTable", + data=json.dumps( + { + "tableName": "t1", + "tableSchemaDesc": { + "keyColumn": "k", + "columnSchemaList": [ + {"name": "k", "type": "INT64"}, + {"name": "a", "type": "STRING"}, + ], + }, + "records": [ + { + "values": [ + {"key": "k", "value": "1"}, + {"key": "a", "value": "1"}, + ] + }, + { + "values": [ + {"key": "k", "value": "2"}, + {"key": "a", "value": "2"}, + ] + }, + { + "values": [ + {"key": "k", "value": "3"}, + {"key": "-", "value": "1"}, + ] + }, + { + "values": [ + {"key": "k", "value": "4"}, + {"key": "a", "value": None}, + ] + }, + ], + }, + separators=(",", ":"), + ), + headers={ + "Content-Type": "application/json; charset=utf-8", + "Authorization": "tt", + }, + timeout=5.0, + ) + mock_post.assert_any_call( + "http://test/api/v1/datastore/updateTable", + data=json.dumps( + { + "tableName": "t1", + "tableSchemaDesc": { + "keyColumn": "k", + "columnSchemaList": [ + {"name": "k", "type": "INT64"}, + ], + }, + "records": [], + }, + separators=(",", ":"), + ), + headers={ + "Content-Type": "application/json; charset=utf-8", + "Authorization": "tt", + }, + timeout=5.0, + ) + mock_post.assert_any_call( + "http://test/api/v1/datastore/updateTable", + data=json.dumps( + { + "tableName": "t1", + "tableSchemaDesc": { + "keyColumn": "k", + "columnSchemaList": [ + {"name": "k", "type": "INT64"}, + {"name": "b", "type": "BOOL"}, + {"name": "c", "type": "INT8"}, + {"name": "d", "type": "INT16"}, + {"name": "e", "type": "INT32"}, + {"name": "f", "type": "FLOAT16"}, + {"name": "g", "type": "FLOAT32"}, + {"name": "h", "type": "FLOAT64"}, + {"name": "i", "type": "BYTES"}, + ], + }, + "records": [ + { + "values": [ + {"key": "k", "value": "1"}, + {"key": "b", "value": "1"}, + {"key": "c", "value": "1"}, + {"key": "d", "value": "1"}, + {"key": "e", "value": "1"}, + {"key": "f", "value": "3c00"}, + {"key": "g", "value": "3f800000"}, + {"key": "h", "value": "3ff0000000000000"}, + {"key": "i", "value": "MQ=="}, + ] + } + ], + }, + separators=(",", ":"), + ), + headers={ + "Content-Type": "application/json; charset=utf-8", + "Authorization": "tt", + }, + timeout=5.0, + ) + + @patch("starwhale.api._impl.data_store.requests.post") + def test_scan_table(self, mock_post: Mock) -> None: + mock_post.return_value.status_code = 200 + mock_post.return_value.json.return_value = { + "columnTypes": { + "a": "BOOL", + "b": "INT8", + "c": "INT16", + "d": "INT32", + "e": "INT64", + "f": "FLOAT16", + "g": "FLOAT32", + "h": "FLOAT64", + "i": "STRING", + "j": "BYTES", + }, + "records": [ + { + "a": "1", + "b": "1", + "c": "1", + "d": "1", + "e": "1", + "f": "3c00", + "g": "3f800000", + "h": "3ff0000000000000", + "i": "1", + "j": "MQ==", + } + ], + } + self.assertEqual( + [ + { + "a": True, + "b": 1, + "c": 1, + "d": 1, + "e": 1, + "f": 1.0, + "g": 1.0, + "h": 1.0, + "i": "1", + "j": b"1", + } + ], + list( + self.ds.scan_tables( + [ + data_store.TableDesc("t1", {"a": "b"}, True), + data_store.TableDesc("t2", ["a"]), + data_store.TableDesc("t3"), + ], + 1, + 1, + True, + ) + ), + "all types", + ) + mock_post.return_value.json.side_effect = [ + { + "columnTypes": {"a": "INT32"}, + "records": [{"a": f"{i:x}"} for i in range(1000)], + "lastKey": f"{999:x}", + }, + { + "columnTypes": {"a": "INT32"}, + "records": [{"a": f"{i+1000:x}"} for i in range(1000)], + "lastKey": f"{1999:x}", + }, + { + "columnTypes": {"a": "INT32"}, + "records": [{"a": f"{2000:x}"}], + }, + ] + self.assertEqual( + [{"a": i} for i in range(2001)], + list(self.ds.scan_tables([data_store.TableDesc("t1")])), + "scan page", + ) + mock_post.assert_any_call( + "http://test/api/v1/datastore/scanTable", + data=json.dumps( + { + "tables": [ + { + "tableName": "t1", + "columns": [{"columnName": "a", "alias": "b"}], + "keepNone": True, + }, + { + "tableName": "t2", + "columns": [{"columnName": "a", "alias": "a"}], + }, + { + "tableName": "t3", + }, + ], + "end": "1", + "start": "1", + "limit": 1000, + "keepNone": True, + }, + separators=(",", ":"), + ), + headers={ + "Content-Type": "application/json; charset=utf-8", + "Authorization": "tt", + }, + timeout=5.0, + ) + mock_post.assert_any_call( + "http://test/api/v1/datastore/scanTable", + data=json.dumps( + { + "tables": [ + { + "tableName": "t1", + }, + ], + "limit": 1000, + }, + separators=(",", ":"), + ), + headers={ + "Content-Type": "application/json; charset=utf-8", + "Authorization": "tt", + }, + timeout=5.0, + ) + mock_post.assert_any_call( + "http://test/api/v1/datastore/scanTable", + data=json.dumps( + { + "tables": [ + { + "tableName": "t1", + }, + ], + "limit": 1000, + "start": f"{999:x}", + "startInclusive": False, + }, + separators=(",", ":"), + ), + headers={ + "Content-Type": "application/json; charset=utf-8", + "Authorization": "tt", + }, + timeout=5.0, + ) + mock_post.assert_any_call( + "http://test/api/v1/datastore/scanTable", + data=json.dumps( + { + "tables": [ + { + "tableName": "t1", + }, + ], + "limit": 1000, + "start": f"{1999:x}", + "startInclusive": False, + }, + separators=(",", ":"), + ), + headers={ + "Content-Type": "application/json; charset=utf-8", + "Authorization": "tt", + }, + timeout=5.0, + ) + class TestTableWriter(BaseTestCase): def setUp(self) -> None: @@ -966,7 +1455,7 @@ def test_insert_and_delete(self) -> None: [{"k": 0}, {"k": 2, "a": "22"}, {"k": 3, "b": "3"}], list( data_store.LocalDataStore.get_instance().scan_tables( - [("p/test", "test", False)] + [data_store.TableDesc("p/test", None, False)] ) ), "scan all", diff --git a/client/tests/sdk/test_wrapper.py b/client/tests/sdk/test_wrapper.py index 8b4966b984..b91274ee66 100644 --- a/client/tests/sdk/test_wrapper.py +++ b/client/tests/sdk/test_wrapper.py @@ -35,7 +35,7 @@ def test_log_metrics(self) -> None: [{"id": "tt", "a": 0, "b": 1, "a/b": 2}], list( data_store.get_data_store().scan_tables( - [("project/test/eval/summary", "summary", False)] + [data_store.TableDesc("project/test/eval/summary")] ) ), ) @@ -51,13 +51,13 @@ def test_put_and_scan(self) -> None: dataset.put("0", a=1, b=2) dataset.put("1", a=2, b=3) dataset.put("2", a=3, b=4) - dataset.put("3", a=4, b=5, c=data_store.Link("a", "b", "c")) + dataset.put("3", a=4, b=5) dataset.close() self.assertEqual( [ {"id": "1", "a": 2, "b": 3}, {"id": "2", "a": 3, "b": 4}, - {"id": "3", "a": 4, "b": 5, "c": data_store.Link("a", "b", "c")}, + {"id": "3", "a": 4, "b": 5}, ], list(dataset.scan("1", "4")), "scan",