From a7b40adf2b77377c05a682e8e0724c535197721d Mon Sep 17 00:00:00 2001 From: Jialei <3217223+jialeicui@users.noreply.github.com> Date: Mon, 27 Mar 2023 14:50:28 +0800 Subject: [PATCH] refactor(client): datastore in standalone support version and mutable schema (#2004) * refactor(client): datastore in standalone support version and mutable schema * change file ext --- client/starwhale/api/_impl/data_store.py | 561 ++++++++------------ client/tests/sdk/test_data_store.py | 622 ++++------------------- client/tests/sdk/test_track.py | 16 +- 3 files changed, 329 insertions(+), 870 deletions(-) diff --git a/client/starwhale/api/_impl/data_store.py b/client/starwhale/api/_impl/data_store.py index d7252fc184..930cdda019 100644 --- a/client/starwhale/api/_impl/data_store.py +++ b/client/starwhale/api/_impl/data_store.py @@ -1,11 +1,13 @@ +from __future__ import annotations + import os import re import json +import time import atexit import base64 import struct import urllib -import pathlib import binascii import importlib import threading @@ -13,7 +15,6 @@ from http import HTTPStatus from typing import ( Any, - Set, cast, Dict, List, @@ -24,12 +25,14 @@ Iterator, Optional, ) +from pathlib import Path +from collections import UserDict, OrderedDict import dill import numpy as np import pyarrow as pa # type: ignore import requests -import pyarrow.parquet as pq # type: ignore +import jsonlines from loguru import logger from typing_extensions import Protocol @@ -48,10 +51,12 @@ except ImportError: has_fcntl = False +datastore_table_file_ext = ".sw-datastore.json" + def _check_move(src: str, dest: str) -> bool: if has_fcntl: - with open(os.path.join(os.path.dirname(src), ".lock"), "w") as f: + with open(os.path.join(os.path.dirname(dest), ".lock"), "w") as f: try: fcntl.flock(f, fcntl.LOCK_EX) # type: ignore except OSError: @@ -584,11 +589,38 @@ def __eq__(self, other: Any) -> bool: and self.type == other.type ) + def dumps(self) -> Dict[str, Any]: + return {"name": self.name, "type": SwType.encode_schema(self.type)} + + @staticmethod + def loads(obj: Dict[str, Any]) -> ColumnSchema: + return ColumnSchema(obj["name"], SwType.decode_schema(obj["type"])) + class TableEmptyException(Exception): pass +class Record(UserDict): + def dumps(self) -> Dict[str, Dict]: + return { + "schema": {k: SwType.encode_schema(_get_type(v)) for k, v in self.items()}, + "data": {k: _get_type(v).encode(v) for k, v in self.items()}, + } + + @staticmethod + def loads(obj: Dict[str, Dict]) -> Record: + schema = obj["schema"] + data = obj["data"] + record = Record() + for k, v in schema.items(): + record[k] = SwType.decode_schema(v).decode(data[k]) + return record + + +Records = List[Record] + + class TableSchemaDesc: def __init__( self, key_column: Optional[str], columns: Optional[List[ColumnSchema]] @@ -657,105 +689,22 @@ def __eq__(self, other: Any) -> bool: def _get_table_path(root_path: str, table_name: str) -> str: - return str(pathlib.Path(root_path) / table_name.lstrip("/")) + return str(Path(root_path) / table_name.strip("/")) + datastore_table_file_ext -def _parse_parquet_name(name: str) -> Tuple[str, int]: +def _parse_data_table_name(name: str) -> Tuple[str, int]: try: - if name.endswith(".parquet"): + if name.endswith(datastore_table_file_ext): if name.startswith("base-"): - return "base", int(name[5:-8]) + return "base", int(name[5 : -len(datastore_table_file_ext)]) elif name.startswith("patch-"): - return "patch", int(name[6:-8]) + return "patch", int(name[6 : -len(datastore_table_file_ext)]) except ValueError: # ignore invalid filename pass return "", 0 -def _write_parquet_file(filename: str, table: pa.Table) -> None: - with pq.ParquetWriter(filename, table.schema) as writer: - writer.write_table(table) - - -def _scan_parquet_file( - path: str, - columns: Optional[Dict[str, str]] = None, - start: Optional[Any] = None, - end: Optional[Any] = None, - keep_none: bool = False, - end_inclusive: bool = False, -) -> Iterator[dict]: - f = pq.ParquetFile(path) - schema_arrow = f.schema_arrow - if columns is None: - columns = { - name: name - for name in schema_arrow.names - if name != "-" and not name.startswith("~") - } - schema = TableSchema.parse(f.metadata.metadata[b"schema"].decode("utf-8")) - key_index = schema_arrow.get_field_index(schema.key_column) - if key_index < 0: - raise RuntimeError( - f"key {schema.key_column} is not found in names: {schema_arrow.names}" - ) - key_alias = columns.get(schema.key_column, None) - all_cols = [schema.key_column] - if schema_arrow.get_field_index("-") >= 0: - all_cols.append("-") - for name, alias in columns.items(): - if ( - name != schema.key_column - and name != "-" - and schema_arrow.get_field_index(name) >= 0 - ): - all_cols.append(name) - for name in schema_arrow.names: - if name.startswith("~") and name[1:] in columns: - all_cols.append(name) - - for i in range(f.num_row_groups): - stats = f.metadata.row_group(i).column(key_index).statistics - _end_check: Callable = lambda x, y: x > y if end_inclusive else x >= y - if (end is not None and _end_check(stats.min, end)) or ( - start is not None and stats.max < start - ): - continue - table = f.read_row_group(i, all_cols) - names = table.schema.names - types = [schema.columns[name].type for name in names] - n_rows = table[0].length() - n_cols = len(names) - for j in range(n_rows): - key = types[0].deserialize(table[0][j].as_py()) - if (start is not None and key < start) or ( - end is not None and _end_check(key, end) - ): - continue - d = {"*": key} - if key_alias is not None: - d[key_alias] = key - for k in range(1, n_cols): - name = names[k] - value = types[k].deserialize(table[k][j].as_py()) - if name == "-": - if value is not None: - d["-"] = value - elif name.startswith("~") and value: - alias = columns.get(name[1:], "") - if alias != "": - if keep_none: - d[alias] = None - else: - d.pop(alias, "") - else: - alias = columns.get(name, "") - if alias != "" and value is not None: - d[alias] = value - yield d - - def _merge_scan( iters: List[Iterator[Dict[str, Any]]], keep_none: bool ) -> Iterator[dict]: @@ -765,9 +714,9 @@ def __init__(self, index: int, iter: Iterator[dict]) -> None: self.iter = iter self.item: Optional[Dict[str, Any]] = None self.exhausted = False - self.nextItem() + self.next_item() - def nextItem(self) -> None: + def next_item(self) -> None: try: self.item = next(self.iter) self.exhausted = False @@ -796,7 +745,7 @@ def nextItem(self) -> None: d.clear() else: d.update(item) - nodes[i].nextItem() + nodes[i].next_item() if len(d) > 0: d["*"] = key if not keep_none: @@ -805,129 +754,73 @@ def nextItem(self) -> None: nodes = [node for node in nodes if not node.exhausted] -def _get_table_files(path: str) -> List[str]: - ensure_dir(path) - - patches = [] - base_index = -1 - for file in os.listdir(path): - type, index = _parse_parquet_name(file) - if type == "base" and index > base_index: - base_index = index - elif type == "patch": - patches.append(index) - if base_index >= 0: - ret = [os.path.join(path, f"base-{base_index}.parquet")] - else: - ret = [] - patches.sort() - for i in patches: - if i > base_index: - ret.append(os.path.join(path, f"patch-{i}.parquet")) - return ret - - -def _read_table_schema(path: str) -> TableSchema: - ensure_dir(path) - - files = _get_table_files(path) - if len(files) == 0: - raise TableEmptyException(f"table path: {path}") - - schema = pq.read_schema(files[-1]) - if schema.metadata is None: - raise RuntimeError(f"no metadata for file {files[-1]}") +def _update_schema(key_column: str, record: Dict[str, Any]) -> TableSchema: + new_schema = TableSchema(key_column, []) + for col, value in record.items(): + value_type = _get_type(value) + new_schema.columns[col] = ColumnSchema(col, value_type) + return new_schema - schema_data = schema.metadata.get(b"schema", None) - if schema_data is None: - raise RuntimeError(f"no schema for file {files[-1]}") - return TableSchema.parse(schema_data.decode()) +class InnerRecord: + def __init__(self, key: Any, record: Optional[Record] = None) -> None: + self.key = key + self.records: OrderedDict[int, Record] = OrderedDict() + self.ordered = True + if record is not None: + self.append(record) + + def append(self, record: Record) -> int: + self.ordered = False + seq = self._get_seq_num() + self.records[seq] = record + return seq + + def _reorder(self) -> None: + if not self.ordered: + self.records = OrderedDict(sorted(self.records.items())) + self.ordered = True + + def get_record(self, revision: Optional[int] = None) -> Dict[str, Any]: + self._reorder() + ret: Dict[str, Any] = dict() + for seq, record in self.records.items(): + if revision is None or seq <= revision: + if "-" in record and record["-"]: + ret = record.data + else: + if "-" in ret: + ret = dict() + ret.update(record) + return ret + def dumps(self) -> Dict[str, Any]: + return { + "key": self.key, + "records": {seq: record.dumps() for seq, record in self.records.items()}, + } -def _scan_table( - path: str, - columns: Optional[Dict[str, str]] = None, - start: Optional[Any] = None, - end: Optional[Any] = None, - keep_none: bool = False, - end_inclusive: bool = False, -) -> Iterator[dict]: - iters = [] - for file in _get_table_files(path): - if os.path.basename(file).startswith("patch"): - keep = True - else: - keep = keep_none - iters.append(_scan_parquet_file(file, columns, start, end, keep, end_inclusive)) - return _merge_scan(iters, keep_none) - - -def _records_to_table( - schema: TableSchema, records: List[Dict[str, Any]], deletes: List[Any] -) -> pa.Table: - if len(records) == 0: - return - schema = schema.copy() - if len(deletes) > 0: - schema.columns["-"] = ColumnSchema("-", BOOL) - for key in deletes: - records.append({schema.key_column: key, "-": True}) - records.sort(key=lambda x: cast(str, x.get(schema.key_column)), reverse=True) - d: Dict[str, Any] = {} - nulls: Dict[str, List[int]] = {} - for i in range(len(records)): - record = records[len(records) - 1 - i] - for col, col_schema in schema.columns.items(): - if col in record: - value = record.get(col) - if value is None: - nulls.setdefault(col, []).append(i) - else: - value = None - d.setdefault(col, []).append(col_schema.type.serialize(value)) - for col, indexes in nulls.items(): - schema.columns["~" + col] = ColumnSchema("~" + col, BOOL) - data = [False] * len(records) - for i in indexes: - data[i] = True - d["~" + col] = data - pa_schema = pa.schema( - [(k, v.type.pa_type) for k, v in schema.columns.items()], - {"schema": str(schema)}, - ) - return pa.Table.from_pydict(d, schema=pa_schema) - - -def _update_schema(schema: TableSchema, record: Dict[str, Any]) -> TableSchema: - new_schema = schema.copy() - for col, value in record.items(): - value_type = _get_type(value) - column_schema = schema.columns.get(col, None) - if column_schema is None: - new_schema.columns[col] = ColumnSchema(col, value_type) - else: - try: - new_schema.columns[col].type = new_schema.columns[col].type.merge( - value_type - ) - except RuntimeError as e: - raise RuntimeError(f"can not insert a record with field {col}") from e + @staticmethod + def loads(data: Dict[str, Any]) -> InnerRecord: + ret = InnerRecord(data["key"]) + ret.ordered = False + ret.records = OrderedDict( + {int(seq): Record.loads(record) for seq, record in data["records"].items()} + ) + return ret - return new_schema + @staticmethod + def _get_seq_num() -> int: + return time.monotonic_ns() class MemoryTable: - def __init__(self, table_name: str, schema: TableSchema) -> None: + def __init__(self, table_name: str, key_column: ColumnSchema) -> None: self.table_name = table_name - self.schema = schema.copy() - self.records: Dict[Any, Dict[str, Any]] = {} - self.deletes: Set[Any] = set() + self.key_column = key_column + self.records: Dict[Any, InnerRecord] = {} self.lock = threading.Lock() - - def get_schema(self) -> TableSchema: - with self.lock: - return self.schema.copy() + self.dirty = False def scan( self, @@ -936,91 +829,99 @@ def scan( end: Optional[Any] = None, keep_none: bool = False, end_inclusive: bool = False, + revision: Optional[int] = None, ) -> Iterator[Dict[str, Any]]: _end_check: Callable = lambda x, y: x <= y if end_inclusive else x < y with self.lock: - schema = self.schema.copy() - records = [] - for key in self.deletes: - if (start is None or key >= start) and ( - end is None or _end_check(key, end) - ): - records.append({self.schema.key_column: key, "-": True}) - for k, v in self.records.items(): if (start is None or k >= start) and ( end is None or _end_check(k, end) ): records.append(v) - records.sort(key=lambda x: cast(str, x[self.schema.key_column])) - for r in records: + records.sort(key=lambda x: cast(str, x.key)) + for ir in records: + r = ir.get_record(revision) if columns is None: d = dict(r) else: d = {columns[k]: v for k, v in r.items() if k in columns} if "-" in r: d["-"] = r["-"] - d["*"] = r[schema.key_column] + d["*"] = r[self.key_column.name] if not keep_none: d = {k: v for k, v in d.items() if v is not None} yield d - def insert(self, record: Dict[str, Any]) -> None: + def insert(self, record: Dict[str, Any]) -> int: + self.dirty = True with self.lock: - self.schema = _update_schema(self.schema, record) - key = record.get(self.schema.key_column) - r = self.records.setdefault(key, record) - if r is not record: - r.update(record) - - def delete(self, keys: List[Any]) -> None: + key = record.get(self.key_column.name) + r = self.records.setdefault(key, InnerRecord(key)) + return r.append(Record(record)) + + def delete(self, keys: List[Any]) -> int | None: + """ + Delete records by keys. If the key is not found, it will be ignored. + Returns the sequence number of the last delete operation, or None if no delete operation is performed. + """ + seq = None with self.lock: for key in keys: - self.deletes.add(key) - self.records.pop(key, None) + r = self.records.get(key, None) + if r is not None: + self.dirty = True + seq = r.append(Record({self.key_column.name: key, "-": True})) + return seq + + def _dump_meta(self) -> Dict[str, Any]: + return { + "key_column": self.key_column.dumps(), + "version": "0.1", + } - def dump(self, root_path: str) -> None: - with self.lock: - schema = self.schema.copy() - path = _get_table_path(root_path, self.table_name) - ensure_dir(path) - while True: - max_index = -1 - for file in os.listdir(path): - type, index = _parse_parquet_name(file) - if type != "" and index > max_index: - max_index = index - if max_index < 0: - filename = "base-0.parquet" - else: - filename = f"base-{max_index + 1}.parquet" - temp_filename = f"temp.{os.getpid()}" - if max_index >= 0: - s = _read_table_schema(os.path.join(path)) - s.merge(schema) - schema = s - _write_parquet_file( - os.path.join(path, temp_filename), - _records_to_table( - schema, - list( - _merge_scan( - [ - _scan_table(path, keep_none=True), - self.scan(keep_none=True), - ], - True, - ) - ), - [], - ), - ) - if _check_move( - os.path.join(path, temp_filename), os.path.join(path, filename) - ): - break + @classmethod + def _parse_meta(cls, meta: Dict[str, Any]) -> ColumnSchema: + if meta["version"] != "0.1": + raise ValueError(f"Unsupported version {meta['version']}") + return ColumnSchema.loads(meta["key_column"]) + + @classmethod + def loads(cls, file: Path, table_name: str) -> MemoryTable: + if not file.exists() or not file.is_file(): + raise RuntimeError(f"File {file} does not exist") + with jsonlines.open(file) as reader: + meta = reader.read() + if not isinstance(meta, dict): + raise RuntimeError(f"Invalid meta data {meta} in {file}") + key_column = cls._parse_meta(meta) + table = MemoryTable(table_name, key_column) + for record in reader: + ir = InnerRecord.loads(record) + table.records[ir.key] = ir + return table + + def dump(self, root_path: str, if_dirty: bool = True) -> None: + if if_dirty and not self.dirty: + return + dst = os.path.join(root_path, f"{self.table_name}{datastore_table_file_ext}") + base = os.path.dirname(dst) + temp_filename = os.path.join(base, f"temp.{os.getpid()}") + ensure_dir(base) + + # dump key column info + with jsonlines.open(temp_filename, mode="w") as writer: + writer.write(self._dump_meta()) + for ir in self.records.values(): + writer.write(ir.dumps()) + + # TODO: remove the lock, and disable concurrent access to the table + while not _check_move(temp_filename, dst): + # wait for the file to be released + time.sleep(0.1) + + self.dirty = False class TableDesc: @@ -1105,10 +1006,15 @@ def update_table( raise RuntimeError( f"invalid column name {k}, only letters(A-Z, a-z), digits(0-9), hyphen('-'), and underscore('_') are allowed" ) - table = self._get_table(table_name, schema) - if schema.key_column != table.schema.key_column: + table = self._get_table(table_name, schema.columns[schema.key_column]) + # this will never happen, makes mypy happy + if table is None: raise RuntimeError( - f"invalid key column, expected {table.schema.key_column}, actual {schema.key_column}" + f"table {table_name} does not exist and can not be created" + ) + if schema.key_column != table.key_column.name: + raise RuntimeError( + f"invalid key column, expected {table.key_column}, actual {schema.key_column}" ) for r in records: @@ -1122,17 +1028,23 @@ def update_table( else: table.insert(r) - def _get_table(self, table_name: str, schema: TableSchema) -> MemoryTable: + def _get_table( + self, table_name: str, key_column: ColumnSchema | None, create: bool = True + ) -> MemoryTable | None: with self.lock: table = self.tables.get(table_name, None) if table is None: - table_path = _get_table_path(self.root_path, table_name) - if _get_table_files(table_path): - table_schema = _read_table_schema(table_path) - table_schema.merge(schema) + file = Path(_get_table_path(self.root_path, table_name)) + if file.exists(): + table = MemoryTable.loads(file, table_name) else: - table_schema = schema - table = MemoryTable(table_name, table_schema) + if not create: + return None + if key_column is None: + raise RuntimeError( + f"key column is required for table {table_name}" + ) + table = MemoryTable(table_name, key_column) self.tables[table_name] = table return table @@ -1148,7 +1060,7 @@ class TableInfo: def __init__( self, name: str, - key_column_type: pa.DataType, + key_column_type: SwType, columns: Optional[Dict[str, str]], keep_none: bool, ) -> None: @@ -1159,14 +1071,10 @@ def __init__( infos: List[TableInfo] = [] for table_desc in tables: - table = self.tables.get(table_desc.table_name, None) - if table is not None: - schema = table.get_schema() - else: - schema = _read_table_schema( - _get_table_path(self.root_path, table_desc.table_name) - ) - key_column_type = schema.columns[schema.key_column].type.pa_type + table = self._get_table(table_desc.table_name, None, create=False) + if table is None: + continue + key_column_type = table.key_column.type infos.append( TableInfo( table_desc.table_name, @@ -1176,7 +1084,7 @@ def __init__( ) ) - # check for key type conflictions + # check for key type conflicts for info in infos: if info is infos[0]: continue @@ -1189,52 +1097,15 @@ def __init__( iters = [] for info in infos: - table_path = _get_table_path(self.root_path, info.name) - if info.name in self.tables: - if _get_table_files(table_path): - iters.append( - _merge_scan( - [ - _scan_table( - table_path, - info.columns, - start, - end, - info.keep_none, - end_inclusive, - ), - self.tables[info.name].scan( - info.columns, - start, - end, - True, - end_inclusive, - ), - ], - info.keep_none, - ) - ) - else: - iters.append( - self.tables[info.name].scan( - info.columns, - start, - end, - info.keep_none, - end_inclusive, - ) - ) - else: - iters.append( - _scan_table( - table_path, - info.columns, - start, - end, - info.keep_none, - end_inclusive, - ) + iters.append( + self.tables[info.name].scan( + info.columns, + start, + end, + info.keep_none, + end_inclusive, ) + ) for record in _merge_scan(iters, keep_none): record.pop("*", None) r: Dict[str, Any] = {} @@ -1424,13 +1295,13 @@ def __init__( ) -> None: super().__init__(name=f"TableWriter-{table_name}") self.table_name = table_name - self.schema = TableSchema(key_column, []) + self.key_column = key_column self.data_store = data_store or get_data_store() self._cond = threading.Condition() self._stopped = False - self._records: List[Dict[str, Any]] = [] - self._updating_records: List[Dict[str, Any]] = [] + self._records: List[Tuple[TableSchema, List[Dict[str, Any]]]] = [] + self._updating_records: List[Tuple[TableSchema, List[Dict[str, Any]]]] = [] self._queue_run_exceptions: List[Exception] = [] self._run_exceptions_limits = max(run_exceptions_limits, 0) @@ -1475,19 +1346,20 @@ def insert(self, record: Dict[str, Any]) -> None: self._insert(record) def delete(self, key: Any) -> None: - self._insert({self.schema.key_column: key, "-": True}) + self._insert({self.key_column: key, "-": True}) def _insert(self, record: Dict[str, Any]) -> None: self._raise_run_exceptions(self._run_exceptions_limits) - key = record.get(self.schema.key_column, None) + key = record.get(self.key_column, None) if key is None: raise RuntimeError( - f"the key {self.schema.key_column} should not be none, record:{record}" + f"the key {self.key_column} should not be none, record:{record}" ) with self._cond: - self.schema = _update_schema(self.schema, record) - self._records.append(record) + schema = _update_schema(self.key_column, record) + # TODO: group the records with the same schema + self._records.append((schema, [record])) self._cond.notify() def flush(self) -> None: @@ -1507,9 +1379,8 @@ def run(self) -> None: self._records = [] try: - self.data_store.update_table( - self.table_name, self.schema, self._updating_records - ) + for schema, records in self._updating_records: + self.data_store.update_table(self.table_name, schema, records) except Exception as e: logger.exception(e) self._queue_run_exceptions.append(e) diff --git a/client/tests/sdk/test_data_store.py b/client/tests/sdk/test_data_store.py index a0708e50c1..094993aee4 100644 --- a/client/tests/sdk/test_data_store.py +++ b/client/tests/sdk/test_data_store.py @@ -3,221 +3,59 @@ import time import unittest import concurrent.futures -from typing import Dict, List from unittest.mock import Mock, patch import numpy as np -import pyarrow as pa # type: ignore import requests from requests_mock import Mocker -from starwhale import Text from starwhale.consts import HTTPMethod from starwhale.api._impl import data_store -from starwhale.api._impl.data_store import ( - SwType, - TableEmptyException, - TableWriterException, -) +from starwhale.api._impl.data_store import INT64, ColumnSchema, TableWriterException from .. import BaseTestCase class TestBasicFunctions(BaseTestCase): def test_get_table_path(self) -> None: - self.assertEqual(os.path.join("a", "b"), data_store._get_table_path("a", "b")) self.assertEqual( - os.path.join("a", "b", "c"), data_store._get_table_path("a", "b/c") + os.path.join("a", "b.sw-datastore.json"), + data_store._get_table_path("a", "b"), ) self.assertEqual( - os.path.join("a", "b", "c", "d"), data_store._get_table_path("a", "b/c/d") + os.path.join("a", "b", "c.sw-datastore.json"), + data_store._get_table_path("a", "b/c"), + ) + self.assertEqual( + os.path.join("a", "b", "c", "d.sw-datastore.json"), + data_store._get_table_path("a", "b/c/d"), ) - def test_parse_parquet_name(self) -> None: + def test_parse_data_table_name(self) -> None: self.assertEqual( ("", 0), - data_store._parse_parquet_name("base-123.txt"), + data_store._parse_data_table_name("base-123.txt"), "invalid extension", ) self.assertEqual( ("", 0), - data_store._parse_parquet_name("base_1.parquet"), + data_store._parse_data_table_name("base_1.sw-datastore.json"), "invalid prefix", ) self.assertEqual( ("", 0), - data_store._parse_parquet_name("base-i.parquet"), + data_store._parse_data_table_name("base-i.sw-datastore.json"), "invalid index", ) self.assertEqual( - ("base", 123), data_store._parse_parquet_name("base-123.parquet"), "base" - ) - self.assertEqual( - ("patch", 123), data_store._parse_parquet_name("patch-123.parquet"), "patch" - ) - - def test_write_and_scan(self) -> None: - path = os.path.join(self.datastore_root, "base-0.parquet") - data_store._write_parquet_file( - path, - pa.Table.from_pydict( - { - "a": [0, 1, 2], - "b": ["x", "y", "z"], - "c": [10, 11, 12], - "-": [None, True, None], - "~c": [False, False, True], - }, - metadata={ - "schema": str( - data_store.TableSchema( - "a", - [ - data_store.ColumnSchema("a", data_store.INT64), - data_store.ColumnSchema("b", data_store.STRING), - data_store.ColumnSchema("c", data_store.INT64), - data_store.ColumnSchema("-", data_store.BOOL), - data_store.ColumnSchema("~c", data_store.BOOL), - ], - ) - ) - }, - ), - ) - self.assertEqual( - [ - {"*": 0, "a": 0, "b": "x", "c": 10}, - {"*": 1, "a": 1, "b": "y", "c": 11, "-": True}, - {"*": 2, "a": 2, "b": "z"}, - ], - list(data_store._scan_parquet_file(path)), - "scan all", - ) - self.assertEqual( - [ - {"*": 0, "i": "x", "j": 10}, - {"*": 1, "i": "y", "j": 11, "-": True}, - {"*": 2, "i": "z"}, - ], - list(data_store._scan_parquet_file(path, columns={"b": "i", "c": "j"})), - "some columns", - ) - self.assertEqual( - [ - {"*": 0, "i": "x", "j": 10}, - {"*": 1, "i": "y", "j": 11, "-": True}, - {"*": 2, "i": "z"}, - ], - list( - data_store._scan_parquet_file( - path, columns={"b": "i", "c": "j", "x": "x"} - ) - ), - "extra column", - ) - self.assertEqual( - [ - {"*": 0, "i": "x", "j": 10}, - {"*": 1, "i": "y", "j": 11, "-": True}, - {"*": 2, "i": "z"}, - ], - list( - data_store._scan_parquet_file( - path, columns={"b": "i", "c": "j", "-": "x"} - ) - ), - "'-' column", - ) - self.assertEqual( - [{"*": 1, "i": "y", "j": 11, "-": True}, {"*": 2, "i": "z"}], - list( - data_store._scan_parquet_file( - path, columns={"b": "i", "c": "j"}, start=1 - ) - ), - "with start", - ) - self.assertEqual( - [{"*": 0, "i": "x", "j": 10}], - list( - data_store._scan_parquet_file(path, columns={"b": "i", "c": "j"}, end=1) - ), - "with end", - ) - self.assertEqual( - [{"*": 1, "i": "y", "j": 11, "-": True}], - list( - data_store._scan_parquet_file( - path, columns={"b": "i", "c": "j"}, start=1, end=2 - ) - ), - "with start and end", - ) - - self.assertEqual( - [{"*": 1, "-": True, "i": "y", "j": 11}, {"*": 2, "i": "z"}], - list( - data_store._scan_parquet_file( - path, - columns={"b": "i", "c": "j"}, - start=1, - end=2, - end_inclusive=True, - ) - ), - "with start and end, with end inclusive", - ) - self.assertEqual( - [ - {"*": 0, "a": 0, "b": "x", "c": 10}, - {"*": 1, "a": 1, "b": "y", "c": 11, "-": True}, - {"*": 2, "a": 2, "b": "z", "c": None}, - ], - list(data_store._scan_parquet_file(path, keep_none=True)), - "keep none", - ) - - def test_only_one_end_inclusive(self) -> None: - path = os.path.join(self.datastore_root, "base-10.parquet") - data_store._write_parquet_file( - path, - pa.Table.from_pydict( - { - "a": [0], - "b": ["x"], - "c": [10], - "-": [None], - "~c": [False], - }, - metadata={ - "schema": str( - data_store.TableSchema( - "a", - [ - data_store.ColumnSchema("a", data_store.INT64), - data_store.ColumnSchema("b", data_store.STRING), - data_store.ColumnSchema("c", data_store.INT64), - data_store.ColumnSchema("-", data_store.BOOL), - data_store.ColumnSchema("~c", data_store.BOOL), - ], - ) - ) - }, - ), + ("base", 123), + data_store._parse_data_table_name("base-123.sw-datastore.json"), + "base", ) self.assertEqual( - 1, - len( - list( - data_store._scan_parquet_file( - path, - start=0, - end=0, - end_inclusive=True, - ) - ) - ), - "end inclusive and end is max", + ("patch", 123), + data_store._parse_data_table_name("patch-123.sw-datastore.json"), + "patch", ) def test_merge_scan(self) -> None: @@ -408,230 +246,6 @@ def test_merge_scan(self) -> None: "mixed", ) - def test_get_table_files(self) -> None: - data: Dict[str, List[str]] = { - "0": [], - "1": ["base-1.parquet"], - "2": ["base-0.parquet", "patch-1.parquet", "patch-3.parquet"], - "3": [ - "base-0.parquet", - "patch-1.parquet", - "base-1.parquet", - "patch-2.parquet", - ], - } - for dir, files in data.items(): - dir = os.path.join(self.datastore_root, dir) - os.makedirs(dir) - for file in files: - file = os.path.join(dir, file) - with open(file, "w"): - pass - self.assertEqual( - [], - data_store._get_table_files(os.path.join(self.datastore_root, "0")), - "empty", - ) - self.assertEqual( - [os.path.join(self.datastore_root, "1", "base-1.parquet")], - data_store._get_table_files(os.path.join(self.datastore_root, "1")), - "base only", - ) - self.assertEqual( - [ - os.path.join(self.datastore_root, "2", f) - for f in ("base-0.parquet", "patch-1.parquet", "patch-3.parquet") - ], - data_store._get_table_files(os.path.join(self.datastore_root, "2")), - "base and patches", - ) - self.assertEqual( - [ - os.path.join(self.datastore_root, "3", f) - for f in ("base-1.parquet", "patch-2.parquet") - ], - data_store._get_table_files(os.path.join(self.datastore_root, "3")), - "multiple bases", - ) - - def test_scan_table(self) -> None: - data_store._write_parquet_file( - os.path.join(self.datastore_root, "base-0.parquet"), - pa.Table.from_pydict( - {"a": [0, 1, 2], "t": [7, 7, 7]}, - metadata={ - "schema": str( - data_store.TableSchema( - "a", - [ - data_store.ColumnSchema("a", data_store.INT64), - data_store.ColumnSchema("t", data_store.INT64), - ], - ) - ) - }, - ), - ) - data_store._write_parquet_file( - os.path.join(self.datastore_root, "patch-1.parquet"), - pa.Table.from_pydict( - { - "a": [0, 1, 2], - "b": ["x", "y", "z"], - "c": [10, 11, 12], - "d": [None, True, None], - }, - metadata={ - "schema": str( - data_store.TableSchema( - "a", - [ - data_store.ColumnSchema("a", data_store.INT64), - data_store.ColumnSchema("b", data_store.STRING), - data_store.ColumnSchema("c", data_store.INT64), - data_store.ColumnSchema("d", data_store.BOOL), - ], - ) - ) - }, - ), - ) - data_store._write_parquet_file( - os.path.join(self.datastore_root, "base-1.parquet"), - pa.Table.from_pydict( - {"k": [1, 3, 4, 5], "a": ["1", "3", "4", "5"]}, - metadata={ - "schema": str( - data_store.TableSchema( - "k", - [ - data_store.ColumnSchema("k", data_store.INT64), - data_store.ColumnSchema("a", data_store.STRING), - ], - ) - ) - }, - ), - ) - data_store._write_parquet_file( - os.path.join(self.datastore_root, "patch-2.parquet"), - pa.Table.from_pydict( - {"k": [0, 2, 3, 5], "a": ["0", "2", "3.3", "5.5"]}, - metadata={ - "schema": str( - data_store.TableSchema( - "k", - [ - data_store.ColumnSchema("k", data_store.INT64), - data_store.ColumnSchema("a", data_store.STRING), - ], - ) - ) - }, - ), - ) - data_store._write_parquet_file( - os.path.join(self.datastore_root, "patch-3.parquet"), - pa.Table.from_pydict( - {"k": [2, 4], "-": [True, True]}, - metadata={ - "schema": str( - data_store.TableSchema( - "k", - [ - data_store.ColumnSchema("k", data_store.INT64), - data_store.ColumnSchema("-", data_store.BOOL), - ], - ) - ) - }, - ), - ) - data_store._write_parquet_file( - os.path.join(self.datastore_root, "patch-4.parquet"), - pa.Table.from_pydict( - { - "k": [0, 1, 2, 3], - "a": [None, None, None, "3"], - "b": ["0", "1", "2", "3"], - "~b": [False, False, False, True], - }, - metadata={ - "schema": str( - data_store.TableSchema( - "k", - [ - data_store.ColumnSchema("k", data_store.INT64), - data_store.ColumnSchema("a", data_store.STRING), - data_store.ColumnSchema("b", data_store.STRING), - data_store.ColumnSchema("~b", data_store.BOOL), - ], - ) - ) - }, - ), - ) - self.assertEqual( - [], - list(data_store._scan_table(os.path.join(self.datastore_root, "no"))), - "empty", - ) - self.assertEqual( - [ - {"*": 0, "k": 0, "a": "0", "b": "0"}, - {"*": 1, "k": 1, "a": "1", "b": "1"}, - {"*": 2, "k": 2, "b": "2"}, - {"*": 3, "k": 3, "a": "3"}, - {"*": 5, "k": 5, "a": "5.5"}, - ], - list(data_store._scan_table(self.datastore_root)), - "scan all", - ) - self.assertEqual( - [ - {"*": 0, "i": "0", "j": "0"}, - {"*": 1, "i": "1", "j": "1"}, - {"*": 2, "j": "2"}, - {"*": 3, "i": "3"}, - {"*": 5, "i": "5.5"}, - ], - list(data_store._scan_table(self.datastore_root, {"a": "i", "b": "j"})), - "some columns", - ) - self.assertEqual( - [{"*": 2, "j": "2"}, {"*": 3, "i": "3"}], - list( - data_store._scan_table( - self.datastore_root, {"a": "i", "b": "j"}, start=2, end=5 - ) - ), - "with start and end", - ) - self.assertEqual( - [{"*": 2, "j": "2"}, {"*": 3, "i": "3"}], - list( - data_store._scan_table( - self.datastore_root, - {"a": "i", "b": "j"}, - start=2, - end=3, - end_inclusive=True, - ) - ), - "with start and end(inclusive)", - ) - self.assertEqual( - [ - {"*": 0, "k": 0, "a": "0", "b": "0"}, - {"*": 1, "k": 1, "a": "1", "b": "1"}, - {"*": 2, "k": 2, "b": "2"}, - {"*": 3, "k": 3, "a": "3", "b": None}, - {"*": 5, "k": 5, "a": "5.5"}, - ], - list(data_store._scan_table(self.datastore_root, keep_none=True)), - "keep none", - ) - def test_get_type(self) -> None: self.assertEqual(data_store.UNKNOWN, data_store._get_type(None), "unknown") self.assertEqual(data_store.BOOL, data_store._get_type(False), "bool") @@ -948,7 +562,7 @@ def test_update_schema(self) -> None: data_store.TableSchema( "a", [data_store.ColumnSchema("a", data_store.INT64)] ), - data_store._update_schema(data_store.TableSchema("a", []), {"a": 1}), + data_store._update_schema("a", {"a": 1}), "new field 1", ) self.assertEqual( @@ -960,77 +574,23 @@ def test_update_schema(self) -> None: ], ), data_store._update_schema( - data_store.TableSchema( - "a", [data_store.ColumnSchema("a", data_store.INT64)] - ), - {"b": ""}, + "a", + {"a": 1, "b": ""}, ), "new field 2", ) - with self.assertRaises(RuntimeError, msg="conflict"): - data_store._update_schema( - data_store.TableSchema( - "a", - [ - data_store.ColumnSchema("a", data_store.INT64), - data_store.ColumnSchema("b", data_store.STRING), - ], - ), - {"b": 0}, - ) self.assertEqual( data_store.TableSchema( "a", [data_store.ColumnSchema("a", data_store.UNKNOWN)] ), - data_store._update_schema(data_store.TableSchema("a", []), {"a": None}), + data_store._update_schema("a", {"a": None}), "none 1", ) - self.assertEqual( - data_store.TableSchema( - "a", [data_store.ColumnSchema("a", data_store.UNKNOWN)] - ), - data_store._update_schema( - data_store.TableSchema( - "a", [data_store.ColumnSchema("a", data_store.UNKNOWN)] - ), - {"a": None}, - ), - "none 2", - ) - self.assertEqual( - data_store.TableSchema( - "a", [data_store.ColumnSchema("a", data_store.INT64)] - ), - data_store._update_schema( - data_store.TableSchema( - "a", [data_store.ColumnSchema("a", data_store.UNKNOWN)] - ), - {"a": 0}, - ), - "none 3", - ) - self.assertEqual( - data_store.TableSchema( - "a", [data_store.ColumnSchema("a", data_store.INT64)] - ), - data_store._update_schema( - data_store.TableSchema( - "a", [data_store.ColumnSchema("a", data_store.INT64)] - ), - {"a": None}, - ), - "none 4", - ) class TestMemoryTable(BaseTestCase): def test_mixed(self) -> None: - table = data_store.MemoryTable( - "test", - data_store.TableSchema( - "k", [data_store.ColumnSchema("k", data_store.INT64)] - ), - ) + table = data_store.MemoryTable("test", ColumnSchema("k", INT64)) table.insert({"k": 0, "a": "0"}) table.insert({"k": 1, "a": "1"}) table.insert({"k": 2, "a": "2"}) @@ -1048,6 +608,7 @@ def test_mixed(self) -> None: list(table.scan()), "scan all", ) + self.assertEqual( [ {"*": 0, "k": 0, "a": "0"}, @@ -1076,54 +637,87 @@ def test_mixed(self) -> None: "some columns", ) - def test_write_with_object(self) -> None: - table = data_store.MemoryTable( - "test", - data_store.TableSchema( - "k", [data_store.ColumnSchema("k", data_store.INT64)] - ), + def test_revision(self): + table = data_store.MemoryTable("test", ColumnSchema("k", INT64)) + table.insert({"k": 0, "a": "0"}) + rev = table.insert({"k": 0, "a": "1"}) + self.assertEqual( + [{"*": 0, "k": 0, "a": "1"}], + list(table.scan(revision=rev)), + "revision 1", + ) + table.delete([0]) + self.assertEqual( + [{"*": 0, "k": 0, "a": "1"}], + list(table.scan(revision=rev)), + "revision 1 after deletion", + ) + table.delete([1]) + self.assertEqual( + [{"*": 0, "k": 0, "a": "1"}], + list(table.scan(revision=rev)), + "revision 1 after deleting non-existing row", + ) + + rev2 = table.insert({"k": 0, "a": "2"}) + self.assertEqual( + [{"*": 0, "k": 0, "a": "1"}], + list(table.scan(revision=rev)), + "revision 1 after inserting new row", ) - table.insert({"k": 0, "data/text": Text("my_text")}) + self.assertEqual( - [{"k": 0, "data/text": Text("my_text"), "*": 0}], + [{"*": 0, "k": 0, "a": "2"}], list(table.scan()), - "get", + "latest 1", + ) + self.assertEqual( + [{"*": 0, "k": 0, "a": "2"}], + list(table.scan(revision=rev2)), + "latest 1 with rev", ) - column_schemas = [] - for col in table.get_schema().columns.values(): - d = SwType.encode_schema(col.type) - d["name"] = col.name - column_schemas.append(d) + def test_mutable_schema(self): + table = data_store.MemoryTable("test", ColumnSchema("k", INT64)) + table.insert({"k": 0, "a": "0"}) + self.assertEqual( + [{"*": 0, "k": 0, "a": "0"}], + list(table.scan()), + "insert string get string", + ) - # TODO wait to resolve UNKNOWN type + table.insert({"k": 0, "a": 1}) self.assertEqual( - column_schemas, - [ - {"type": "INT64", "name": "k"}, - { - "type": "OBJECT", - "attributes": [ - {"type": "STRING", "name": "_content"}, - {"type": "BYTES", "name": "fp"}, - {"type": "BYTES", "name": "_BaseArtifact__cache_bytes"}, - {"type": "STRING", "name": "_type"}, - {"type": "STRING", "name": "display_name"}, - {"type": "STRING", "name": "_mime_type"}, - { - "type": "TUPLE", - "elementType": {"type": "UNKNOWN"}, - "name": "shape", - }, - {"type": "STRING", "name": "_dtype_name"}, - {"type": "STRING", "name": "encoding"}, - {"type": "UNKNOWN", "name": "link"}, - {"type": "UNKNOWN", "name": "owner"}, - ], - "pythonType": "starwhale.core.dataset.type.Text", - "name": "data/text", - }, - ], + [{"*": 0, "k": 0, "a": 1}], + list(table.scan()), + "insert int get int", + ) + + table.delete([0]) + self.assertEqual( + [{"*": 0, "-": True, "k": 0}], + list(table.scan()), + "delete all", + ) + + rev = table.insert({"k": 0, "a": "2", "b": 1, "c": 2.0}) + self.assertEqual( + [{"*": 0, "k": 0, "a": "2", "b": 1, "c": 2.0}], + list(table.scan()), + "insert multiple types", + ) + + table.insert({"k": 0, "a": 3, "b": 2.1, "c": None, "d": None}) + self.assertEqual( + [{"*": 0, "k": 0, "a": 3, "b": 2.1, "c": None, "d": None}], + list(table.scan(keep_none=True)), + "change multiple types", + ) + + self.assertEqual( + [{"*": 0, "k": 0, "a": "2", "b": 1, "c": 2.0}], + list(table.scan(revision=rev)), + "get changes by revision", ) @@ -2358,7 +1952,7 @@ def test_scan_table(self, mock_post: Mock) -> None: { "data": { "columnTypes": [{"name": "a", "type": "INT32"}], - "records": [{"a": f"{i+1000:x}"} for i in range(1000)], + "records": [{"a": f"{i + 1000:x}"} for i in range(1000)], "lastKey": f"{1999:x}", } }, @@ -2448,19 +2042,13 @@ def tearDown(self) -> None: super().tearDown() def test_writer(self): - _writer = data_store.TableWriter("p/test_flush", "id") - _writer.insert({"id": 0, "result": "data"}) - with self.assertRaises(TableEmptyException): - list(_writer.data_store.scan_tables([data_store.TableDesc("p/test_flush")])) - _writer.close() - _writer2 = data_store.TableWriter("p/test_flush2", "id") _writer2.insert({"id": 0, "result": "data"}) _writer2.flush() self.assertEqual( len( list( - _writer.data_store.scan_tables( + _writer2.data_store.scan_tables( [data_store.TableDesc("p/test_flush2")] ) ) @@ -2482,7 +2070,7 @@ def test_writer(self): self.assertEqual( len( list( - _writer.data_store.scan_tables( + _writer3.data_store.scan_tables( [data_store.TableDesc("p/test_flush3")] ) ) @@ -2510,15 +2098,15 @@ def test_insert_and_delete(self) -> None: "y": data_store.Link("http://test.com/1.jpg", "1", "image/jpeg"), } ) - with self.assertRaises(RuntimeError, msg="conflicting type"): - self.writer.insert({"k": 4, "a": 0}) + # change type + self.writer.insert({"k": 4, "a": 0}) self.writer.close() self.assertEqual( [ {"k": 0, "a": None}, {"k": 2, "a": "22"}, {"k": 3, "b": "3"}, - {"k": 4, "a/b": 0, "a/c": 1}, + {"k": 4, "a": 0, "a/b": 0, "a/c": 1}, { "k": 5, "x": data_store.Link("http://test.com/1.jpg"), diff --git a/client/tests/sdk/test_track.py b/client/tests/sdk/test_track.py index 0c6cb300d5..f6649b9076 100644 --- a/client/tests/sdk/test_track.py +++ b/client/tests/sdk/test_track.py @@ -452,9 +452,9 @@ def test_handle_metrics(self) -> None: assert isinstance(h._table_writers["metrics/user"], TableWriter) h.flush() - parquet_path = workdir / "metrics" / "user" / "base-0.parquet" - assert parquet_path.exists() - assert parquet_path.is_file() + datastore_file_path = workdir / "metrics" / "user.sw-datastore.json" + assert datastore_file_path.exists() + assert datastore_file_path.is_file() records = list(h._data_store.scan_tables([TableDesc("metrics/user")])) assert len(records) == 2 @@ -508,9 +508,9 @@ def test_handle_artifacts(self) -> None: h.flush() - parquet_path = workdir / "artifacts" / "user" / "base-0.parquet" - assert parquet_path.exists() - assert parquet_path.is_file() + datastore_file_path = workdir / "artifacts" / "user.sw-datastore.json" + assert datastore_file_path.exists() + assert datastore_file_path.is_file() files_dir = workdir / "artifacts" / "_files" assert files_dir.exists() @@ -559,8 +559,8 @@ def test_run(self) -> None: assert "metrics/_system" in h._table_writers assert "artifacts/user" in h._table_writers - assert (workdir / "metrics" / "user" / "base-0.parquet").exists() - assert (workdir / "metrics" / "_system" / "base-0.parquet").exists() + assert (workdir / "metrics" / "user.sw-datastore.json").exists() + assert (workdir / "metrics" / "_system.sw-datastore.json").exists() assert (workdir / "artifacts" / "_files").exists() assert len(list((workdir / "artifacts" / "_files").iterdir())) != 0 assert (workdir / "params" / "user.json").exists()