Skip to content

Commit

Permalink
chore(client): compress data store table file
Browse files Browse the repository at this point in the history
  • Loading branch information
jialeicui committed Apr 7, 2023
1 parent 3d1da21 commit 38db434
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 69 deletions.
168 changes: 133 additions & 35 deletions client/starwhale/api/_impl/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,21 @@

import os
import re
import abc
import sys
import json
import time
import atexit
import base64
import struct
import urllib
import inspect
import zipfile
import binascii
import tempfile
import importlib
import threading
import contextlib
from abc import ABCMeta, abstractmethod
from http import HTTPStatus
from typing import (
Expand Down Expand Up @@ -47,7 +53,7 @@
from starwhale.utils.retry import http_retry
from starwhale.utils.config import SWCliConfigMixed

datastore_table_file_ext = ".sw-datastore.json"
datastore_table_file_ext = ".sw-datastore"


class SwType(metaclass=ABCMeta):
Expand Down Expand Up @@ -664,20 +670,16 @@ def __eq__(self, other: Any) -> bool:


def _get_table_path(root_path: str, table_name: str) -> str:
return str(Path(root_path) / table_name.strip("/")) + datastore_table_file_ext


def _parse_data_table_name(name: str) -> Tuple[str, int]:
try:
if name.endswith(datastore_table_file_ext):
if name.startswith("base-"):
return "base", int(name[5 : -len(datastore_table_file_ext)])
elif name.startswith("patch-"):
return "patch", int(name[6 : -len(datastore_table_file_ext)])
except ValueError:
# ignore invalid filename
pass
return "", 0
"""
get table path from table name, return the matched file path if there is only one file match the table name
"""
expect_prefix = Path(root_path) / (table_name.strip("/") + datastore_table_file_ext)
paths = list(expect_prefix.parent.glob(f"{expect_prefix.name}*"))
if len(paths) > 1:
raise RuntimeError(f"can not find table {table_name}, get files {paths}")
if len(paths) == 1:
return str(paths[0])
return str(expect_prefix)


def _merge_scan(
Expand Down Expand Up @@ -795,13 +797,91 @@ def _get_seq_num() -> int:
return time.monotonic_ns()


class Compressor(abc.ABC):
@abc.abstractmethod
def extension(self) -> str:
"""
Return the extension of the compressed file.
"""
...

@abc.abstractmethod
def compress(self, source: Path) -> Path:
"""
Compress the file and return the path to the compressed file.
"""
...

@contextlib.contextmanager
@abc.abstractmethod
def decompress(self, source: Path) -> Iterator[Path]:
"""
Decompress the file and return the path to the temp decompressed file.
And the temp file will be deleted after the context manager exits.
"""
...


class NoCompressor(Compressor):
def extension(self) -> str:
return ".json"

def compress(self, source: Path) -> Path:
# never be called
# we need to duplicate the file because the dump method will remove the source file
raise RuntimeError("should not be called")

@contextlib.contextmanager
def decompress(self, source: Path) -> Iterator[Path]:
# compatible with the existing json file
yield source


class ZipCompressor(Compressor):
def extension(self) -> str:
return ".zip"

def compress(self, source: Path) -> Path:
output = tempfile.mktemp()
with zipfile.ZipFile(output, "w", compression=zipfile.ZIP_DEFLATED) as zipf:
zipf.write(source, source.name)
return Path(output)

@contextlib.contextmanager
def decompress(self, source: Path) -> Iterator[Path]:
with zipfile.ZipFile(source, "r") as zipf:
# extract to tmp dir
tmp_dir = tempfile.TemporaryDirectory()
zipf.extractall(tmp_dir.name)
file_name = zipf.namelist()[0]
try:
yield Path(tmp_dir.name) / file_name
finally:
tmp_dir.cleanup()


# get all the compressors in this module
compressors: Dict[str, Compressor] = {}
for name, obj in inspect.getmembers(sys.modules[__name__]):
if inspect.isclass(obj) and issubclass(obj, Compressor) and obj != Compressor:
compressors[obj.__name__] = obj()


def get_compressor(file: Path) -> Compressor:
for compressor in compressors.values():
if file.suffix == compressor.extension():
return compressor
raise ValueError(f"Unknown compressor for file {file}")


class MemoryTable:
def __init__(self, table_name: str, key_column: ColumnSchema) -> None:
self.table_name = table_name
self.key_column = key_column
self.records: Dict[Any, InnerRecord] = {}
self.lock = threading.Lock()
self.dirty = False
self.compressor = ZipCompressor()

def scan(
self,
Expand Down Expand Up @@ -874,14 +954,16 @@ def _parse_meta(cls, meta: Any) -> ColumnSchema:
def loads(cls, file: Path, table_name: str) -> MemoryTable:
if not file.exists() or not file.is_file():
raise RuntimeError(f"File {file} does not exist")
with jsonlines.open(file) as reader:
meta = reader.read()
key_column = cls._parse_meta(meta)
table = MemoryTable(table_name, key_column)
for record in reader:
ir = InnerRecord.loads(record)
table.records[ir.key] = ir
return table

with get_compressor(file).decompress(file) as f:
with jsonlines.open(f) as reader:
meta = reader.read()
key_column = cls._parse_meta(meta)
table = MemoryTable(table_name, key_column)
for record in reader:
ir = InnerRecord.loads(record)
table.records[ir.key] = ir
return table

def dump(self, root_path: str, if_dirty: bool = True) -> None:
root = Path(root_path)
Expand All @@ -893,7 +975,7 @@ def _dump(self, root_path: Path, if_dirty: bool = True) -> None:
if if_dirty and not self.dirty:
return

dst = root_path / f"{self.table_name}{datastore_table_file_ext}"
dst = Path(_get_table_path(str(root_path), self.table_name))
base = dst.parent
temp_filename = base / f"temp.{os.getpid()}"
ensure_dir(base)
Expand All @@ -907,23 +989,39 @@ def _dump(self, root_path: Path, if_dirty: bool = True) -> None:
continue
writer.write(ir.dumps())

os.rename(temp_filename, dst)
compressed = self.compressor.compress(temp_filename)
os.unlink(temp_filename)
if dst.suffix == datastore_table_file_ext:
# the dst file is must not exist, we never save a table as a sw-datastore file
# use the same extension as compressed file
ext = datastore_table_file_ext + self.compressor.extension()
else:
# the dst file is a compressed file, change the extension
ext = self.compressor.extension()

# make dst file have the same extension as compressed file
new_dst = dst.with_suffix(ext)
os.rename(compressed, new_dst)
if new_dst != dst and dst.exists():
# remove the old file if it is not the new file name
dst.unlink()
self.dirty = False

def _dump_from_local_file(self, file: Path, output: Writer) -> Set[str]:
def _dump_from_local_file(self, existing: Path, output: Writer) -> Set[str]:
dumped_keys: Set[str] = set()

if not file.exists():
if not existing.exists():
return dumped_keys

with jsonlines.open(file, mode="r") as reader:
self._parse_meta(reader.read())
for i in reader:
ir = InnerRecord.loads(i)
r = self.records.get(ir.key)
ir.update(r)
dumped_keys.add(ir.key)
output.write(ir.dumps())
with get_compressor(existing).decompress(existing) as f:
with jsonlines.open(f, mode="r") as reader:
self._parse_meta(reader.read())
for i in reader:
ir = InnerRecord.loads(i)
r = self.records.get(ir.key)
ir.update(r)
dumped_keys.add(ir.key)
output.write(ir.dumps())

return dumped_keys

Expand Down
33 changes: 3 additions & 30 deletions client/tests/sdk/test_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,45 +20,18 @@
class TestBasicFunctions(BaseTestCase):
def test_get_table_path(self) -> None:
self.assertEqual(
os.path.join("a", "b.sw-datastore.json"),
os.path.join("a", "b.sw-datastore"),
data_store._get_table_path("a", "b"),
)
self.assertEqual(
os.path.join("a", "b", "c.sw-datastore.json"),
os.path.join("a", "b", "c.sw-datastore"),
data_store._get_table_path("a", "b/c"),
)
self.assertEqual(
os.path.join("a", "b", "c", "d.sw-datastore.json"),
os.path.join("a", "b", "c", "d.sw-datastore"),
data_store._get_table_path("a", "b/c/d"),
)

def test_parse_data_table_name(self) -> None:
self.assertEqual(
("", 0),
data_store._parse_data_table_name("base-123.txt"),
"invalid extension",
)
self.assertEqual(
("", 0),
data_store._parse_data_table_name("base_1.sw-datastore.json"),
"invalid prefix",
)
self.assertEqual(
("", 0),
data_store._parse_data_table_name("base-i.sw-datastore.json"),
"invalid index",
)
self.assertEqual(
("base", 123),
data_store._parse_data_table_name("base-123.sw-datastore.json"),
"base",
)
self.assertEqual(
("patch", 123),
data_store._parse_data_table_name("patch-123.sw-datastore.json"),
"patch",
)

def test_merge_scan(self) -> None:
self.assertEqual([], list(data_store._merge_scan([], False)), "no iter")
self.assertEqual(
Expand Down
8 changes: 4 additions & 4 deletions client/tests/sdk/test_track.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def test_handle_metrics(self) -> None:
assert isinstance(h._table_writers["metrics/user"], TableWriter)

h.flush()
datastore_file_path = workdir / "metrics" / "user.sw-datastore.json"
datastore_file_path = workdir / "metrics" / "user.sw-datastore.zip"
assert datastore_file_path.exists()
assert datastore_file_path.is_file()

Expand Down Expand Up @@ -495,7 +495,7 @@ def test_handle_artifacts(self) -> None:

h.flush()

datastore_file_path = workdir / "artifacts" / "user.sw-datastore.json"
datastore_file_path = workdir / "artifacts" / "user.sw-datastore.zip"
assert datastore_file_path.exists()
assert datastore_file_path.is_file()

Expand Down Expand Up @@ -549,8 +549,8 @@ def test_run(self) -> None:
assert "metrics/_system" in h._table_writers
assert "artifacts/user" in h._table_writers

assert (workdir / "metrics" / "user.sw-datastore.json").exists()
assert (workdir / "metrics" / "_system.sw-datastore.json").exists()
assert (workdir / "metrics" / "user.sw-datastore.zip").exists()
assert (workdir / "metrics" / "_system.sw-datastore.zip").exists()
assert (workdir / "artifacts" / "_files").exists()
assert len(list((workdir / "artifacts" / "_files").iterdir())) != 0
assert (workdir / "params" / "user.json").exists()
Expand Down

0 comments on commit 38db434

Please sign in to comment.