Skip to content

Commit

Permalink
fix(datastore): lazy to init table name for eval & replace name with …
Browse files Browse the repository at this point in the history
…id (#1840)
  • Loading branch information
goldenxinxing authored Feb 16, 2023
1 parent 68382f0 commit 7d52d61
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 109 deletions.
35 changes: 0 additions & 35 deletions client/starwhale/api/_impl/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1409,41 +1409,6 @@ def get_data_store(instance_uri: str = "", token: str = "") -> DataStore:
)


def table_name_generator(project: Union[str, int], table: str) -> str:
return f"project/{project}/{table}"


def gen_table_name(project: Union[str, int], table: str, instance_uri: str = "") -> str:
_instance_uri = instance_uri or os.getenv(SWEnv.instance_uri)
if (
_instance_uri is None
or _instance_uri == STANDALONE_INSTANCE
or type(project) == int
):
return table_name_generator(project, table)
else:
return table_name_generator(
_get_remote_project_id(_instance_uri, project), table
)


@http_retry
def _get_remote_project_id(instance_uri: str, project: Union[str, int]) -> Any:
resp = requests.get(
urllib.parse.urljoin(instance_uri, f"/api/v1/project/{project}"),
headers={
"Content-Type": "application/json; charset=utf-8",
"Authorization": (
SWCliConfigMixed().get_sw_token(instance=instance_uri)
or os.getenv(SWEnv.instance_token, "")
),
},
timeout=60,
)
resp.raise_for_status()
return resp.json().get("data", {})["id"]


def _flatten(record: Dict[str, Any]) -> Dict[str, Any]:
def _new(key_prefix: str, src: Dict[str, Any], dest: Dict[str, Any]) -> None:
for k, v in src.items():
Expand Down
117 changes: 90 additions & 27 deletions client/starwhale/api/_impl/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
import os
import re
import urllib
import threading
from enum import Enum, unique
from typing import Any, Dict, List, Union, Iterator, Optional
from typing import Any, Dict, List, Union, Callable, Iterator, Optional
from functools import lru_cache

import dill
import requests
from loguru import logger

from starwhale.consts import VERSION_PREFIX_CNT
from starwhale.consts import VERSION_PREFIX_CNT, STANDALONE_INSTANCE

from . import data_store
from ...consts.env import SWEnv
from ...utils.retry import http_retry
from ...utils.config import SWCliConfigMixed


class Logger:
Expand Down Expand Up @@ -71,6 +78,45 @@ def _deserialize(data: bytes) -> Any:
return dill.loads(data)


table_name_formatter: Callable[
[Union[str, int], str], str
] = lambda project, table: f"project/{project}/{table}"


def _gen_storage_table_name(
project: Union[str, int], table: str, instance_uri: str = ""
) -> str:
_instance_uri = instance_uri or os.getenv(SWEnv.instance_uri)
if (
_instance_uri is None
or _instance_uri == STANDALONE_INSTANCE
or isinstance(project, int)
):
return table_name_formatter(project, table)
else:
return table_name_formatter(
_get_remote_project_id(_instance_uri, project), table
)


@lru_cache(maxsize=None)
@http_retry
def _get_remote_project_id(instance_uri: str, project: str) -> Any:
resp = requests.get(
urllib.parse.urljoin(instance_uri, f"/api/v1/project/{project}"),
headers={
"Content-Type": "application/json; charset=utf-8",
"Authorization": (
SWCliConfigMixed().get_sw_token(instance=instance_uri)
or os.getenv(SWEnv.instance_token, "")
),
},
timeout=60,
)
resp.raise_for_status()
return resp.json().get("data", {})["id"]


