Skip to content

Commit

Permalink
Merge pull request #310 from aiverify-foundation/ms-406
Browse files Browse the repository at this point in the history
[MS-406] Add unit-tests to bookmarks, bookmarks-api and bookmarkarguments
  • Loading branch information
imda-kelvinkok authored Aug 27, 2024
2 parents d53ad64 + 4198bf5 commit 936f62f
Show file tree
Hide file tree
Showing 7 changed files with 1,772 additions and 281 deletions.
12 changes: 6 additions & 6 deletions moonshot/src/api/api_bookmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def api_insert_bookmark(
metric=metric,
bookmark_time="", # bookmark_time will be set to current time in add_bookmark method
)
return Bookmark.get_instance().add_bookmark(bookmark_args)
return Bookmark().add_bookmark(bookmark_args)


def api_get_all_bookmarks() -> list[dict]:
Expand All @@ -49,7 +49,7 @@ def api_get_all_bookmarks() -> list[dict]:
Returns:
list[dict]: A list of bookmarks, each represented as a dictionary.
"""
return Bookmark.get_instance().get_all_bookmarks()
return Bookmark().get_all_bookmarks()


def api_get_bookmark(bookmark_name: str) -> dict:
Expand All @@ -62,7 +62,7 @@ def api_get_bookmark(bookmark_name: str) -> dict:
Returns:
dict: The bookmark details corresponding to the provided ID.
"""
return Bookmark.get_instance().get_bookmark(bookmark_name)
return Bookmark().get_bookmark(bookmark_name)


def api_delete_bookmark(bookmark_name: str) -> dict:
Expand All @@ -72,14 +72,14 @@ def api_delete_bookmark(bookmark_name: str) -> dict:
Args:
bookmark_name (str): The name of the bookmark to be removed.
"""
return Bookmark.get_instance().delete_bookmark(bookmark_name)
return Bookmark().delete_bookmark(bookmark_name)


def api_delete_all_bookmark() -> dict:
"""
Removes all bookmarks from the database.
"""
return Bookmark.get_instance().delete_all_bookmark()
return Bookmark().delete_all_bookmark()


def api_export_bookmarks(export_file_name: str = "bookmarks") -> str:
Expand All @@ -92,4 +92,4 @@ def api_export_bookmarks(export_file_name: str = "bookmarks") -> str:
Returns:
str: The filepath of where the file is written.
"""
return Bookmark.get_instance().export_bookmarks(export_file_name)
return Bookmark().export_bookmarks(export_file_name)
179 changes: 119 additions & 60 deletions moonshot/src/bookmark/bookmark.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,24 @@
from __future__ import annotations

import textwrap
from datetime import datetime

from moonshot.src.bookmark.bookmark_arguments import BookmarkArguments
from moonshot.src.configs.env_variables import EnvVariables
from moonshot.src.messages_constants import (
BOOKMARK_ADD_BOOKMARK_ERROR,
BOOKMARK_ADD_BOOKMARK_SUCCESS,
BOOKMARK_ADD_BOOKMARK_VALIDATION_ERROR,
BOOKMARK_DELETE_ALL_BOOKMARK_ERROR,
BOOKMARK_DELETE_ALL_BOOKMARK_SUCCESS,
BOOKMARK_DELETE_BOOKMARK_ERROR,
BOOKMARK_DELETE_BOOKMARK_ERROR_1,
BOOKMARK_DELETE_BOOKMARK_SUCCESS,
BOOKMARK_EXPORT_BOOKMARK_ERROR,
BOOKMARK_EXPORT_BOOKMARK_VALIDATION_ERROR,
BOOKMARK_GET_BOOKMARK_ERROR,
BOOKMARK_GET_BOOKMARK_ERROR_1,
)
from moonshot.src.storage.storage import Storage
from moonshot.src.utils.log import configure_logger

Expand All @@ -12,36 +29,7 @@
class Bookmark:
_instance = None

def __new__(cls, db_name="bookmark"):
"""
Create a new instance of the Bookmark class or return the existing instance.
Args:
db_name (str): The name of the database.
Returns:
Bookmark: The singleton instance of the Bookmark class.
"""
if cls._instance is None:
cls._instance = super(Bookmark, cls).__new__(cls)
cls._instance.__init_instance(db_name)
return cls._instance

@classmethod
def get_instance(cls, db_name="bookmark"):
"""
Get the singleton instance of the Bookmark class.
Args:
db_name (str): The name of the database.
Returns:
Bookmark: The singleton instance of the Bookmark class.
"""
if cls._instance is None:
cls._instance = super(Bookmark, cls).__new__(cls)
cls._instance.__init_instance(db_name)
return cls._instance
sql_table_name = "bookmark"

