Skip to content

Commit

Permalink
feat(client): add filter params for model/dataset/runtime list command (
Browse files Browse the repository at this point in the history
  • Loading branch information
dreamlandliu authored Jan 6, 2023
1 parent 039020b commit 9c28447
Show file tree
Hide file tree
Showing 20 changed files with 557 additions and 55 deletions.
43 changes: 42 additions & 1 deletion client/starwhale/base/bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from starwhale.utils.config import SWCliConfigMixed

from .uri import URI
from .store import BundleField


class BaseBundle(metaclass=ABCMeta):
Expand Down Expand Up @@ -76,9 +77,49 @@ def list(
project_uri: URI,
page: int = DEFAULT_PAGE_IDX,
size: int = DEFAULT_PAGE_SIZE,
filters: t.Optional[t.Union[t.Dict[str, t.Any], t.List[str]]] = None,
) -> t.Tuple[t.Dict[str, t.Any], t.Dict[str, t.Any]]:
filters = filters or {}
_cls = cls._get_cls(project_uri)
return _cls.list(project_uri, page, size) # type: ignore
_filter = cls.get_filter_dict(filters, cls.get_filter_fields())
return _cls.list(project_uri, page, size, _filter) # type: ignore

@classmethod
def get_filter_dict(
cls,
filters: t.Union[t.Dict[str, t.Any], t.List[str]],
fields: t.Optional[t.List[str]] = None,
) -> t.Dict[str, t.Any]:
fields = fields or []
if isinstance(filters, t.Dict):
return {k: v for k, v in filters.items() if k in fields}

_filter_dict: t.Dict[str, t.Any] = {}
for _f in filters:
_item = _f.split("=", 1)
if _item[0] in fields:
_filter_dict[_item[0]] = _item[1] if len(_item) > 1 else ""
return _filter_dict

@classmethod
def get_filter_fields(cls) -> t.List[str]:
return ["name", "owner", "latest"]

@classmethod
def do_bundle_filter(
cls,
bundle_field: BundleField,
filters: t.Union[t.Dict[str, t.Any], t.List[str]],
) -> bool:
filter_dict = cls.get_filter_dict(filters, cls.get_filter_fields())
_name = filter_dict.get("name")
if _name and not bundle_field.name.startswith(_name):
return False
_latest = filter_dict.get("latest") is not None
if _latest and "latest" not in bundle_field.tags:
return False

return True

@abstractclassmethod
def _get_cls(cls, uri: URI) -> t.Any:
Expand Down
17 changes: 13 additions & 4 deletions client/starwhale/base/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,23 +235,32 @@ def _fetch_bundle_all_list(
uri_typ: str,
page: int = DEFAULT_PAGE_IDX,
size: int = DEFAULT_PAGE_SIZE,
filter_dict: t.Optional[t.Dict[str, t.Any]] = None,
) -> t.Tuple[t.Dict[str, t.Any], t.Dict[str, t.Any]]:
filter_dict = filter_dict or {}
_params = {"pageNum": page, "pageSize": size}
_params.update(filter_dict)
r = self.do_http_request(
f"/project/{project_uri.project}/{uri_typ}",
params={"pageNum": page, "pageSize": size},
params=_params,
instance_uri=project_uri,
).json()

objects = {}

_page = page
_size = size
if filter_dict.get("latest") is not None:
_page = 1
_size = 1

for o in r["data"]["list"]:
_name = f"[{o['id']}] {o['name']}"
objects[_name] = self._fetch_bundle_history(
name=o["id"],
project_uri=project_uri,
typ=uri_typ,
page=page,
size=size,
page=_page,
size=_size,
)[0]

return objects, self.parse_pager(r)
Expand Down
26 changes: 24 additions & 2 deletions client/starwhale/core/dataset/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def _diff(
view(base_uri).diff(URI(compare_uri, expected_type=URIType.DATASET), show_details)


@dataset_cmd.command("list", aliases=["ls"], help="List dataset")
@dataset_cmd.command("list", aliases=["ls"])
@click.option("-p", "--project", default="", help="Project URI")
@click.option("-f", "--fullname", is_flag=True, help="Show fullname of dataset version")
@click.option("-sr", "--show-removed", is_flag=True, help="Show removed datasets")
Expand All @@ -128,6 +128,13 @@ def _diff(
@click.option(
"--size", type=int, default=DEFAULT_PAGE_SIZE, help="Page size for dataset list"
)
@click.option(
"filters",
"-fl",
"--filter",
multiple=True,
help="Filter output based on conditions provided.",
)
@click.pass_obj
def _list(
view: DatasetTermView,
Expand All @@ -136,8 +143,23 @@ def _list(
show_removed: bool,
page: int,
size: int,
filters: list,
) -> None:
view.list(project, fullname, show_removed, page, size)
"""
List Dataset
The filtering flag (-fl or --filter) format is a key=value pair or a flag.
If there is more than one filter, then pass multiple flags.\n
(e.g. --filter name=mnist --filter latest)
\b
The currently supported filters are:
name\tTEXT\tThe prefix of the dataset name
owner\tTEXT\tThe name or id of the dataset owner
latest\tFLAG\t[Cloud] Only show the latest version
\t \t[Standalone] Only show the version with "latest" tag
"""
view.list(project, fullname, show_removed, page, size, filters)


@dataset_cmd.command("info", help="Show dataset details")
Expand Down
11 changes: 10 additions & 1 deletion client/starwhale/core/dataset/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,14 +257,19 @@ def list(
project_uri: URI,
page: int = DEFAULT_PAGE_IDX,
size: int = DEFAULT_PAGE_SIZE,
filters: t.Optional[t.Union[t.Dict[str, t.Any], t.List[str]]] = None,
) -> t.Tuple[t.Dict[str, t.Any], t.Dict[str, t.Any]]:
filters = filters or {}
rs = defaultdict(list)

for _bf in DatasetStorage.iter_all_bundles(
project_uri,
bundle_type=BundleType.DATASET,
uri_type=URIType.DATASET,
):
if not cls.do_bundle_filter(_bf, filters):
continue

_mf = _bf.path / DEFAULT_MANIFEST_NAME
if not _mf.exists():
continue
Expand Down Expand Up @@ -521,9 +526,13 @@ def list(
project_uri: URI,
page: int = DEFAULT_PAGE_IDX,
size: int = DEFAULT_PAGE_SIZE,
filter_dict: t.Optional[t.Dict[str, t.Any]] = None,
) -> t.Tuple[t.Dict[str, t.Any], t.Dict[str, t.Any]]:
filter_dict = filter_dict or {}
crm = CloudRequestMixed()
return crm._fetch_bundle_all_list(project_uri, URIType.DATASET, page, size)
return crm._fetch_bundle_all_list(
project_uri, URIType.DATASET, page, size, filter_dict
)

def summary(self) -> t.Optional[DatasetSummary]:
resp = self.do_http_request(
Expand Down
13 changes: 9 additions & 4 deletions client/starwhale/core/dataset/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,16 @@ def list(
show_removed: bool = False,
page: int = DEFAULT_PAGE_IDX,
size: int = DEFAULT_PAGE_SIZE,
filters: t.Optional[t.List[str]] = None,
) -> t.Tuple[t.List[t.Dict[str, t.Any]], t.Dict[str, t.Any]]:

filters = filters or []
if isinstance(project_uri, str):
_uri = URI(project_uri, expected_type=URIType.PROJECT)
else:
_uri = project_uri

fullname = fullname or (_uri.instance_type == InstanceType.CLOUD)
_datasets, _pager = Dataset.list(_uri, page, size)
_datasets, _pager = Dataset.list(_uri, page, size, filters)
_data = BaseTermView.list_data(_datasets, show_removed, fullname)
return _data, _pager

Expand Down Expand Up @@ -235,9 +236,11 @@ def list(
show_removed: bool = False,
page: int = DEFAULT_PAGE_IDX,
size: int = DEFAULT_PAGE_SIZE,
filters: t.Optional[t.List[str]] = None,
) -> t.Tuple[t.List[t.Dict[str, t.Any]], t.Dict[str, t.Any]]:
filters = filters or []
_datasets, _pager = super().list(
project_uri, fullname, show_removed, page, size
project_uri, fullname, show_removed, page, size, filters
)
custom_column: t.Dict[str, t.Callable[[t.Any], str]] = {
"tags": lambda x: ",".join(x),
Expand All @@ -258,9 +261,11 @@ def list( # type: ignore
show_removed: bool = False,
page: int = DEFAULT_PAGE_IDX,
size: int = DEFAULT_PAGE_SIZE,
filters: t.Optional[t.List[str]] = None,
) -> None:
filters = filters or []
_datasets, _pager = super().list(
project_uri, fullname, show_removed, page, size
project_uri, fullname, show_removed, page, size, filters
)
cls.pretty_json(_datasets)

Expand Down
26 changes: 24 additions & 2 deletions client/starwhale/core/model/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _diff(
view(base_uri).diff(URI(compare_uri, expected_type=URIType.MODEL), show_details)


@model_cmd.command("list", aliases=["ls"], help="List Model")
@model_cmd.command("list", aliases=["ls"])
@click.option("-p", "--project", default="", help="Project URI")
@click.option("-f", "--fullname", is_flag=True, help="Show fullname of model version")
@click.option("-sr", "--show-removed", is_flag=True, help="Show removed model")
Expand All @@ -130,6 +130,13 @@ def _diff(
@click.option(
"--size", type=int, default=DEFAULT_PAGE_SIZE, help="Page size for model list"
)
@click.option(
"filters",
"-fl",
"--filter",
multiple=True,
help="Filter output based on conditions provided.",
)
@click.pass_obj
def _list(
view: t.Type[ModelTermView],
Expand All @@ -138,8 +145,23 @@ def _list(
show_removed: bool,
page: int,
size: int,
filters: list,
) -> None:
view.list(project, fullname, show_removed, page, size)
"""
List Model
The filtering flag (-fl or --filter) format is a key=value pair or a flag.
If there is more than one filter, then pass multiple flags.\n
(e.g. --filter name=mnist --filter latest)
\b
The currently supported filters are:
name\tTEXT\tThe prefix of the model name
owner\tTEXT\tThe name or id of the model owner
latest\tFLAG\t[Cloud] Only show the latest version
\t \t[Standalone] Only show the version with "latest" tag
"""
view.list(project, fullname, show_removed, page, size, filters)


@model_cmd.command("history", help="Show model history")
Expand Down
11 changes: 10 additions & 1 deletion client/starwhale/core/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,13 +537,18 @@ def list(
project_uri: URI,
page: int = DEFAULT_PAGE_IDX,
size: int = DEFAULT_PAGE_SIZE,
filters: t.Optional[t.Union[t.Dict[str, t.Any], t.List[str]]] = None,
) -> t.Tuple[t.Dict[str, t.Any], t.Dict[str, t.Any]]:
filters = filters or {}
rs = defaultdict(list)
for _bf in ModelStorage.iter_all_bundles(
project_uri,
bundle_type=BundleType.MODEL,
uri_type=URIType.MODEL,
):
if not cls.do_bundle_filter(_bf, filters):
continue

if _bf.path.is_file():
# for origin swmp(tar)
_manifest = ModelStorage.get_manifest_by_path(
Expand Down Expand Up @@ -736,9 +741,13 @@ def list(
project_uri: URI,
page: int = DEFAULT_PAGE_IDX,
size: int = DEFAULT_PAGE_SIZE,
filter_dict: t.Optional[t.Dict[str, t.Any]] = None,
) -> t.Tuple[t.Dict[str, t.Any], t.Dict[str, t.Any]]:
filter_dict = filter_dict or {}
crm = CloudRequestMixed()
return crm._fetch_bundle_all_list(project_uri, URIType.MODEL, page, size)
return crm._fetch_bundle_all_list(
project_uri, URIType.MODEL, page, size, filter_dict
)

def build(self, *args: t.Any, **kwargs: t.Any) -> None:
raise NoSupportError("no support build model in the cloud instance")
Expand Down
16 changes: 13 additions & 3 deletions client/starwhale/core/model/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,12 @@ def list(
show_removed: bool = False,
page: int = DEFAULT_PAGE_IDX,
size: int = DEFAULT_PAGE_SIZE,
filters: t.Optional[t.Union[t.Dict[str, t.Any], t.List[str]]] = None,
) -> t.Tuple[t.List[t.Dict[str, t.Any]], t.Dict[str, t.Any]]:
filters = filters or {}
_uri = URI(project_uri, expected_type=URIType.PROJECT)
fullname = fullname or (_uri.instance_type == InstanceType.CLOUD)
_models, _pager = Model.list(_uri, page, size)
_models, _pager = Model.list(_uri, page, size, filters)
_data = BaseTermView.list_data(_models, show_removed, fullname)
return _data, _pager

Expand Down Expand Up @@ -308,8 +310,12 @@ def list(
show_removed: bool = False,
page: int = DEFAULT_PAGE_IDX,
size: int = DEFAULT_PAGE_SIZE,
filters: t.Optional[t.List[str]] = None,
) -> t.Tuple[t.List[t.Dict[str, t.Any]], t.Dict[str, t.Any]]:
_models, _pager = super().list(project_uri, fullname, show_removed, page, size)
filters = filters or []
_models, _pager = super().list(
project_uri, fullname, show_removed, page, size, filters
)
custom_column: t.Dict[str, t.Callable[[t.Any], str]] = {
"tags": lambda x: ",".join(x),
"size": lambda x: pretty_bytes(x),
Expand All @@ -329,8 +335,12 @@ def list( # type: ignore
show_removed: bool = False,
page: int = DEFAULT_PAGE_IDX,
size: int = DEFAULT_PAGE_SIZE,
filters: t.Optional[t.List[str]] = None,
) -> None:
_models, _pager = super().list(project_uri, fullname, show_removed, page, size)
filters = filters or []
_models, _pager = super().list(
project_uri, fullname, show_removed, page, size, filters
)
cls.pretty_json(_models)

def info(self, fullname: bool = False) -> None:
Expand Down
26 changes: 24 additions & 2 deletions client/starwhale/core/runtime/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def _restore(target: str) -> None:
RuntimeTermView.restore(target)


@runtime_cmd.command("list", aliases=["ls"], help="List runtime")
@runtime_cmd.command("list", aliases=["ls"])
@click.option("-p", "--project", default="", help="Project URI")
@click.option("-f", "--fullname", is_flag=True, help="Show fullname of runtime version")
@click.option("-sr", "--show-removed", is_flag=True, help="Show removed runtime")
Expand All @@ -275,6 +275,13 @@ def _restore(target: str) -> None:
@click.option(
"--size", type=int, default=DEFAULT_PAGE_SIZE, help="Page size for tasks list"
)
@click.option(
"filters",
"-fl",
"--filter",
multiple=True,
help="Filter output based on conditions provided.",
)
@click.pass_obj
def _list(
view: t.Type[RuntimeTermView],
Expand All @@ -283,8 +290,23 @@ def _list(
show_removed: bool,
page: int,
size: int,
filters: list,
) -> None:
view.list(project, fullname, show_removed, page, size)
"""
List Runtime
The filtering flag (-fl or --filter) format is a key=value pair or a flag.
If there is more than one filter, then pass multiple flags.\n
(e.g. --filter name=mnist --filter latest)
\b
The currently supported filters are:
name\tTEXT\tThe prefix of the runtime name
owner\tTEXT\tThe name or id of the runtime owner
latest\tFLAG\t[Cloud] Only show the latest version
\t \t[Standalone] Only show the version with "latest" tag
"""
view.list(project, fullname, show_removed, page, size, filters)


@runtime_cmd.command(
Expand Down
Loading

0 comments on commit 9c28447

Please sign in to comment.