diff --git a/client/starwhale/api/_impl/data_store.py b/client/starwhale/api/_impl/data_store.py index ef30cd34ec..6d72d63200 100644 --- a/client/starwhale/api/_impl/data_store.py +++ b/client/starwhale/api/_impl/data_store.py @@ -2,15 +2,21 @@ import os import re +import abc +import sys import json import time import atexit import base64 import struct import urllib +import inspect +import zipfile import binascii +import tempfile import importlib import threading +import contextlib from abc import ABCMeta, abstractmethod from http import HTTPStatus from typing import ( @@ -47,7 +53,7 @@ from starwhale.utils.retry import http_retry from starwhale.utils.config import SWCliConfigMixed -datastore_table_file_ext = ".sw-datastore.json" +datastore_table_file_ext = ".sw-datastore" class SwType(metaclass=ABCMeta): @@ -664,20 +670,16 @@ def __eq__(self, other: Any) -> bool: def _get_table_path(root_path: str, table_name: str) -> str: - return str(Path(root_path) / table_name.strip("/")) + datastore_table_file_ext - - -def _parse_data_table_name(name: str) -> Tuple[str, int]: - try: - if name.endswith(datastore_table_file_ext): - if name.startswith("base-"): - return "base", int(name[5 : -len(datastore_table_file_ext)]) - elif name.startswith("patch-"): - return "patch", int(name[6 : -len(datastore_table_file_ext)]) - except ValueError: - # ignore invalid filename - pass - return "", 0 + """ + get table path from table name, return the matched file path if there is only one file match the table name + """ + expect_prefix = Path(root_path) / (table_name.strip("/") + datastore_table_file_ext) + paths = list(expect_prefix.parent.glob(f"{expect_prefix.name}*")) + if len(paths) > 1: + raise RuntimeError(f"can not find table {table_name}, get files {paths}") + if len(paths) == 1: + return str(paths[0]) + return str(expect_prefix) def _merge_scan( @@ -795,6 +797,83 @@ def _get_seq_num() -> int: return time.monotonic_ns() +class Compressor(abc.ABC): + @abc.abstractmethod + def extension(self) -> str: + """ + Return the extension of the compressed file. + """ + ... + + @abc.abstractmethod + def compress(self, source: Path) -> Path: + """ + Compress the file and return the path to the compressed file. + """ + ... + + @contextlib.contextmanager + @abc.abstractmethod + def decompress(self, source: Path) -> Iterator[Path]: + """ + Decompress the file and return the path to the temp decompressed file. + And the temp file will be deleted after the context manager exits. + """ + ... + + +class NoCompressor(Compressor): + def extension(self) -> str: + return ".json" + + def compress(self, source: Path) -> Path: + # never be called + # we need to duplicate the file because the dump method will remove the source file + raise RuntimeError("should not be called") + + @contextlib.contextmanager + def decompress(self, source: Path) -> Iterator[Path]: + # compatible with the existing json file + yield source + + +class ZipCompressor(Compressor): + def extension(self) -> str: + return ".zip" + + def compress(self, source: Path) -> Path: + output = tempfile.mktemp() + with zipfile.ZipFile(output, "w", compression=zipfile.ZIP_DEFLATED) as zipf: + zipf.write(source, source.name) + return Path(output) + + @contextlib.contextmanager + def decompress(self, source: Path) -> Iterator[Path]: + with zipfile.ZipFile(source, "r") as zipf: + # extract to tmp dir + tmp_dir = tempfile.TemporaryDirectory() + zipf.extractall(tmp_dir.name) + file_name = zipf.namelist()[0] + try: + yield Path(tmp_dir.name) / file_name + finally: + tmp_dir.cleanup() + + +# get all the compressors in this module +compressors: Dict[str, Compressor] = {} +for name, obj in inspect.getmembers(sys.modules[__name__]): + if inspect.isclass(obj) and issubclass(obj, Compressor) and obj != Compressor: + compressors[obj.__name__] = obj() + + +def get_compressor(file: Path) -> Compressor: + for compressor in compressors.values(): + if file.suffix == compressor.extension(): + return compressor + raise ValueError(f"Unknown compressor for file {file}") + + class MemoryTable: def __init__(self, table_name: str, key_column: ColumnSchema) -> None: self.table_name = table_name @@ -802,6 +881,7 @@ def __init__(self, table_name: str, key_column: ColumnSchema) -> None: self.records: Dict[Any, InnerRecord] = {} self.lock = threading.Lock() self.dirty = False + self.compressor = ZipCompressor() def scan( self, @@ -874,14 +954,16 @@ def _parse_meta(cls, meta: Any) -> ColumnSchema: def loads(cls, file: Path, table_name: str) -> MemoryTable: if not file.exists() or not file.is_file(): raise RuntimeError(f"File {file} does not exist") - with jsonlines.open(file) as reader: - meta = reader.read() - key_column = cls._parse_meta(meta) - table = MemoryTable(table_name, key_column) - for record in reader: - ir = InnerRecord.loads(record) - table.records[ir.key] = ir - return table + + with get_compressor(file).decompress(file) as f: + with jsonlines.open(f) as reader: + meta = reader.read() + key_column = cls._parse_meta(meta) + table = MemoryTable(table_name, key_column) + for record in reader: + ir = InnerRecord.loads(record) + table.records[ir.key] = ir + return table def dump(self, root_path: str, if_dirty: bool = True) -> None: root = Path(root_path) @@ -893,7 +975,7 @@ def _dump(self, root_path: Path, if_dirty: bool = True) -> None: if if_dirty and not self.dirty: return - dst = root_path / f"{self.table_name}{datastore_table_file_ext}" + dst = Path(_get_table_path(str(root_path), self.table_name)) base = dst.parent temp_filename = base / f"temp.{os.getpid()}" ensure_dir(base) @@ -907,23 +989,39 @@ def _dump(self, root_path: Path, if_dirty: bool = True) -> None: continue writer.write(ir.dumps()) - os.rename(temp_filename, dst) + compressed = self.compressor.compress(temp_filename) + os.unlink(temp_filename) + if dst.suffix == datastore_table_file_ext: + # the dst file is must not exist, we never save a table as a sw-datastore file + # use the same extension as compressed file + ext = datastore_table_file_ext + self.compressor.extension() + else: + # the dst file is a compressed file, change the extension + ext = self.compressor.extension() + + # make dst file have the same extension as compressed file + new_dst = dst.with_suffix(ext) + os.rename(compressed, new_dst) + if new_dst != dst and dst.exists(): + # remove the old file if it is not the new file name + dst.unlink() self.dirty = False - def _dump_from_local_file(self, file: Path, output: Writer) -> Set[str]: + def _dump_from_local_file(self, existing: Path, output: Writer) -> Set[str]: dumped_keys: Set[str] = set() - if not file.exists(): + if not existing.exists(): return dumped_keys - with jsonlines.open(file, mode="r") as reader: - self._parse_meta(reader.read()) - for i in reader: - ir = InnerRecord.loads(i) - r = self.records.get(ir.key) - ir.update(r) - dumped_keys.add(ir.key) - output.write(ir.dumps()) + with get_compressor(existing).decompress(existing) as f: + with jsonlines.open(f, mode="r") as reader: + self._parse_meta(reader.read()) + for i in reader: + ir = InnerRecord.loads(i) + r = self.records.get(ir.key) + ir.update(r) + dumped_keys.add(ir.key) + output.write(ir.dumps()) return dumped_keys diff --git a/client/tests/sdk/test_data_store.py b/client/tests/sdk/test_data_store.py index 0ab35d0880..e57f74c876 100644 --- a/client/tests/sdk/test_data_store.py +++ b/client/tests/sdk/test_data_store.py @@ -20,45 +20,18 @@ class TestBasicFunctions(BaseTestCase): def test_get_table_path(self) -> None: self.assertEqual( - os.path.join("a", "b.sw-datastore.json"), + os.path.join("a", "b.sw-datastore"), data_store._get_table_path("a", "b"), ) self.assertEqual( - os.path.join("a", "b", "c.sw-datastore.json"), + os.path.join("a", "b", "c.sw-datastore"), data_store._get_table_path("a", "b/c"), ) self.assertEqual( - os.path.join("a", "b", "c", "d.sw-datastore.json"), + os.path.join("a", "b", "c", "d.sw-datastore"), data_store._get_table_path("a", "b/c/d"), ) - def test_parse_data_table_name(self) -> None: - self.assertEqual( - ("", 0), - data_store._parse_data_table_name("base-123.txt"), - "invalid extension", - ) - self.assertEqual( - ("", 0), - data_store._parse_data_table_name("base_1.sw-datastore.json"), - "invalid prefix", - ) - self.assertEqual( - ("", 0), - data_store._parse_data_table_name("base-i.sw-datastore.json"), - "invalid index", - ) - self.assertEqual( - ("base", 123), - data_store._parse_data_table_name("base-123.sw-datastore.json"), - "base", - ) - self.assertEqual( - ("patch", 123), - data_store._parse_data_table_name("patch-123.sw-datastore.json"), - "patch", - ) - def test_merge_scan(self) -> None: self.assertEqual([], list(data_store._merge_scan([], False)), "no iter") self.assertEqual( diff --git a/client/tests/sdk/test_track.py b/client/tests/sdk/test_track.py index 2b1edeb1a2..6d1544d748 100644 --- a/client/tests/sdk/test_track.py +++ b/client/tests/sdk/test_track.py @@ -439,7 +439,7 @@ def test_handle_metrics(self) -> None: assert isinstance(h._table_writers["metrics/user"], TableWriter) h.flush() - datastore_file_path = workdir / "metrics" / "user.sw-datastore.json" + datastore_file_path = workdir / "metrics" / "user.sw-datastore.zip" assert datastore_file_path.exists() assert datastore_file_path.is_file() @@ -495,7 +495,7 @@ def test_handle_artifacts(self) -> None: h.flush() - datastore_file_path = workdir / "artifacts" / "user.sw-datastore.json" + datastore_file_path = workdir / "artifacts" / "user.sw-datastore.zip" assert datastore_file_path.exists() assert datastore_file_path.is_file() @@ -549,8 +549,8 @@ def test_run(self) -> None: assert "metrics/_system" in h._table_writers assert "artifacts/user" in h._table_writers - assert (workdir / "metrics" / "user.sw-datastore.json").exists() - assert (workdir / "metrics" / "_system.sw-datastore.json").exists() + assert (workdir / "metrics" / "user.sw-datastore.zip").exists() + assert (workdir / "metrics" / "_system.sw-datastore.zip").exists() assert (workdir / "artifacts" / "_files").exists() assert len(list((workdir / "artifacts" / "_files").iterdir())) != 0 assert (workdir / "params" / "user.json").exists()