Skip to content

Commit

Permalink
example: refactor PennFudanPed example with new api (#1167)
Browse files Browse the repository at this point in the history
refactor pfp code
  • Loading branch information
tianweidut authored Sep 13, 2022
1 parent 383184a commit fbd46ca
Show file tree
Hide file tree
Showing 33 changed files with 773 additions and 999 deletions.
6 changes: 4 additions & 2 deletions client/starwhale/api/_impl/metric.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import typing as t
from enum import Enum, unique
from functools import wraps

from sklearn.metrics import ( # type: ignore
Expand All @@ -17,7 +18,8 @@
from .model import PipelineHandler


class MetricKind:
@unique
class MetricKind(Enum):
MultiClassification = "multi_classification"


Expand All @@ -40,7 +42,7 @@ def _wrapper(*args: t.Any, **kwargs: t.Any) -> t.Dict[str, t.Any]:
else:
y_true, y_pred = _rt

_r: t.Dict[str, t.Any] = {"kind": MetricKind.MultiClassification}
_r: t.Dict[str, t.Any] = {"kind": MetricKind.MultiClassification.value}
cr = classification_report(
y_true, y_pred, output_dict=True, labels=all_labels
)
Expand Down
25 changes: 10 additions & 15 deletions client/starwhale/api/_impl/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import io
import os
import sys
import math
import base64
Expand All @@ -22,7 +21,6 @@
from starwhale.utils.fs import ensure_dir, ensure_file
from starwhale.base.type import URIType, RunSubDirType
from starwhale.utils.log import StreamWrapper
from starwhale.consts.env import SWEnv
from starwhale.utils.error import FieldTypeOrValueError
from starwhale.api._impl.job import Context
from starwhale.core.job.model import STATUS
Expand Down Expand Up @@ -93,7 +91,9 @@ def __init__(
# TODO: split status/result files
self._timeline_writer = _jl_writer(self.status_dir / "timeline")

self.evaluation = self._init_datastore()
self.evaluation = Evaluation(
eval_id=self.context.version, project=self.context.project
)
self._monkey_patch()

def _init_dir(self) -> None:
Expand All @@ -108,11 +108,6 @@ def _init_dir(self) -> None:
ensure_dir(self.status_dir)
ensure_dir(self.log_dir)

def _init_datastore(self) -> Evaluation:
os.environ[SWEnv.project] = self.context.project
os.environ[SWEnv.eval_version] = self.context.version
return Evaluation()

def _init_logger(self) -> t.Tuple[loguru.Logger, loguru.Logger]:
# TODO: remove logger first?
# TODO: add custom log format, include daemonset pod name
Expand Down Expand Up @@ -181,11 +176,11 @@ def ppl(self, data: t.Any, **kw: t.Any) -> t.Any:
def cmp(self, ppl_result: PPLResultIterator) -> t.Any:
raise NotImplementedError

def _builtin_serialize(self, *data: t.Any) -> bytes:
def _builtin_serialize(self, data: t.Any) -> bytes:
return dill.dumps(data) # type: ignore

def ppl_result_serialize(self, *data: t.Any) -> bytes:
return self._builtin_serialize(*data)
def ppl_result_serialize(self, data: t.Any) -> bytes:
return self._builtin_serialize(data)

def ppl_result_deserialize(self, data: bytes) -> t.Any:
return dill.loads(base64.b64decode(data))
Expand All @@ -194,7 +189,7 @@ def annotations_serialize(self, data: t.Any) -> bytes:
return self._builtin_serialize(data)

def annotations_deserialize(self, data: bytes) -> bytes:
return dill.loads(base64.b64decode(data))[0] # type: ignore
return dill.loads(base64.b64decode(data)) # type: ignore

def deserialize(self, data: t.Dict[str, t.Any]) -> t.Any:
data["result"] = self.ppl_result_deserialize(data["result"])
Expand Down Expand Up @@ -275,14 +270,14 @@ def _starwhale_internal_run_ppl(self) -> None:
else:
exception = None

self._do_record(_idx, _annotations, exception, *pred)
self._do_record(_idx, _annotations, exception, pred)

def _do_record(
self,
idx: int,
annotations: t.Dict,
exception: t.Optional[Exception],
*args: t.Any,
pred: t.Any,
) -> None:
_timeline = {
"time": now_str(),
Expand All @@ -296,7 +291,7 @@ def _do_record(
_b64: t.Callable[[bytes], str] = lambda x: base64.b64encode(x).decode("ascii")
self.evaluation.log_result(
data_id=idx,
result=_b64(self.ppl_result_serialize(*args)),
result=_b64(self.ppl_result_serialize(pred)),
annotations=_b64(self.annotations_serialize(annotations)),
)
self._update_status(STATUS.RUNNING)
Expand Down
13 changes: 7 additions & 6 deletions client/starwhale/api/_impl/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,20 @@ def _log(self, table_name: str, record: Dict[str, Any]) -> None:


class Evaluation(Logger):
def __init__(self, eval_id: Optional[str] = None):
if eval_id is None:
eval_id = os.getenv(SWEnv.eval_version, None)
if eval_id is None:
def __init__(self, eval_id: str = "", project: str = ""):
eval_id = eval_id or os.getenv(SWEnv.eval_version, "")
if not eval_id:
raise RuntimeError("eval id should not be None")
if re.match(r"^[A-Za-z0-9-_]+$", eval_id) is None:
raise RuntimeError(
f"invalid eval id {eval_id}, only letters(A-Z, a-z), digits(0-9), hyphen('-'), and underscore('_') are allowed"
)
self.eval_id = eval_id
self.project = os.getenv(SWEnv.project)
if self.project is None:

self.project = project or os.getenv(SWEnv.project, "")
if not self.project:
raise RuntimeError(f"{SWEnv.project} is not set")

self._results_table_name = self._get_datastore_table_name("results")
self._summary_table_name = f"project/{self.project}/eval/summary"
self._init_writers([self._results_table_name, self._summary_table_name])
Expand Down
10 changes: 9 additions & 1 deletion client/starwhale/core/dataset/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import io
import os
import base64
import typing as t
from abc import ABCMeta, abstractmethod
from enum import Enum, unique
Expand Down Expand Up @@ -166,6 +167,9 @@ class ArtifactType(Enum):
Text = "text"


_TBAType = t.TypeVar("_TBAType", bound="BaseArtifact")


class BaseArtifact(ASDictMixin, metaclass=ABCMeta):
def __init__(
self,
Expand Down Expand Up @@ -224,6 +228,10 @@ def to_bytes(self) -> bytes:
else:
raise NoSupportError(f"read raw for type:{type(self.fp)}")

def carry_raw_data(self: _TBAType) -> _TBAType:
self._raw_base64_data = base64.b64encode(self.to_bytes()).decode()
return self

def astype(self) -> t.Dict[str, t.Any]:
return {
"type": self.type,
Expand Down Expand Up @@ -379,7 +387,7 @@ def __init__(
image_id: int,
category_id: int,
segmentation: t.Union[t.List, t.Dict],
area: float,
area: t.Union[float, int],
bbox: t.Union[BoundingBox, t.List[float]],
iscrowd: int,
) -> None:
Expand Down
39 changes: 20 additions & 19 deletions client/starwhale/core/eval/model.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,25 @@
from __future__ import annotations

import os
import json
import typing as t
import subprocess
from abc import ABCMeta, abstractmethod
from http import HTTPStatus
from collections import defaultdict

from loguru import logger

from starwhale.utils import load_yaml
from starwhale.consts import HTTPMethod, DEFAULT_PAGE_IDX, DEFAULT_PAGE_SIZE
from starwhale.base.uri import URI
from starwhale.utils.fs import move_dir
from starwhale.api._impl import wrapper
from starwhale.base.type import InstanceType, JobOperationType
from starwhale.base.cloud import CloudRequestMixed
from starwhale.consts.env import SWEnv
from starwhale.utils.http import ignore_error
from starwhale.utils.error import NotFoundError, NoSupportError
from starwhale.utils.config import SWCliConfigMixed
from starwhale.utils.process import check_call
from starwhale.core.eval.store import EvaluationStorage
from starwhale.api._impl.metric import MetricKind
from starwhale.core.eval.executor import EvalExecutor
from starwhale.core.runtime.process import Process as RuntimeProcess

Expand Down Expand Up @@ -95,22 +92,26 @@ def _get_version(self) -> str:
raise NotImplementedError

def _get_report(self) -> t.Dict[str, t.Any]:
# use datastore
os.environ[SWEnv.project] = self.uri.project
os.environ[SWEnv.eval_version] = self._get_version()
logger.debug(
f"eval instance:{self.uri.instance}, project:{self.uri.project}, eval_id:{self._get_version()}"
)
_evaluation = wrapper.Evaluation()
_summary = _evaluation.get_metrics()
return dict(
summary=_summary,
labels={str(i): l for i, l in enumerate(list(_evaluation.get("labels")))},
confusion_matrix=dict(
binarylabel=list(_evaluation.get("confusion_matrix/binarylabel"))
),
kind=_summary["kind"],
evaluation = wrapper.Evaluation(
eval_id=self._get_version(), project=self.uri.project
)
summary = evaluation.get_metrics()
kind = summary.get("kind", "")

ret = {
"kind": kind,
"summary": summary,
}

if kind == MetricKind.MultiClassification.value:
ret["labels"] = {
str(i): l for i, l in enumerate(list(evaluation.get("labels")))
}
ret["confusion_matrix"] = {
"binarylabel": list(evaluation.get("confusion_matrix/binarylabel"))
}

return ret

@classmethod
def _get_job_cls(
Expand Down
54 changes: 24 additions & 30 deletions client/starwhale/core/eval/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

from rich import box
from loguru import logger
from rich.tree import Tree
from rich.panel import Panel
from rich.table import Table
from rich.pretty import Pretty
from rich.columns import Columns

from starwhale.utils import Order, console, sort_obj_list
from starwhale.consts import (
Expand All @@ -18,6 +19,7 @@
from starwhale.base.type import URIType, InstanceType, JobOperationType
from starwhale.base.view import BaseTermView
from starwhale.core.eval.model import EvaluationJob
from starwhale.api._impl.metric import MetricKind


class JobTermView(BaseTermView):
Expand Down Expand Up @@ -130,7 +132,21 @@ def info(self, page: int = DEFAULT_PAGE_IDX, size: int = DEFAULT_PAGE_SIZE) -> N
self._print_tasks(_rt["tasks"][0])

if "report" in _rt:
self._render_job_report(_rt["report"])
_report = _rt["report"]
_kind = _rt["report"].get("kind", "")

if "summary" in _report:
self._render_summary_report(_report["summary"], _kind)

if _kind == MetricKind.MultiClassification.value:
self._render_multi_classification_job_report(_rt["report"])

def _render_summary_report(self, summary: t.Dict[str, t.Any], kind: str) -> None:
console.rule(f"[bold green]{kind.upper()} Summary")
contents = [
Panel(f"[b]{k}[/b]\n[yellow]{v}", expand=True) for k, v in summary.items()
]
console.print(Columns(contents))

def _print_tasks(self, tasks: t.List[t.Dict[str, t.Any]]) -> None:
table = Table(box=box.SIMPLE)
Expand Down Expand Up @@ -159,39 +175,17 @@ def _print_tasks(self, tasks: t.List[t.Dict[str, t.Any]]) -> None:
)
console.print(table)

# TODO: use new result format
def _render_job_report(self, report: t.Dict[str, t.Any]) -> None:
def _render_multi_classification_job_report(
self, report: t.Dict[str, t.Any]
) -> None:
if not report:
console.print(":turtle: no report")
return

labels: t.Dict[str, t.Any] = report.get("labels", {})
sort_label_names = sorted(list(labels.keys()))

def _print_report() -> None:
# TODO: add other kind report
def _r(_tree: t.Any, _obj: t.Any) -> None:
if not isinstance(_obj, dict):
_tree.add(str(_obj))

for _k, _v in _obj.items():
if _k == "id":
continue
if isinstance(_v, (list, tuple)):
_k = f"{_k}: [green]{'|'.join(_v)}"
elif isinstance(_v, dict):
_k = _k
elif isinstance(_v, str):
_k = f"{_k}:{_v}"
else:
_k = f"{_k}: [green]{_v:.4f}"

_ntree = _tree.add(_k)
if isinstance(_v, dict):
_r(_ntree, _v)

tree = Tree("Summary")
_r(tree, report["summary"])
def _print_labels() -> None:
if len(labels) == 0:
return

Expand All @@ -209,7 +203,7 @@ def _r(_tree: t.Any, _obj: t.Any) -> None:
)

console.rule(f"[bold green]{report['kind'].upper()} Report")
console.print(self.comparison(tree, table))
console.print(table)

def _print_confusion_matrix() -> None:
cm = report.get("confusion_matrix", {})
Expand All @@ -236,7 +230,7 @@ def _print_confusion_matrix() -> None:
console.rule(f"[bold green]{report['kind'].upper()} Confusion Matrix")
console.print(self.comparison(mtable, btable))

_print_report()
_print_labels()
_print_confusion_matrix()

@classmethod
Expand Down
1 change: 1 addition & 0 deletions client/tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def some(self):
Context(
workdir=Path(_model_data_dir),
version="rwerwe9",
project="self",
),
"some",
)
2 changes: 2 additions & 0 deletions example/PennFudanPed/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
data/
model/
12 changes: 12 additions & 0 deletions example/PennFudanPed/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
.POHNY: train
train:
mkdir -p models
python3 pfp/train.py

.POHNY: download-data
download-data:
rm -rf data
mkdir -p data
curl -o data/pfp.zip https://www.cis.upenn.edu/~jshi/ped_html/PennFudanPed.zip
unzip data/pfp.zip -d data
rm -rf data/pfp.zip
Loading

0 comments on commit fbd46ca

Please sign in to comment.