Skip to content

Commit

Permalink
feat(evaluation): support multi datasets (#1502)
Browse files Browse the repository at this point in the history
  • Loading branch information
goldenxinxing authored Nov 17, 2022
1 parent 16899cc commit b3aa333
Show file tree
Hide file tree
Showing 22 changed files with 390 additions and 274 deletions.
2 changes: 1 addition & 1 deletion client/scripts/sw-docker-entrypoint
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ restore_activate_runtime() {

run() {
echo "--> start to run swmp ${STEP}, use $(which swcli) cli @ ${SWMP_DIR} ..."
swcli ${VERBOSE} model eval "${SWMP_DIR}"/src --dataset=${SW_DATASET_URI} --step=${STEP} --task-index=${TASK_INDEX} --override-task-num=${TASK_NUM} --version=${SW_EVALUATION_VERSION} || exit 1
swcli ${VERBOSE} model eval "${SWMP_DIR}"/src --step=${STEP} --task-index=${TASK_INDEX} --override-task-num=${TASK_NUM} --version=${SW_EVALUATION_VERSION} || exit 1
}

welcome() {
Expand Down
87 changes: 45 additions & 42 deletions client/starwhale/api/_impl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
import loguru
import jsonlines

from starwhale import URI
from starwhale.utils import now_str
from starwhale.consts import CURRENT_FNAME
from starwhale.api.job import Context
from starwhale.utils.fs import ensure_dir, ensure_file
from starwhale.api._impl import wrapper
from starwhale.base.type import RunSubDirType
from starwhale.base.type import URIType, RunSubDirType
from starwhale.utils.log import StreamWrapper
from starwhale.utils.error import FieldTypeOrValueError
from starwhale.api._impl.job import context_holder
Expand Down Expand Up @@ -218,50 +219,52 @@ def _starwhale_internal_run_ppl(self) -> None:

if not self.context.dataset_uris:
raise FieldTypeOrValueError("context.dataset_uris is empty")
# TODO: support multi dataset uris
# TODO: user custom config batch size, max_retries
ds_uri = self.context.dataset_uris[0]
consumption = get_dataset_consumption(
dataset_uri=ds_uri, session_id=self.context.version
)
loader = get_data_loader(ds_uri, session_consumption=consumption)

cnt = 0
for _idx, _data, _annotations in loader:
cnt += 1
_start = time.time()
result: t.Any = b""
exception = None
try:
# TODO: inspect profiling
result = self.ppl(_data, annotations=_annotations, index=_idx)
except Exception as e:
exception = e
self._sw_logger.exception(f"[{_idx}] data handle -> failed")
if not self.ignore_error:
self._update_status(STATUS.FAILED)
raise
else:
exception = None

self._sw_logger.debug(
f"[{_idx}] use {time.time() - _start:.3f}s, session-id:{self.context.version} @{self.context.step}-{self.context.index}"
)

self._timeline_writer.write(
{
"time": now_str(),
"status": exception is None,
"exception": str(exception),
"index": _idx,
}
# TODO: user custom config batch size, max_retries
for ds_uri in self.context.dataset_uris:
_uri = URI(ds_uri, expected_type=URIType.DATASET)
consumption = get_dataset_consumption(
dataset_uri=_uri, session_id=self.context.version
)
loader = get_data_loader(_uri, session_consumption=consumption)

result_storage.save(
data_id=_idx,
result=result,
annotations={} if self.ignore_annotations else _annotations,
)
cnt = 0
for _idx, _data, _annotations in loader:
cnt += 1
_start = time.time()
result: t.Any = b""
exception = None
_unique_id = f"{_uri.object}_{_idx}"
try:
# TODO: inspect profiling
result = self.ppl(_data, annotations=_annotations, index=_unique_id)
except Exception as e:
exception = e
self._sw_logger.exception(f"[{_unique_id}] data handle -> failed")
if not self.ignore_error:
self._update_status(STATUS.FAILED)
raise
else:
exception = None

self._sw_logger.debug(
f"[{_unique_id}] use {time.time() - _start:.3f}s, session-id:{self.context.version} @{self.context.step}-{self.context.index}"
)

self._timeline_writer.write(
{
"time": now_str(),
"status": exception is None,
"exception": str(exception),
"index": f"{_unique_id}",
}
)

result_storage.save(
data_id=f"{_unique_id}",
result=result,
annotations={} if self.ignore_annotations else _annotations,
)

if self.flush_result:
result_storage.flush()
Expand Down
6 changes: 4 additions & 2 deletions client/starwhale/core/eval/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@ def _list(
@click.option("--model", required=True, help="model uri or model.yaml dir path")
# TODO:support multi dataset
@click.option(
"datasets",
"--dataset",
required=True,
envvar=SWEnv.dataset_uri,
multiple=True,
help=f"dataset uri, env is {SWEnv.dataset_uri}",
)
@click.option("--runtime", default="", help="runtime uri")
Expand Down Expand Up @@ -93,7 +95,7 @@ def _run(
project: str,
version: str,
model: str,
dataset: str,
datasets: list,
runtime: str,
name: str,
desc: str,
Expand All @@ -109,7 +111,7 @@ def _run(
project_uri=project,
version=version,
model_uri=model,
dataset_uris=[dataset],
dataset_uris=datasets,
runtime_uri=runtime,
name=name,
desc=desc,
Expand Down
7 changes: 6 additions & 1 deletion client/starwhale/core/eval/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,12 @@ def _gen_run_container_cmd(self, typ: str, step: str, task_index: int) -> str:
]
)
# TODO: support multi dataset
cmd.extend(["-e", f"{SWEnv.dataset_uri}={self.dataset_uris[0].full_uri}"])
cmd.extend(
[
"-e",
f"{SWEnv.dataset_uri}={' '.join([ds.full_uri for ds in self.dataset_uris])}",
]
)

cntr_cache_dir = os.environ.get("SW_PIP_CACHE_DIR", CNTR_DEFAULT_PIP_CACHE_DIR)
host_cache_dir = os.path.expanduser("~/.cache/starwhale-pip")
Expand Down
6 changes: 4 additions & 2 deletions client/starwhale/core/model/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,17 +156,19 @@ def _extract(model: str, force: bool, target_dir: str) -> None:
)
@click.option("--runtime", default="", help="runtime uri")
@click.option(
"datasets",
"--dataset",
required=True,
envvar=SWEnv.dataset_uri,
multiple=True,
help=f"dataset uri, env is {SWEnv.dataset_uri}",
)
def _eval(
project: str,
target: str,
model_yaml: str,
version: str,
dataset: str,
datasets: list,
step: str,
task_index: int,
override_task_num: int,
Expand All @@ -186,5 +188,5 @@ def _eval(
step=step,
task_index=task_index,
task_num=override_task_num,
dataset_uris=[dataset],
dataset_uris=datasets,
)
21 changes: 18 additions & 3 deletions scripts/client_test/cli_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ def standard_workflow(
assert len(self.dataset.list()) == 1
swds = self.dataset.info(_ds_uri)
assert swds
self.dataset.build(workdir=ds_workdir)
assert len(self.dataset.list()) == 2
swds2 = self.dataset.info(_ds_uri)
assert swds2

# 3.runtime build
logging.info("building runtime...")
Expand All @@ -78,7 +82,13 @@ def standard_workflow(
# 4.run evaluation on local instance
logging.info("running evaluation at local...")
assert len(self.evaluation.list()) == 0
assert self.evaluation.run(model=_model_uri, dataset=_ds_uri)
assert self.evaluation.run(
model=_model_uri,
datasets=[
f'{ds_name}/version/{swds["version"]}',
f'{ds_name}/version/{swds2["version"]}',
],
)
_eval_list = self.evaluation.list()
assert len(_eval_list) == 1

Expand All @@ -101,7 +111,12 @@ def standard_workflow(
force=True,
)
assert self.dataset.copy(
src_uri=_ds_uri,
src_uri=f'{ds_name}/version/{swds["version"]}',
target_project=f"cloud://cloud/project/{cloud_project}",
force=True,
)
assert self.dataset.copy(
src_uri=f'{ds_name}/version/{swds2["version"]}',
target_project=f"cloud://cloud/project/{cloud_project}",
force=True,
)
Expand All @@ -122,7 +137,7 @@ def standard_workflow(
logging.info("running evaluation at cloud...")
assert self.evaluation.run(
model=swmp["version"],
dataset=swds["version"],
datasets=[swds["version"], swds2["version"]],
runtime=swrt["version"],
project=cloud_project,
step_spec=step_spec_file,
Expand Down
10 changes: 6 additions & 4 deletions scripts/client_test/cmds/eval_cmd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any
from typing import Any, List

from . import CLI
from .base.invoke import invoke
Expand All @@ -11,7 +11,7 @@ class Evaluation:
def run(
self,
model: str,
dataset: str,
datasets: List[str],
project: str = "self",
version: str = "",
runtime: str = "",
Expand All @@ -28,7 +28,7 @@ def run(
:param project:
:param version: Evaluation job version
:param model: model uri or model.yaml dir path [required]
:param dataset: dataset uri, one or more [required]
:param datasets: dataset uri, one or more [required]
:param runtime: runtime uri
:param name: job name
:param desc: job description
Expand All @@ -40,7 +40,9 @@ def run(
:param resource_pool: [ONLY Cloud] which nodes should job run on
:return:
"""
_args = [CLI, self._cmd, "run", "--model", model, "--dataset", dataset]
_args = [CLI, self._cmd, "run", "--model", model]
for ds in datasets:
_args.extend(["--dataset", ds])
if version:
_args.extend(["--version", version])
if runtime:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

package ai.starwhale.mlops.domain.dataset.dataloader;

import ai.starwhale.mlops.common.KeyLock;
import ai.starwhale.mlops.domain.dataset.dataloader.bo.DataReadLog;
import org.springframework.stereotype.Service;

Expand All @@ -29,24 +28,12 @@ public DataLoader(DataReadManager dataReadManager) {
}

public DataReadLog next(DataReadRequest request) {
var sessionId = request.getSessionId();
var consumerId = request.getConsumerId();
var session = dataReadManager.getOrGenerateSession(request);

var lock = new KeyLock<>(consumerId);
try {
lock.lock();
dataReadManager.handleConsumerData(request);
} finally {
lock.unlock();
}
dataReadManager.handleConsumerData(consumerId,
request.isSerial(), request.getProcessedData(), session);

// ensure serially in the same session
var sessionLock = new KeyLock<>(sessionId);
try {
sessionLock.lock();
return dataReadManager.getDataReadIndex(request);
} finally {
sessionLock.unlock();
}
return dataReadManager.assignmentData(consumerId, session);
}
}
Loading

0 comments on commit b3aa333

Please sign in to comment.