sql_create_bookmark_table = """
CREATE TABLE IF NOT EXISTS bookmark (
Expand Down Expand Up @@ -77,19 +65,46 @@ def get_instance(cls, db_name="bookmark"):
DELETE FROM bookmark;
"""

def __init_instance(self, db_name) -> None:
def __new__(cls, db_name="bookmark"):
"""
Create a new instance of the Bookmark class or return the existing instance.
This method ensures that only one instance of the Bookmark class is created (singleton pattern).
If an instance already exists, it returns that instance. Otherwise, it creates a new instance
and initializes it with the provided database name.
Args:
db_name (str): The name of the database. Defaults to "bookmark".
Returns:
Bookmark: The singleton instance of the Bookmark class.
"""
if cls._instance is None:
cls._instance = super(Bookmark, cls).__new__(cls)
cls._instance.__init_instance(db_name)
return cls._instance

def __init_instance(self, db_name: str = "bookmark") -> None:
"""
Initialize the database instance for the Bookmark class.
This method sets up the database connection for the Bookmark class. It creates a new database
connection using the provided database name and checks if the required table exists. If the table
does not exist, it creates the table.
Args:
db_name (str): The name of the database.
db_name (str): The name of the database. Defaults to "bookmark".
"""
self.db_instance = Storage.create_database_connection(
EnvVariables.BOOKMARKS.name, db_name, "db"
)
Storage.create_database_table(
self.db_instance, Bookmark.sql_create_bookmark_table
)

if not Storage.check_database_table_exists(
self.db_instance, Bookmark.sql_table_name
):
Storage.create_database_table(
self.db_instance, Bookmark.sql_create_bookmark_table
)

def add_bookmark(self, bookmark: BookmarkArguments) -> dict:
"""
Expand All @@ -99,7 +114,7 @@ def add_bookmark(self, bookmark: BookmarkArguments) -> dict:
bookmark (BookmarkArguments): The bookmark data to add.
Returns:
bool: True if the bookmark was added successfully, False otherwise.
dict: A dictionary containing the success status and a message.
"""
bookmark.bookmark_time = datetime.now().replace(microsecond=0).isoformat(" ")

Expand All @@ -119,12 +134,14 @@ def add_bookmark(self, bookmark: BookmarkArguments) -> dict:
self.db_instance, data, Bookmark.sql_insert_bookmark_record
)
if results is not None:
return {"success": True, "message": "Bookmark added successfully."}
return {"success": True, "message": BOOKMARK_ADD_BOOKMARK_SUCCESS}
else:
raise Exception("Error inserting record into database.")
raise Exception(BOOKMARK_ADD_BOOKMARK_VALIDATION_ERROR)
except Exception as e:
error_message = f"Failed to add bookmark record: {e}"
return {"success": False, "message": error_message}
return {
"success": False,
"message": BOOKMARK_ADD_BOOKMARK_ERROR.format(message=str(e)),
}

def get_all_bookmarks(self) -> list[dict]:
"""
Expand All @@ -137,7 +154,9 @@ def get_all_bookmarks(self) -> list[dict]:
self.db_instance,
Bookmark.sql_select_bookmarks_record,
)
if list_of_bookmarks_tuples:
if isinstance(list_of_bookmarks_tuples, list) and all(
isinstance(item, tuple) for item in list_of_bookmarks_tuples
):
list_of_bookmarks = [
BookmarkArguments.from_tuple_to_dict(bookmark_tuple)
for bookmark_tuple in list_of_bookmarks_tuples
Expand All @@ -151,49 +170,66 @@ def get_bookmark(self, bookmark_name: str) -> dict:
Retrieve a bookmark by its unique name.
Args:
bookmark_name (int): The unique name for the bookmark.
bookmark_name (str): The unique name for the bookmark.
Returns:
dict: The bookmark information as a dictionary.
Raises:
RuntimeError: If the bookmark cannot be found.
"""
if bookmark_name is not None:
if isinstance(bookmark_name, str) and bookmark_name:
bookmark_info = Storage.read_database_record(
self.db_instance, (bookmark_name,), Bookmark.sql_select_bookmark_record
)
if bookmark_info is not None:
if (
bookmark_info is not None
and isinstance(bookmark_info, tuple)
and all(isinstance(item, str) for item in bookmark_info)
):
return BookmarkArguments.from_tuple_to_dict(bookmark_info)
else:
raise RuntimeError(
f"[Bookmark] No record found for bookmark name {bookmark_name}"
BOOKMARK_GET_BOOKMARK_ERROR.format(message=bookmark_name)
)
else:
raise RuntimeError(f"[Bookmark] Invalid bookmark name: {bookmark_name}")
raise RuntimeError(
BOOKMARK_GET_BOOKMARK_ERROR_1.format(message=bookmark_name)
)

