Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MS-334] Dataset Converter #269

Merged
merged 22 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions moonshot/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
api_delete_dataset,
api_get_all_datasets,
api_get_all_datasets_name,
api_create_datasets
)
from moonshot.src.api.api_environment_variables import api_set_environment_variables
from moonshot.src.api.api_metrics import (
Expand Down Expand Up @@ -115,6 +116,7 @@
"api_read_cookbook",
"api_read_cookbooks",
"api_update_cookbook",
"api_create_datasets",
"api_delete_dataset",
"api_get_all_datasets",
"api_get_all_datasets_name",
Expand Down
8 changes: 8 additions & 0 deletions moonshot/integrations/cli/common/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

import cmd2

from moonshot.integrations.cli.common.dataset import (
add_dataset,
add_dataset_args
)
from moonshot.integrations.cli.common.connectors import (
add_endpoint,
add_endpoint_args,
Expand Down Expand Up @@ -56,6 +60,10 @@ def do_delete_prompt_template(self, args: argparse.Namespace) -> None:
def do_add_endpoint(self, args: argparse.Namespace) -> None:
add_endpoint(args)

@cmd2.with_argparser(add_dataset_args)
def do_add_dataset(self, args:argparse.Namespace) -> None:
add_dataset(args)

# ------------------------------------------------------------------------------
# Delete contents
# ------------------------------------------------------------------------------
Expand Down
66 changes: 66 additions & 0 deletions moonshot/integrations/cli/common/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from ast import literal_eval

import cmd2
from rich.console import Console

from moonshot.api import (
api_create_datasets,
)

console = Console()
def add_dataset(args) -> None:
"""
Create a new dataset using the provided arguments and log the result.

This function attempts to create a new dataset by calling the `api_create_datasets`
function with the necessary parameters extracted from `args`. If successful, it logs
the creation of the dataset with its ID. If an exception occurs, it logs the error.

Args:
args: An argparse.Namespace object containing the following attributes:
- name (str): Name of the new dataset.
- description (str): Description of the new dataset.
- reference (str): Reference URL for the new dataset.
- license (str): License type for the new dataset.
- method (str): Method to convert the new dataset ('hf' or 'csv').
- params (dict): Additional parameters for dataset creation.
"""
try:
imda-normanchia marked this conversation as resolved.
Show resolved Hide resolved
new_dataset_id = api_create_datasets(
args.name,
args.description,
args.reference,
args.license,
args.method,
**args.params,
)
print(f"[add_dataset]: Dataset ({new_dataset_id}) created.")
except Exception as e:
print(f"[add_dataset]: {str(e)}")

# ------------------------------------------------------------------------------
# Cmd2 Arguments Parsers
# ------------------------------------------------------------------------------
# Add dataset arguments
add_dataset_args = cmd2.Cmd2ArgumentParser(
description="Add a new dataset. The 'name' argument will be slugified to create a unique identifier.",
epilog=(
"Examples:\n"
"1. add_dataset 'dataset-name' 'A brief description' 'http://reference.com' 'MIT' 'csv' \"{'csv_file_path': '/path/to/your/file.csv'}\"\n"
"2. add_dataset 'dataset-name' 'A brief description' 'http://reference.com' 'MIT' 'hf' \"{'dataset_name': 'cais/mmlu', 'dataset_config': 'college_biology', 'split': 'test', 'input_col': ['question','choices'], 'target_col': 'answer'}\""
),
)
add_dataset_args.add_argument("name", type=str, help="Name of the new dataset")
add_dataset_args.add_argument("description", type=str, help="Description of the new dataset")
add_dataset_args.add_argument("reference", type=str, help="Reference of the new dataset")
add_dataset_args.add_argument("license", type=str, help="License of the new dataset")
add_dataset_args.add_argument("method", type=str, choices=['hf', 'csv'], help="Method to convert the new dataset. Choose either 'hf' or 'csv'.")
add_dataset_args.add_argument(
"params",
type=literal_eval,
help=(
"Params of the new dataset in dictionary format. For example: \n"
"1. For 'csv' method: \"{'csv_file_path': '/path/to/your/file.csv'}\"\n"
"2. For 'hf' method: \"{'dataset_name': 'cais_mmlu', 'dataset_config': 'college_biology', 'split': 'test', 'input_col': ['questions','choices'], 'target_col': 'answer'}\""
)
)
47 changes: 46 additions & 1 deletion moonshot/integrations/web_api/routes/dataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,59 @@
from dependency_injector.wiring import Provide, inject
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Query

from ..container import Container
from ..schemas.dataset_create_dto import DatasetCreateDTO
from ..schemas.dataset_response_dto import DatasetResponseDTO
from ..services.dataset_service import DatasetService
from ..services.utils.exceptions_handler import ServiceException

router = APIRouter(tags=["Datasets"])


@router.post("/api/v1/datasets")
@inject
def create_dataset(
dataset_data: DatasetCreateDTO,
method: str = Query(
...,
description="The method to use for creating the dataset. Supported methods are 'hf' and 'csv'.",
),
dataset_service: DatasetService = Depends(Provide[Container.dataset_service]),
) -> str:
"""
Create a new dataset using the specified method.

Args:
dataset_data (DatasetCreateDTO): The data required to create the dataset.
method (str): The method to use for creating the dataset. Supported methods are "hf" and "csv".
dataset_service (DatasetService, optional): The service responsible for creating the dataset.
Defaults to Depends(Provide[Container.dataset_service]).

Returns:
dict: A message indicating the dataset was created successfully.

Raises:
HTTPException: An error with status code 404 if the dataset file is not found.
An error with status code 400 if there is a validation error.
An error with status code 500 for any other server-side error.
"""
try:
return dataset_service.create_dataset(dataset_data, method)
except ServiceException as e:
if e.error_code == "FileNotFound":
raise HTTPException(
status_code=404, detail=f"Failed to retrieve datasets: {e.msg}"
)
elif e.error_code == "ValidationError":
raise HTTPException(
status_code=400, detail=f"Failed to retrieve datasets: {e.msg}"
)
else:
raise HTTPException(
status_code=500, detail=f"Failed to retrieve datasets: {e.msg}"
)


@router.get("/api/v1/datasets")
@inject
def get_all_datasets(
Expand Down
18 changes: 18 additions & 0 deletions moonshot/integrations/web_api/schemas/dataset_create_dto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Optional

from pydantic import Field
from pyparsing import Iterator

from moonshot.src.datasets.dataset_arguments import (
DatasetArguments as DatasetPydanticModel,
)


class DatasetCreateDTO(DatasetPydanticModel):
id: Optional[str] = None
examples: Iterator[dict] = None
name: str = Field(..., min_length=1)
description: str = Field(default="", min_length=1)
license: Optional[str] = ""
reference: Optional[str] = ""
params: dict
8 changes: 6 additions & 2 deletions moonshot/integrations/web_api/services/bookmark_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,22 @@ def get_all_bookmarks(self, name: str | None = None) -> list[BookmarkPydanticMod
@exception_handler
def delete_bookmarks(self, all: bool = False, name: str | None = None) -> dict:
"""
Deletes a single bookmark by its name or all bookmarks if the 'all' flag is set to True.
Deletes a single bookmark by its name or all bookmarks if the 'all' flag is set to True and returns
a boolean indicating the success of the operation.

Args:
all (bool, optional): If True, all bookmarks will be deleted. Defaults to False.
name (str | None, optional): The name of the bookmark to delete. If 'all' is False, 'name' must be provided.

Returns:
dict: True if the deletion was successful, False otherwise.
"""
if all:
result = moonshot_api.api_delete_all_bookmark()
elif name is not None:
result = moonshot_api.api_delete_bookmark(name)
else:
raise ValueError("Either 'all' must be True or 'id' must be provided.")
raise ValueError("Either 'all' must be True or 'name' must be provided.")

if not result["success"]:
raise Exception(
Expand Down
25 changes: 25 additions & 0 deletions moonshot/integrations/web_api/services/dataset_service.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,35 @@
from .... import api as moonshot_api
from ..schemas.dataset_create_dto import DatasetCreateDTO
from ..schemas.dataset_response_dto import DatasetResponseDTO
from ..services.base_service import BaseService
from ..services.utils.exceptions_handler import exception_handler
from .utils.file_manager import copy_file


class DatasetService(BaseService):
@exception_handler
def create_dataset(self, dataset_data: DatasetCreateDTO, method: str) -> str:
"""
Create a dataset using the specified method.

Args:
dataset_data (DatasetCreateDTO): The data required to create the dataset.
method (str): The method to use for creating the dataset.
Supported methods are "hf" and "csv".

Raises:
Exception: If an error occurs during dataset creation.
"""
new_ds_path = moonshot_api.api_create_datasets(
name=dataset_data.name,
description=dataset_data.description,
reference=dataset_data.reference,
license=dataset_data.license,
method=method,
**dataset_data.params,
)
return copy_file(new_ds_path)

@exception_handler
def get_all_datasets(self) -> list[DatasetResponseDTO]:
datasets = moonshot_api.api_get_all_datasets()
Expand Down
35 changes: 35 additions & 0 deletions moonshot/src/api/api_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pydantic import validate_call

from moonshot.src.datasets.dataset import Dataset
from moonshot.src.datasets.dataset_arguments import DatasetArguments


# ------------------------------------------------------------------------------
Expand Down Expand Up @@ -44,3 +45,37 @@ def api_get_all_datasets_name() -> list[str]:
"""
datasets_name, _ = Dataset.get_available_items()
return datasets_name


def api_create_datasets(
name: str, description: str, reference: str, license: str, method: str, **kwargs
imda-lionelteo marked this conversation as resolved.
Show resolved Hide resolved
) -> str:
"""
This function creates a new dataset.

This function takes the name, description, reference, and license for a new dataset as input. It then creates a new
DatasetArguments object with these details and an empty id. The id is left empty because it will be generated
from the name during the creation process. The function then calls the Dataset's create method to
create the new dataset.

Args:
name (str): The name of the new dataset.
description (str): A brief description of the new dataset.
reference (str): A reference link for the new dataset.
license (str): The license of the new dataset.
method (str): The method to create new dataset. (csv/hf)
kwargs: Additional keyword arguments for the Dataset's create method.
imda-normanchia marked this conversation as resolved.
Show resolved Hide resolved

Returns:
str: The ID of the newly created dataset.
"""
ds_args = DatasetArguments(
id="",
name=name,
description=description,
reference=reference,
license=license,
examples=None,
)

return Dataset.create(ds_args, method, **kwargs)
Loading
Loading