Skip to content

Commit

Permalink
Merge pull request #305 from aiverify-foundation/ms-344
Browse files Browse the repository at this point in the history
[MS-344] Redo unit test for CLI Benchmark
  • Loading branch information
imda-kelvinkok committed Aug 27, 2024
2 parents 6204797 + 0ff221a commit d53ad64
Show file tree
Hide file tree
Showing 18 changed files with 9,695 additions and 1,291 deletions.
268 changes: 226 additions & 42 deletions moonshot/integrations/cli/benchmark/cookbook.py

Large diffs are not rendered by default.

61 changes: 53 additions & 8 deletions moonshot/integrations/cli/benchmark/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@
api_get_all_datasets,
api_get_all_datasets_name,
)
from moonshot.integrations.cli.cli_errors import (
ERROR_BENCHMARK_DELETE_DATASET_DATASET_VALIDATION,
ERROR_BENCHMARK_LIST_DATASETS_FIND_VALIDATION,
ERROR_BENCHMARK_LIST_DATASETS_PAGINATION_VALIDATION,
ERROR_BENCHMARK_LIST_DATASETS_PAGINATION_VALIDATION_1,
ERROR_BENCHMARK_VIEW_DATASET_DATASET_FILENAME_VALIDATION,
)
from moonshot.integrations.cli.common.display_helper import display_view_str_format
from moonshot.integrations.cli.utils.process_data import filter_data

Expand All @@ -23,22 +30,43 @@ def list_datasets(args) -> list | None:
List all available datasets.
This function retrieves all available datasets by calling the api_get_all_datasets function from the
moonshot.api module. It then displays the datasets using the _display_datasets function. If an exception occurs,
it prints an error message.
moonshot.api module. It then filters the datasets based on the provided keyword and pagination arguments.
If there are no datasets, it prints a message indicating that no datasets were found.
Args:
args: A namespace object from argparse. It should have an optional attribute:
find (str): Optional field to find dataset(s) with a keyword.
pagination (str): Optional field to paginate datasets.
args: A namespace object from argparse. It should have optional attributes:
find (str): Optional keyword to filter datasets.
pagination (str): Optional tuple to paginate datasets.
Returns:
list | None: A list of Dataset or None if there is no result.
list | None: A list of datasets or None if there are no datasets.
"""
try:
print("Listing datasets may take a while...")
if args.find is not None:
if not isinstance(args.find, str) or not args.find:
raise TypeError(ERROR_BENCHMARK_LIST_DATASETS_FIND_VALIDATION)

if args.pagination is not None:
if not isinstance(args.pagination, str) or not args.pagination:
raise TypeError(ERROR_BENCHMARK_LIST_DATASETS_PAGINATION_VALIDATION)
try:
pagination = literal_eval(args.pagination)
if not (
isinstance(pagination, tuple)
and len(pagination) == 2
and all(isinstance(i, int) for i in pagination)
):
raise ValueError(
ERROR_BENCHMARK_LIST_DATASETS_PAGINATION_VALIDATION_1
)
except (ValueError, SyntaxError):
raise ValueError(ERROR_BENCHMARK_LIST_DATASETS_PAGINATION_VALIDATION_1)
else:
pagination = ()

datasets_list = api_get_all_datasets()
keyword = args.find.lower() if args.find else ""
pagination = literal_eval(args.pagination) if args.pagination else ()

if datasets_list:
filtered_datasets_list = filter_data(datasets_list, keyword, pagination)
Expand All @@ -48,8 +76,10 @@ def list_datasets(args) -> list | None:

console.print("[red]There are no datasets found.[/red]")
return None

except Exception as e:
print(f"[list_datasets]: {str(e)}")
return None


def view_dataset(args) -> None:
Expand All @@ -69,6 +99,13 @@ def view_dataset(args) -> None:
"""
try:
print("Viewing datasets may take a while...")
if (
not isinstance(args.dataset_filename, str)
or not args.dataset_filename
or args.dataset_filename is None
):
raise TypeError(ERROR_BENCHMARK_VIEW_DATASET_DATASET_FILENAME_VALIDATION)

datasets_list = api_get_all_datasets()
datasets_name_list = api_get_all_datasets_name()

Expand All @@ -92,7 +129,7 @@ def delete_dataset(args) -> None:
Args:
args: A namespace object from argparse. It should have the following attribute:
dataset_name (str): The name of the dataset to delete.
dataset (str): The name of the dataset to delete.
Returns:
None
Expand All @@ -104,7 +141,15 @@ def delete_dataset(args) -> None:
if confirmation.lower() != "y":
console.print("[bold yellow]Dataset deletion cancelled.[/]")
return

