Skip to content

Commit

Permalink
enhance(client): tune dataset head cli output (#2381)
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweidut authored Jun 25, 2023
1 parent f8dfa71 commit e76d21c
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 19 deletions.
48 changes: 44 additions & 4 deletions client/starwhale/core/dataset/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,20 +487,60 @@ def _tag(
view(dataset).tag(tags, remove, quiet)


@dataset_cmd.command("head", help="Print the first 5 rows of the dataset")
@dataset_cmd.command("head")
@click.argument("dataset")
@click.option("-n", "--rows", default=5, help="Print the first NUM rows of the dataset")
@click.option(
"-d",
"-n",
"--rows",
default=5,
show_default=True,
help="Print the first NUM rows of the dataset",
)
@click.option(
"-srd",
"--show-raw-data",
is_flag=True,
help="Fetch raw data content",
)
@click.option(
"-st",
"--show-types",
is_flag=True,
help="Show data types",
)
@click.pass_obj
def _head(
view: t.Type[DatasetTermView],
dataset: str,
rows: int,
show_raw_data: bool,
show_types: bool,
) -> None:
view(dataset).head(rows, show_raw_data)
"""Print the first n rows of the dataset
DATASET: argument use the `Dataset URI` format, so you can remove the whole dataset or a specified-version dataset.
Examples:
\b
- print the first 5 rows of the mnist dataset
swcli dataset head -n 5 mnist
\b
- print the first 10 rows of the mnist(v0 version) dataset and show raw data
swcli dataset head -n 10 mnist/v0 --show-raw-data
\b
- print the data types of the mnist dataset
swcli dataset head mnist --show-types
\b
- print the remote cloud dataset's first 5 rows
swcli dataset head cloud://cloud-cn/project/test/dataset/mnist -n 5
\b
- print the first 5 rows in the json format
swcli -o json dataset head -n 5 mnist
"""
view(dataset).head(rows, show_raw_data, show_types)
37 changes: 22 additions & 15 deletions client/starwhale/core/dataset/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,25 +259,28 @@ def tag(
self.dataset.add_tags(tags, ignore_errors)

@BaseTermView._header
def head(self, rows: int, show_raw_data: bool = False) -> None:
def head(
self, rows: int, show_raw_data: bool = False, show_types: bool = False
) -> None:
from starwhale.api._impl.data_store import _get_type

for row in self.dataset.head(rows, show_raw_data):
console.rule(f"row [{row['index']}]", align="left")
output = (
f":deciduous_tree: id: {row['index']} \n"
":cyclone: data:\n"
f"\t :dim_button: type: {row['features']} \n"
)
output = f":deciduous_tree: id: {row['index']} \n" ":cyclone: features:\n"
for _k, _v in row["features"].items():
ds_type: t.Any
try:
ds_type = _get_type(_v)
except RuntimeError:
ds_type = type(_v)
output += (
f"\t :droplet: {_k}: value[{_v}], type[{ds_type} | {type(_v)}] \n"
)
output += f"\t :dim_button: [bold green]{_k}[/] : {_v} \n"

if show_types:
output += ":school_satchel: features types:\n"
for _k, _v in row["features"].items():
ds_type: t.Any
try:
ds_type = _get_type(_v)
except RuntimeError:
ds_type = type(_v)
output += (
f"\t :droplet: [bold green]{_k}[/] : {ds_type} | {type(_v)} \n"
)

console.print(output)

Expand Down Expand Up @@ -329,9 +332,13 @@ def list( # type: ignore
def info(self) -> None:
self.pretty_json(self.dataset.info())

def head(self, rows: int, show_raw_data: bool = False) -> None:
def head(
self, rows: int, show_raw_data: bool = False, show_types: bool = False
) -> None:
from starwhale.base.mixin import _do_asdict_convert

# TODO: support show_types in the json format output

info = self.dataset.head(rows, show_raw_data)
self.pretty_json(_do_asdict_convert(info))

Expand Down
1 change: 1 addition & 0 deletions client/tests/core/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ def test_head(self, *args: t.Any) -> None:
assert len(results) == 2
DatasetTermView(dataset_uri).head(1, show_raw_data=True)
DatasetTermView(dataset_uri).head(2, show_raw_data=True)
DatasetTermView(dataset_uri).head(2, show_raw_data=True, show_types=True)
DatasetTermViewJson(dataset_uri).head(1, show_raw_data=False)
DatasetTermViewJson(dataset_uri).head(2, show_raw_data=True)

Expand Down

0 comments on commit e76d21c

Please sign in to comment.