Skip to content

Commit

Permalink
PipelineHandler support predict and evaluate function interface
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweidut committed May 8, 2023
1 parent 5188222 commit e105bc6
Show file tree
Hide file tree
Showing 17 changed files with 187 additions and 105 deletions.
1 change: 1 addition & 0 deletions .github/workflows/client.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -194,4 +194,5 @@ jobs:
env:
GITHUB_ACTION: 1
PYTHON_VERSION: ${{matrix.python-version}}
SKIP_UI_BUILD: 1
run: bash scripts/client_test/cli_test.sh sdk simple
2 changes: 2 additions & 0 deletions .github/workflows/e2e-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ jobs:
- name: Test by client side
working-directory: ./scripts/e2e_test
run: bash start_test.sh client_test
env:
SKIP_UI_BUILD: 1

- name: Post output client-side test logs
if: failure()
Expand Down
83 changes: 51 additions & 32 deletions client/starwhale/api/_impl/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

from __future__ import annotations
import time
import typing as t
import threading
from abc import ABCMeta, abstractmethod
from abc import ABCMeta
from types import TracebackType
from pathlib import Path
from functools import wraps
Expand Down Expand Up @@ -31,14 +31,14 @@
class PipelineHandler(metaclass=ABCMeta):
def __init__(
self,
ppl_batch_size: int = 1,
predict_batch_size: int = 1,
ignore_dataset_data: bool = False,
ignore_error: bool = False,
flush_result: bool = False,
ppl_auto_log: bool = True,
predict_auto_log: bool = True,
dataset_uris: t.Optional[t.List[str]] = None,
) -> None:
self.ppl_batch_size = ppl_batch_size
self.predict_batch_size = predict_batch_size
self.svc = Service()
self.context = Context.get_runtime_context()