def delete_bookmark(self, bookmark_name: str) -> dict:
"""
Delete a bookmark by its unique name.
Args:
bookmark_name (str): The unique name for the bookmark to be deleted.
Returns:
dict: A dictionary containing the success status and a message.
"""
if bookmark_name is not None:
if isinstance(bookmark_name, str) and bookmark_name:
try:
sql_delete_bookmark_record = f"""
sql_delete_bookmark_record = textwrap.dedent(
f"""
DELETE FROM bookmark WHERE name = '{bookmark_name}';
"""
)
Storage.delete_database_record_in_table(
self.db_instance, sql_delete_bookmark_record
)
return {"success": True, "message": "Bookmark record deleted."}
return {"success": True, "message": BOOKMARK_DELETE_BOOKMARK_SUCCESS}
except Exception as e:
error_message = f"Failed to delete bookmark record: {e}"
return {"success": False, "message": error_message}
return {
"success": False,
"message": BOOKMARK_DELETE_BOOKMARK_ERROR.format(message=str(e)),
}
else:
error_message = f"[Bookmark] Invalid bookmark name: {bookmark_name}"
return {"success": False, "message": error_message}
return {
"success": False,
"message": BOOKMARK_DELETE_BOOKMARK_ERROR_1.format(
message=bookmark_name
),
}

def delete_all_bookmark(self) -> dict:
"""
Expand All @@ -206,10 +242,12 @@ def delete_all_bookmark(self) -> dict:
Storage.delete_database_record_in_table(
self.db_instance, Bookmark.sql_delete_bookmark_records
)
return {"success": True, "message": "All bookmark records deleted."}
return {"success": True, "message": BOOKMARK_DELETE_ALL_BOOKMARK_SUCCESS}
except Exception as e:
error_message = f"Failed to delete all bookmark records: {e}"
return {"success": False, "message": error_message}
return {
"success": False,
"message": BOOKMARK_DELETE_ALL_BOOKMARK_ERROR.format(message=str(e)),
}

def export_bookmarks(self, export_file_name: str = "bookmarks") -> str:
"""
Expand All @@ -224,13 +262,27 @@ def export_bookmarks(self, export_file_name: str = "bookmarks") -> str:
Returns:
str: The path to the exported JSON file containing the bookmarks.
Raises:
Exception: If the export file name is invalid or an error occurs during export.
"""
if not isinstance(export_file_name, str) or not export_file_name:
error_message = BOOKMARK_EXPORT_BOOKMARK_ERROR.format(
message=BOOKMARK_EXPORT_BOOKMARK_VALIDATION_ERROR
)
logger.error(error_message)
raise Exception(error_message)

list_of_bookmarks_tuples = Storage.read_database_records(
self.db_instance,
Bookmark.sql_select_bookmarks_record,
)

if list_of_bookmarks_tuples is not None:
if (
list_of_bookmarks_tuples is not None
and isinstance(list_of_bookmarks_tuples, list)
and all(isinstance(item, tuple) for item in list_of_bookmarks_tuples)
):
bookmarks_json = [
BookmarkArguments.from_tuple_to_dict(bookmark_tuple)
for bookmark_tuple in list_of_bookmarks_tuples
Expand All @@ -246,12 +298,19 @@ def export_bookmarks(self, export_file_name: str = "bookmarks") -> str:
"json",
)
except Exception as e:
logger.error(f"Failed to export bookmarks - {str(e)}.")
raise e
error_message = BOOKMARK_EXPORT_BOOKMARK_ERROR.format(message=str(e))
logger.error(error_message)
raise Exception(error_message)

def close(self) -> None:
"""
Close the database connection.
Close the database connection and set the Bookmark instance to None.
This method ensures that the database connection is properly closed and the singleton
instance of the Bookmark class is reset to None, allowing for a fresh instance to be created
if needed in the future.
"""
if self.db_instance:
Storage.close_database_connection(self.db_instance)

Bookmark._instance = None
10 changes: 10 additions & 0 deletions moonshot/src/bookmark/bookmark_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

from pydantic import BaseModel, Field

from moonshot.src.messages_constants import (
BOOKMARK_ARGUMENTS_FROM_TUPLE_TO_DICT_VALIDATION_ERROR,
)


class BookmarkArguments(BaseModel):
name: str = Field(min_length=1)
Expand All @@ -24,7 +28,13 @@ def from_tuple_to_dict(cls, values: tuple) -> dict:
Returns:
dict: A dictionary representing the BookmarkArguments.
Raises:
ValueError: If the number of values in the tuple is less than 10.
"""
if len(values) < 10:
raise ValueError(BOOKMARK_ARGUMENTS_FROM_TUPLE_TO_DICT_VALIDATION_ERROR)

return {
"name": values[1],
"prompt": values[2],
Expand Down
Loading

0 comments on commit 936f62f

Please sign in to comment.