class Evaluation(Logger):
def __init__(self, eval_id: str, project: str, instance: str = ""):
if not eval_id:
Expand All @@ -85,20 +131,41 @@ def __init__(self, eval_id: str, project: str, instance: str = ""):
self.eval_id = eval_id
self.project = project
self.instance = instance
self._results_table_name = self._get_datastore_table_name("results")
self._summary_table_name = data_store.gen_table_name(
project=self.project, table="eval/summary", instance_uri=self.instance
self._tables: Dict[str, str] = {}
self._eval_table_name: Callable[
[str], str
] = (
lambda name: f"eval/{self.eval_id[:VERSION_PREFIX_CNT]}/{self.eval_id}/{name}"
)
self._eval_summary_table_name = "eval/summary"
self._data_store = data_store.get_data_store(instance_uri=instance)
self._init_writers([self._results_table_name, self._summary_table_name])
self._init_writers([])

def _get_storage_table_name(self, table: str) -> str:
with self._lock:
_table_name = self._tables.get(table)
if _table_name is None:
_table_name = _gen_storage_table_name(
project=self.project,
table=table,
instance_uri=self.instance,
)
self._tables[table] = _table_name
return _table_name

def _log(self, table_name: str, record: Dict[str, Any]) -> None:
_storage_table_name = self._get_storage_table_name(table_name)
super()._log(_storage_table_name, record)

def _get_datastore_table_name(self, name: str) -> str:
return data_store.gen_table_name(
project=self.project,
table=f"eval/{self.eval_id[:VERSION_PREFIX_CNT]}/{self.eval_id}/{name}",
instance_uri=self.instance,
def _get(self, table_name: str) -> Iterator[Dict[str, Any]]:
return self._data_store.scan_tables(
[data_store.TableDesc(self._get_storage_table_name(table=table_name))]
)

def _flush(self, table_name: str) -> None:
_storage_table_name = self._get_storage_table_name(table_name)
super()._flush(_storage_table_name)

def log_result(
self,
data_id: Union[int, str],
Expand All @@ -109,7 +176,8 @@ def log_result(
record = {"id": data_id, "result": _serialize(result) if serialize else result}
for k, v in kwargs.items():
record[k.lower()] = _serialize(v) if serialize else v
self._log(self._results_table_name, record)

self._log(self._eval_table_name("results"), record)

def log_metrics(
self, metrics: Optional[Dict[str, Any]] = None, **kwargs: Any
Expand All @@ -124,18 +192,17 @@ def log_metrics(
else:
for k, v in kwargs.items():
record[k.lower()] = v
self._log(self._summary_table_name, record)

self._log(self._eval_summary_table_name, record)

def log(self, table_name: str, **kwargs: Any) -> None:
record = {}
for k, v in kwargs.items():
record[k.lower()] = v
self._log(self._get_datastore_table_name(table_name), record)
self._log(self._eval_table_name(table_name), record)

def get_results(self, deserialize: bool = False) -> Iterator[Dict[str, Any]]:
for data in self._data_store.scan_tables(
[data_store.TableDesc(self._results_table_name)]
):
for data in self._get(self._eval_table_name("results")):
if deserialize:
for _k, _v in data.items():
if _k == "id":
Expand All @@ -144,27 +211,23 @@ def get_results(self, deserialize: bool = False) -> Iterator[Dict[str, Any]]:
yield data

def get_metrics(self) -> Dict[str, Any]:
for metrics in self._data_store.scan_tables(
[data_store.TableDesc(self._summary_table_name)]
):
for metrics in self._get(self._eval_summary_table_name):
if metrics["id"] == self.eval_id:
return metrics

return {}

def get(self, table_name: str) -> Iterator[Dict[str, Any]]:
return self._data_store.scan_tables(
[data_store.TableDesc(self._get_datastore_table_name(table_name))]
)
return self._get(self._eval_table_name(table_name))

def flush_result(self) -> None:
self._flush(self._results_table_name)
self._flush(self._eval_table_name("results"))

def flush_metrics(self) -> None:
self._flush(self._summary_table_name)
self._flush(self._eval_summary_table_name)

def flush(self, table_name: str) -> None:
self._flush(table_name)
self._flush(self._eval_table_name(table_name))


@unique
Expand All @@ -191,7 +254,7 @@ def __init__(
self.dataset_id = dataset_id
self.project = project
self._kind = kind
self._table_name = data_store.gen_table_name(
self._table_name = _gen_storage_table_name(
project=project,
table=f"dataset/{self.dataset_id}/{kind.value}",
instance_uri=instance_name,
Expand Down
19 changes: 0 additions & 19 deletions client/tests/sdk/test_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2382,25 +2382,6 @@ def test_scan_table(self, mock_post: Mock) -> None:
timeout=60,
)

@patch("starwhale.api._impl.data_store.requests.get")
def test_gen_table_name(self, mock_get: Mock) -> None:
table_name = data_store.gen_table_name(project="starwhale", table="test")
assert table_name == "project/starwhale/test"

instance_uri = "http://1.1.1.1:8182"
data_store.gen_table_name(
project="starwhale", table="test", instance_uri=instance_uri
)

mock_get.assert_called_with(
f"{instance_uri}/api/v1/project/starwhale",
headers={
"Content-Type": "application/json; charset=utf-8",
"Authorization": "",
},
timeout=60,
)


class TestTableWriter(BaseTestCase):
def setUp(self) -> None:
Expand Down
30 changes: 30 additions & 0 deletions client/tests/sdk/test_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from unittest.mock import patch

from requests_mock import Mocker

Expand All @@ -9,6 +10,7 @@
from .test_base import BaseTestCase


@patch.dict(os.environ, {"SW_TOKEN": "sw_token"})
class TestEvaluation(BaseTestCase):
def setUp(self) -> None:
super().setUp()
Expand All @@ -18,6 +20,34 @@ def tearDown(self) -> None:
os.environ.pop(SWEnv.instance_uri, None)
os.environ.pop(SWEnv.instance_token, None)

@Mocker()
def test_gen_table_name(self, request_mock: Mocker) -> None:
request_mock.request(
HTTPMethod.GET,
"http://localhost:80/api/v1/project/project-test",
json={"data": {"id": 1, "name": "project-test"}},
)

eval = wrapper.Evaluation("123456", "project-test", instance="local")

result_table_name = eval._eval_table_name("results")
assert result_table_name == "eval/12/123456/results"

table_name_1 = eval._get_storage_table_name("table-1")
assert table_name_1 == "project/project-test/table-1"

eval = wrapper.Evaluation(
"123456", "project-test", instance="http://localhost:80"
)

result_table = eval._eval_table_name("results")
result_table_name = eval._get_storage_table_name(result_table)
assert result_table == "eval/12/123456/results"
assert result_table_name == "project/1/eval/12/123456/results"

table_name_1 = eval._get_storage_table_name("table-1")
assert table_name_1 == "project/1/table-1"

def test_log_results_and_scan(self) -> None:
eval = wrapper.Evaluation("tt", "test")
eval.log_result("0", 3)
Expand Down
Loading

0 comments on commit 7d52d61

Please sign in to comment.