try:
if (
args.dataset is None
or not isinstance(args.dataset, str)
or not args.dataset
):
raise ValueError(ERROR_BENCHMARK_DELETE_DATASET_DATASET_VALIDATION)

api_delete_dataset(args.dataset)
print("[delete_dataset]: Dataset deleted.")
except Exception as e:
Expand Down
55 changes: 48 additions & 7 deletions moonshot/integrations/cli/benchmark/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@
from rich.table import Table

from moonshot.api import api_delete_metric, api_get_all_metric, api_get_all_metric_name
from moonshot.integrations.cli.cli_errors import (
ERROR_BENCHMARK_DELETE_METRIC_METRIC_VALIDATION,
ERROR_BENCHMARK_LIST_METRICS_FIND_VALIDATION,
ERROR_BENCHMARK_LIST_METRICS_PAGINATION_VALIDATION,
ERROR_BENCHMARK_LIST_METRICS_PAGINATION_VALIDATION_1,
ERROR_BENCHMARK_VIEW_METRIC_METRIC_FILENAME_VALIDATION,
)
from moonshot.integrations.cli.utils.process_data import filter_data

console = Console()
Expand All @@ -18,23 +25,44 @@ def list_metrics(args) -> list | None:
List all available metrics.
This function retrieves all available metrics by calling the api_get_all_metric function from the
moonshot.api module. It then displays the metrics using the _display_metrics function. If an exception occurs,
it prints an error message.
moonshot.api module. It then filters the metrics based on the provided keyword and pagination arguments.
If there are no metrics, it prints a message indicating that no metrics were found.
Args:
args: A namespace object from argparse. It should have an optional attribute:
find (str): Optional field to find metric(s) with a keyword.
pagination (str): Optional field to paginate metrics.
args: A namespace object from argparse. It should have optional attributes:
find (str): Optional field to find metric(s) with a keyword.
pagination (str): Optional field to paginate metrics.
Returns:
list | None: A list of Metric or None if there is no result.
list | None: A list of metrics or None if there are no metrics.
"""

try:
print("Listing metrics may take a while...")
if args.find is not None:
if not isinstance(args.find, str) or not args.find:
raise TypeError(ERROR_BENCHMARK_LIST_METRICS_FIND_VALIDATION)

if args.pagination is not None:
if not isinstance(args.pagination, str) or not args.pagination:
raise TypeError(ERROR_BENCHMARK_LIST_METRICS_PAGINATION_VALIDATION)
try:
pagination = literal_eval(args.pagination)
if not (
isinstance(pagination, tuple)
and len(pagination) == 2
and all(isinstance(i, int) for i in pagination)
):
raise ValueError(
ERROR_BENCHMARK_LIST_METRICS_PAGINATION_VALIDATION_1
)
except (ValueError, SyntaxError):
raise ValueError(ERROR_BENCHMARK_LIST_METRICS_PAGINATION_VALIDATION_1)
else:
pagination = ()

metrics_list = api_get_all_metric()
keyword = args.find.lower() if args.find else ""
pagination = literal_eval(args.pagination) if args.pagination else ()

if metrics_list:
filtered_metrics_list = filter_data(metrics_list, keyword, pagination)
Expand All @@ -44,8 +72,10 @@ def list_metrics(args) -> list | None:

console.print("[red]There are no metrics found.[/red]")
return None

except Exception as e:
print(f"[list_metrics]: {str(e)}")
return None


def view_metric(args) -> None:
Expand All @@ -65,6 +95,13 @@ def view_metric(args) -> None:
"""
try:
print("Viewing metrics may take a while...")
if (
not isinstance(args.metric_filename, str)
or not args.metric_filename
or args.metric_filename is None
):
raise TypeError(ERROR_BENCHMARK_VIEW_METRIC_METRIC_FILENAME_VALIDATION)

metrics_list = api_get_all_metric()
metrics_name_list = api_get_all_metric_name()

Expand Down Expand Up @@ -100,7 +137,11 @@ def delete_metric(args) -> None:
if confirmation.lower() != "y":
console.print("[bold yellow]Metric deletion cancelled.[/]")
return

try:
if args.metric is None or not isinstance(args.metric, str) or not args.metric:
raise ValueError(ERROR_BENCHMARK_DELETE_METRIC_METRIC_VALIDATION)

api_delete_metric(args.metric)
print("[delete_metric]: Metric deleted.")
except Exception as e:
Expand Down
Loading

0 comments on commit d53ad64

Please sign in to comment.