From 52cafb99489359309aa64a44c3e049860b97e054 Mon Sep 17 00:00:00 2001 From: tianwei Date: Thu, 11 Aug 2022 18:51:41 +0800 Subject: [PATCH] enhance(datastore): tune datastore storage dir by swcli config (#902) tune datastore storage dir by swcli config --- client/starwhale/api/_impl/data_store.py | 50 ++++++++++++++---------- client/starwhale/utils/config.py | 16 +------- client/tests/sdk/test_base.py | 21 +++++++--- client/tests/sdk/test_data_store.py | 6 ++- client/tests/sdk/test_wrapper.py | 2 +- 5 files changed, 54 insertions(+), 41 deletions(-) diff --git a/client/starwhale/api/_impl/data_store.py b/client/starwhale/api/_impl/data_store.py index 4d8f2a041a..4681676fa9 100644 --- a/client/starwhale/api/_impl/data_store.py +++ b/client/starwhale/api/_impl/data_store.py @@ -12,6 +12,9 @@ import pyarrow.parquet as pq # type: ignore from typing_extensions import Protocol +from starwhale.utils.fs import ensure_dir +from starwhale.utils.config import SWCliConfigMixed + class Type: def __init__( @@ -286,12 +289,12 @@ def nextItem(self) -> None: self.item = None self.key = "" - n = len(iters) nodes = [] - for i in range(n): - node = Node(i, iters[i]) - if not node.exhausted: - nodes.append(node) + for _i, _iter in enumerate(iters): + _node = Node(_i, _iter) + if not _node.exhausted: + nodes.append(_node) + while len(nodes) > 0: key = min(nodes, key=lambda x: x.key).key d: Dict[str, Any] = {} @@ -315,6 +318,9 @@ def nextItem(self) -> None: def _get_table_files(path: str) -> List[str]: if not os.path.exists(path): return [] + if not os.path.isdir(path): + raise RuntimeError(f"{path} is not a directory") + patches = [] base_index = -1 for file in os.listdir(path): @@ -339,15 +345,19 @@ def _read_table_schema(path: str) -> TableSchema: raise RuntimeError(f"path not found: {path}") if not os.path.isdir(path): raise RuntimeError(f"{path} is not a directory") + files = _get_table_files(path) if len(files) == 0: raise RuntimeError(f"table is empty, path:{path}") + schema = pq.read_schema(files[-1]) if schema.metadata is None: raise RuntimeError(f"no metadata for file {files[-1]}") + schema_data = schema.metadata.get(b"schema", None) if schema_data is None: raise RuntimeError(f"no schema for file {files[-1]}") + return TableSchema.parse(schema_data.decode()) @@ -361,6 +371,7 @@ def _scan_table( iters = [] for file in _get_table_files(path): iters.append(_scan_parquet_file(file, columns, start, end)) + column_names = [] if len(iters) > 0: schema = _read_table_schema(path) column_names = [ @@ -413,7 +424,7 @@ def _records_to_table( def _get_size(d: Any) -> int: ret = sys.getsizeof(d) - if type(d) is dict: + if isinstance(d, dict): for v in d.values(): ret += sys.getsizeof(v) return ret @@ -525,10 +536,8 @@ def delete(self, keys: List[Any]) -> None: def dump(self, root_path: str) -> None: path = _get_table_path(root_path, self.table_name) - if not os.path.exists(path): - os.mkdir(path) - if not os.path.isdir(path): - raise RuntimeError(f"{path} is not a directory") + ensure_dir(path) + max_index = -1 for file in os.listdir(path): type, index = _parse_parquet_name(file) @@ -556,12 +565,11 @@ class LocalDataStore: def get_instance() -> "LocalDataStore": with LocalDataStore._lock: if LocalDataStore._instance is None: - root_path = os.getenv("SW_ROOT_PATH", None) - if root_path is None: - raise RuntimeError( - "data store root path is not defined for standalone instance" - ) - LocalDataStore._instance = LocalDataStore(root_path) + + ds_path = SWCliConfigMixed().datastore_dir + ensure_dir(ds_path) + + LocalDataStore._instance = LocalDataStore(str(ds_path)) atexit.register(LocalDataStore._instance.dump) return LocalDataStore._instance @@ -630,7 +638,7 @@ def __init__( self, name: str, key_column_type: pa.DataType, - columns: Dict[str, str], + columns: Optional[Dict[str, str]], explicit_none: bool, ) -> None: self.name = name @@ -648,6 +656,8 @@ def __init__( key_column_type = schema.columns[schema.key_column].type.pa_type column_names = schema.columns.keys() col_prefix = table_alias + "." + + cols: Optional[Dict[str, str]] if columns is None or col_prefix + "*" in columns: cols = None else: @@ -659,15 +669,15 @@ def __init__( cols[name] = alias infos.append(TableInfo(table_name, key_column_type, cols, explicit_none)) - # check for key type conflication + # check for key type conflictions for info in infos: if info is infos[0]: continue if info.key_column_type != infos[0].key_column_type: raise RuntimeError( "conflicting key field type. " - + f"{info.name} has a key of type {info.key_column_type}," - + f" while {infos[0].name} has a key of type {infos[0].key_column_type}" + f"{info.name} has a key of type {info.key_column_type}," + f" while {infos[0].name} has a key of type {infos[0].key_column_type}" ) iters = [] for info in infos: diff --git a/client/starwhale/utils/config.py b/client/starwhale/utils/config.py index 2ad73e7019..c62cccbc63 100644 --- a/client/starwhale/utils/config.py +++ b/client/starwhale/utils/config.py @@ -109,20 +109,8 @@ def rootdir(self) -> Path: return Path(self._config["storage"]["root"]) @property - def workdir(self) -> Path: - return self.rootdir / "workdir" - - @property - def pkgdir(self) -> Path: - return self.rootdir / "pkg" - - @property - def dataset_dir(self) -> Path: - return self.rootdir / "dataset" - - @property - def eval_run_dir(self) -> Path: - return self.rootdir / "run" / "eval" + def datastore_dir(self) -> Path: + return self.rootdir / ".datastore" @property def sw_remote_addr(self) -> str: diff --git a/client/tests/sdk/test_base.py b/client/tests/sdk/test_base.py index 75413477ec..93ba4870f4 100644 --- a/client/tests/sdk/test_base.py +++ b/client/tests/sdk/test_base.py @@ -1,14 +1,25 @@ import os -import shutil import tempfile import unittest +from starwhale.utils import config as sw_config +from starwhale.consts import ENV_SW_CLI_CONFIG, ENV_SW_LOCAL_STORAGE +from starwhale.utils.fs import empty_dir, ensure_dir + class BaseTestCase(unittest.TestCase): def setUp(self) -> None: - self.root = os.path.join(tempfile.gettempdir(), "datastore_test") - os.makedirs(self.root, exist_ok=True) - os.environ["SW_ROOT_PATH"] = self.root + self._test_local_storage = tempfile.mkdtemp(prefix="sw-test-mock-") + os.environ[ENV_SW_CLI_CONFIG] = os.path.join( + self._test_local_storage, "config.yaml" + ) + os.environ[ENV_SW_LOCAL_STORAGE] = self._test_local_storage + sw_config._config = {} + + self.root = str(sw_config.SWCliConfigMixed().datastore_dir) + ensure_dir(self.root) def tearDown(self) -> None: - shutil.rmtree(self.root) + empty_dir(self._test_local_storage) + os.environ.pop(ENV_SW_CLI_CONFIG, "") + os.environ.pop(ENV_SW_LOCAL_STORAGE, "") diff --git a/client/tests/sdk/test_data_store.py b/client/tests/sdk/test_data_store.py index e00e5280d9..a726719332 100644 --- a/client/tests/sdk/test_data_store.py +++ b/client/tests/sdk/test_data_store.py @@ -1,6 +1,7 @@ import os import unittest from typing import Dict, List +from unittest.mock import patch, MagicMock import numpy as np import pyarrow as pa # type: ignore @@ -940,13 +941,16 @@ def test_data_store_scan(self) -> None: class TestTableWriter(BaseTestCase): def setUp(self) -> None: + self.mock_atexit = patch("starwhale.api._impl.data_store.atexit", MagicMock()) + self.mock_atexit.start() + super().setUp() - os.environ["SW_ROOT_PATH"] = self.root self.writer = data_store.TableWriter("p/test", "k") def tearDown(self) -> None: self.writer.close() super().tearDown() + self.mock_atexit.stop() def test_insert_and_delete(self) -> None: with self.assertRaises(RuntimeError, msg="no key"): diff --git a/client/tests/sdk/test_wrapper.py b/client/tests/sdk/test_wrapper.py index 98cca13d32..8b4966b984 100644 --- a/client/tests/sdk/test_wrapper.py +++ b/client/tests/sdk/test_wrapper.py @@ -5,7 +5,7 @@ from .test_base import BaseTestCase -class TestEvaluaiton(BaseTestCase): +class TestEvaluation(BaseTestCase): def setUp(self) -> None: super().setUp() os.environ["SW_PROJECT"] = "test"