Expand All @@ -48,7 +48,7 @@ def __init__(
self.ignore_dataset_data = ignore_dataset_data
self.ignore_error = ignore_error
self.flush_result = flush_result
self.ppl_auto_log = ppl_auto_log
self.predict_auto_log = predict_auto_log

_logdir = JobStorage.local_run_dir(self.context.project, self.context.version)
_run_dir = (
Expand Down Expand Up @@ -84,15 +84,6 @@ def __exit__(

self._timeline_writer.close()

@abstractmethod
def ppl(self, data: t.Any, **kw: t.Any) -> t.Any:
# TODO: how to handle each element is not equal.
raise NotImplementedError

@abstractmethod
def cmp(self, *args: t.Any, **kw: t.Any) -> t.Any:
raise NotImplementedError

def _record_status(func): # type: ignore
@wraps(func) # type: ignore
def _wrapper(*args: t.Any, **kwargs: t.Any) -> None:
Expand All @@ -113,27 +104,55 @@ def _wrapper(*args: t.Any, **kwargs: t.Any) -> None:
return _wrapper

@_record_status # type: ignore
def _starwhale_internal_run_cmp(self) -> None:
def _starwhale_internal_run_evaluate(self) -> None:
now = now_str()
try:
if self.ppl_auto_log:
self.cmp(self.evaluation_store.get_results(deserialize=True))
if self.predict_auto_log:
self._do_evaluate(self.evaluation_store.get_results(deserialize=True))
else:
self.cmp()
self._do_evaluate()
except Exception as e:
console.exception(f"cmp exception: {e}")
console.exception(f"evaluate exception: {e}")
self._timeline_writer.write(
{"time": now, "status": False, "exception": str(e)}
)
raise
else:
self._timeline_writer.write({"time": now, "status": True, "exception": ""})

def _is_ppl_batch(self) -> bool:
return self.ppl_batch_size > 1
def _do_predict(self, *args: t.Any, **kw: t.Any) -> t.Any:
predict_func = getattr(self, "predict", None)
ppl_func = getattr(self, "ppl", None)

if predict_func and ppl_func:
raise ParameterError("predict and ppl cannot be defined at the same time")

if predict_func:
return predict_func(*args, **kw)
elif ppl_func:
return ppl_func(*args, **kw)
else:
raise ParameterError(
"predict or ppl must be defined, predict function is recommended"
)

def _do_evaluate(self, *args: t.Any, **kw: t.Any) -> t.Any:
evaluate_func = getattr(self, "evaluate", None)
cmp_func = getattr(self, "cmp", None)
if evaluate_func and cmp_func:
raise ParameterError("evaluate and cmp cannot be defined at the same time")

if evaluate_func:
return evaluate_func(*args, **kw)
elif cmp_func:
return cmp_func(*args, **kw)
else:
raise ParameterError(
"evaluate or cmp must be defined, evaluate function is recommended"
)

@_record_status # type: ignore
def _starwhale_internal_run_ppl(self) -> None:
def _starwhale_internal_run_predict(self) -> None:
if not self.dataset_uris:
raise FieldTypeOrValueError("context.dataset_uris is empty")
join_str = "_#@#_"
Expand All @@ -146,13 +165,13 @@ def _starwhale_internal_run_ppl(self) -> None:
dataset_info = ds.info
cnt = 0
idx_prefix = f"{_uri.typ}-{_uri.name}-{_uri.version}"
for rows in ds.batch_iter(self.ppl_batch_size):
for rows in ds.batch_iter(self.predict_batch_size):
_start = time.time()
_exception = None
_results: t.Any = b""
try:
if self._is_ppl_batch():
_results = self.ppl(
if self.predict_batch_size > 1:
_results = self._do_predict(
[row.features for row in rows],
index=[row.index for row in rows],
index_with_dataset=[
Expand All @@ -162,7 +181,7 @@ def _starwhale_internal_run_ppl(self) -> None:
)
else:
_results = [
self.ppl(
self._do_predict(
rows[0].features,
index=rows[0].index,
index_with_dataset=f"{idx_prefix}{join_str}{rows[0].index}",
Expand Down Expand Up @@ -198,7 +217,7 @@ def _starwhale_internal_run_ppl(self) -> None:
}
)

if self.ppl_auto_log:
if self.predict_auto_log:
if not self.ignore_dataset_data:
for artifact in TabularDatasetRow.artifacts_of(_features):
if artifact.link:
Expand All @@ -213,7 +232,7 @@ def _starwhale_internal_run_ppl(self) -> None:
serialize=True,
)

if self.flush_result and self.ppl_auto_log:
if self.flush_result and self.predict_auto_log:
self.evaluation_store.flush_result()

console.info(
Expand Down Expand Up @@ -446,9 +465,9 @@ def _register_predict(
needs=needs,
replicas=replicas,
extra_kwargs=dict(
ppl_batch_size=batch_size,
predict_batch_size=batch_size,
ignore_error=not fail_on_error,
ppl_auto_log=auto_log,
predict_auto_log=auto_log,
ignore_dataset_data=not auto_log,
dataset_uris=datasets,
),
Expand Down Expand Up @@ -514,6 +533,6 @@ def _register_evaluate(
replicas=1,
needs=needs,
extra_kwargs=dict(
ppl_auto_log=use_predict_auto_log,
predict_auto_log=use_predict_auto_log,
),
)(func)
13 changes: 7 additions & 6 deletions client/starwhale/api/_impl/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,14 +263,15 @@ def _preload_registering_handlers(
and issubclass(v, PipelineHandler)
and v != PipelineHandler
):
ppl_func = getattr(v, "ppl")
cmp_func = getattr(v, "cmp")
Handler.register(replicas=2, name="ppl")(ppl_func)
# compatible with old version: ppl and cmp function are renamed to predict and evaluate
predict_func = getattr(v, "predict", None) or getattr(v, "ppl")
evaluate_func = getattr(v, "evaluate", None) or getattr(v, "cmp")
Handler.register(replicas=2, name="predict")(predict_func)
Handler.register(
replicas=1,
needs=[ppl_func],
name="cmp",
)(cmp_func)
needs=[predict_func],
name="evaluate",
)(evaluate_func)


def generate_jobs_yaml(
Expand Down
24 changes: 12 additions & 12 deletions client/starwhale/base/scheduler/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ def status(self) -> str:
return self.__status

def _get_internal_func_name(self, func_name: str) -> str:
if func_name == "ppl":
return "_starwhale_internal_run_ppl"
elif func_name == "cmp":
return "_starwhale_internal_run_cmp"
if func_name in ("ppl", "predict"):
return "_starwhale_internal_run_predict"
elif func_name in ("cmp", "evaluate"):
return "_starwhale_internal_run_evaluate"
else:
raise RuntimeError(
f"failed to map func name({func_name}) into PipelineHandler internal func name"
Expand All @@ -68,8 +68,8 @@ def _run_in_pipeline_handler_cls(
from starwhale.api._impl.evaluation import PipelineHandler

patch_func_map = {
"ppl": lambda *args, **kwargs: ...,
"cmp": lambda *args, **kwargs: ...,
"predict": lambda *args, **kwargs: ...,
"evaluate": lambda *args, **kwargs: ...,
}

if func_name not in patch_func_map:
Expand Down Expand Up @@ -107,9 +107,9 @@ def _do_execute(self) -> None:
if cls_ is None:
func = getattr(module, self.step.func_name)
if getattr(func, DecoratorInjectAttr.Evaluate, False):
self._run_in_pipeline_handler_cls(func, "cmp")
self._run_in_pipeline_handler_cls(func, "evaluate")
elif getattr(func, DecoratorInjectAttr.Predict, False):
self._run_in_pipeline_handler_cls(func, "ppl")
self._run_in_pipeline_handler_cls(func, "predict")
elif getattr(func, DecoratorInjectAttr.Step, False):
func()
else:
Expand All @@ -127,17 +127,17 @@ def _do_execute(self) -> None:
with cls_() as instance:
func = getattr(instance, func_name)
if getattr(func, DecoratorInjectAttr.Evaluate, False):
self._run_in_pipeline_handler_cls(func, "cmp")
self._run_in_pipeline_handler_cls(func, "evaluate")
elif getattr(func, DecoratorInjectAttr.Predict, False):
self._run_in_pipeline_handler_cls(func, "ppl")
self._run_in_pipeline_handler_cls(func, "predict")
else:
func()
else:
func = getattr(cls_(), func_name)
if getattr(func, DecoratorInjectAttr.Evaluate, False):
self._run_in_pipeline_handler_cls(func, "cmp")
self._run_in_pipeline_handler_cls(func, "evaluate")
elif getattr(func, DecoratorInjectAttr.Predict, False):
self._run_in_pipeline_handler_cls(func, "ppl")
self._run_in_pipeline_handler_cls(func, "predict")
else:
func()

Expand Down
4 changes: 2 additions & 2 deletions client/starwhale/core/job/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def info(

if "location" in _rt:
console.rule("Process dirs")
console.print(f":cactus: ppl: {_rt['location']['ppl']}")
console.print(f":camel: cmp: {_rt['location']['cmp']}")
console.print(f":cactus: predict: {_rt['location']['predict']}")
console.print(f":camel: evaluate: {_rt['location']['evaluate']}")

if "tasks" in _rt:
self._print_tasks(_rt["tasks"][0])
Expand Down
11 changes: 6 additions & 5 deletions client/starwhale/core/runtime/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1868,11 +1868,6 @@ def _install_dependencies_within_restore(
# We assume the equation in the runtime auto-lock build mode:
# the lock files = pip_pkg + pip_req_file + conda_pkg + conda_env_file
raw_deps = []
for dep in deps["raw_deps"]:
kind = DependencyType(dep["kind"])
if kind in (DependencyType.NATIVE_FILE, DependencyType.WHEEL):
raw_deps.append(dep)

for lf in lock_files:
if lf.endswith(RuntimeLockFileType.CONDA):
raw_deps.append({"deps": lf, "kind": DependencyType.CONDA_ENV_FILE})
Expand All @@ -1882,6 +1877,12 @@ def _install_dependencies_within_restore(
raise NoSupportError(
f"lock file({lf}) cannot be converted into raw_deps"
)

# NATIVE_FILE and WHEEL must be installed after CONDA_ENV_FILE or PIP_REQ_FILE installation.
for dep in deps["raw_deps"]:
kind = DependencyType(dep["kind"])
if kind in (DependencyType.NATIVE_FILE, DependencyType.WHEEL):
raw_deps.append(dep)
else:
raw_deps = deps["raw_deps"]

Expand Down
2 changes: 2 additions & 0 deletions client/starwhale/utils/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn

from starwhale.utils import console
from starwhale.consts import ENV_DISABLE_PROGRESS_BAR


Expand All @@ -23,6 +24,7 @@ def run_with_progress_bar(
*Progress.get_default_columns(),
TimeElapsedColumn(),
refresh_per_second=1,
console=console.rich_console,
) as progress:
task = progress.add_task(
f"[red]{title}", total=sum([o[1] for o in operations])
Expand Down
7 changes: 6 additions & 1 deletion client/starwhale/utils/venv.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,12 @@ def conda_install_req(
return

configs = configs or {}
prefix_cmd = [get_conda_bin(), "run" if use_pip_install else "install"]
prefix_cmd = [get_conda_bin()]

if use_pip_install:
prefix_cmd += ["run", "--live-stream"]
else:
prefix_cmd += ["install"]

if env_name:
prefix_cmd += ["--name", env_name]
Expand Down
8 changes: 4 additions & 4 deletions client/tests/sdk/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def test_cmp(
)
Context.set_runtime_context(context)
with SimpleHandler() as _handler:
_handler._starwhale_internal_run_cmp()
_handler._starwhale_internal_run_evaluate()

status_file_path = os.path.join(_status_dir, "current")
assert os.path.exists(status_file_path)
Expand Down Expand Up @@ -202,7 +202,7 @@ def test_ppl(
)
Context.set_runtime_context(context)
with SimpleHandler() as _handler:
_handler._starwhale_internal_run_ppl()
_handler._starwhale_internal_run_predict()

m_eval_log.assert_called_once()
status_file_path = os.path.join(_status_dir, "current")
Expand Down Expand Up @@ -300,7 +300,7 @@ def cmp(self, _data_loader: t.Any) -> t.Any:
Context.set_runtime_context(context)
# mock
with Dummy(flush_result=True) as _handler:
_handler._starwhale_internal_run_ppl()
_handler._starwhale_internal_run_predict()

context = Context(
workdir=Path(),
Expand All @@ -312,7 +312,7 @@ def cmp(self, _data_loader: t.Any) -> t.Any:
)
Context.set_runtime_context(context)
with Dummy() as _handler:
_handler._starwhale_internal_run_cmp()
_handler._starwhale_internal_run_evaluate()


class TestEvaluationLogStore(BaseTestCase):
Expand Down
Loading

0 comments on commit e105bc6

Please sign in to comment.