From e76d21ca39c5875d2bb3e56b57ba21570384b410 Mon Sep 17 00:00:00 2001 From: tianwei Date: Sun, 25 Jun 2023 13:48:24 +0800 Subject: [PATCH] enhance(client): tune dataset head cli output (#2381) --- client/starwhale/core/dataset/cli.py | 48 ++++++++++++++++++++++++--- client/starwhale/core/dataset/view.py | 37 ++++++++++++--------- client/tests/core/test_dataset.py | 1 + 3 files changed, 67 insertions(+), 19 deletions(-) diff --git a/client/starwhale/core/dataset/cli.py b/client/starwhale/core/dataset/cli.py index 407c18bc7f..f76470fd5e 100644 --- a/client/starwhale/core/dataset/cli.py +++ b/client/starwhale/core/dataset/cli.py @@ -487,20 +487,60 @@ def _tag( view(dataset).tag(tags, remove, quiet) -@dataset_cmd.command("head", help="Print the first 5 rows of the dataset") +@dataset_cmd.command("head") @click.argument("dataset") -@click.option("-n", "--rows", default=5, help="Print the first NUM rows of the dataset") @click.option( - "-d", + "-n", + "--rows", + default=5, + show_default=True, + help="Print the first NUM rows of the dataset", +) +@click.option( + "-srd", "--show-raw-data", is_flag=True, help="Fetch raw data content", ) +@click.option( + "-st", + "--show-types", + is_flag=True, + help="Show data types", +) @click.pass_obj def _head( view: t.Type[DatasetTermView], dataset: str, rows: int, show_raw_data: bool, + show_types: bool, ) -> None: - view(dataset).head(rows, show_raw_data) + """Print the first n rows of the dataset + + DATASET: argument use the `Dataset URI` format, so you can remove the whole dataset or a specified-version dataset. + + Examples: + + \b + - print the first 5 rows of the mnist dataset + swcli dataset head -n 5 mnist + + \b + - print the first 10 rows of the mnist(v0 version) dataset and show raw data + swcli dataset head -n 10 mnist/v0 --show-raw-data + + \b + - print the data types of the mnist dataset + swcli dataset head mnist --show-types + + \b + - print the remote cloud dataset's first 5 rows + swcli dataset head cloud://cloud-cn/project/test/dataset/mnist -n 5 + + \b + - print the first 5 rows in the json format + swcli -o json dataset head -n 5 mnist + + """ + view(dataset).head(rows, show_raw_data, show_types) diff --git a/client/starwhale/core/dataset/view.py b/client/starwhale/core/dataset/view.py index d4c0bdb5d4..be6e5ab170 100644 --- a/client/starwhale/core/dataset/view.py +++ b/client/starwhale/core/dataset/view.py @@ -259,25 +259,28 @@ def tag( self.dataset.add_tags(tags, ignore_errors) @BaseTermView._header - def head(self, rows: int, show_raw_data: bool = False) -> None: + def head( + self, rows: int, show_raw_data: bool = False, show_types: bool = False + ) -> None: from starwhale.api._impl.data_store import _get_type for row in self.dataset.head(rows, show_raw_data): console.rule(f"row [{row['index']}]", align="left") - output = ( - f":deciduous_tree: id: {row['index']} \n" - ":cyclone: data:\n" - f"\t :dim_button: type: {row['features']} \n" - ) + output = f":deciduous_tree: id: {row['index']} \n" ":cyclone: features:\n" for _k, _v in row["features"].items(): - ds_type: t.Any - try: - ds_type = _get_type(_v) - except RuntimeError: - ds_type = type(_v) - output += ( - f"\t :droplet: {_k}: value[{_v}], type[{ds_type} | {type(_v)}] \n" - ) + output += f"\t :dim_button: [bold green]{_k}[/] : {_v} \n" + + if show_types: + output += ":school_satchel: features types:\n" + for _k, _v in row["features"].items(): + ds_type: t.Any + try: + ds_type = _get_type(_v) + except RuntimeError: + ds_type = type(_v) + output += ( + f"\t :droplet: [bold green]{_k}[/] : {ds_type} | {type(_v)} \n" + ) console.print(output) @@ -329,9 +332,13 @@ def list( # type: ignore def info(self) -> None: self.pretty_json(self.dataset.info()) - def head(self, rows: int, show_raw_data: bool = False) -> None: + def head( + self, rows: int, show_raw_data: bool = False, show_types: bool = False + ) -> None: from starwhale.base.mixin import _do_asdict_convert + # TODO: support show_types in the json format output + info = self.dataset.head(rows, show_raw_data) self.pretty_json(_do_asdict_convert(info)) diff --git a/client/tests/core/test_dataset.py b/client/tests/core/test_dataset.py index 061de56848..a0ffef6501 100644 --- a/client/tests/core/test_dataset.py +++ b/client/tests/core/test_dataset.py @@ -400,6 +400,7 @@ def test_head(self, *args: t.Any) -> None: assert len(results) == 2 DatasetTermView(dataset_uri).head(1, show_raw_data=True) DatasetTermView(dataset_uri).head(2, show_raw_data=True) + DatasetTermView(dataset_uri).head(2, show_raw_data=True, show_types=True) DatasetTermViewJson(dataset_uri).head(1, show_raw_data=False) DatasetTermViewJson(dataset_uri).head(2, show_raw_data=True)