Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Sprint 11] New features and fixes #73

Merged
merged 15 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion connectors-endpoints/flageval-flagjudge.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
{
"id": "flageval-flagjudge",
"name": "FlagEval flagjudge",
"connector_type": "flageval-connector",
"uri": "http://120.92.208.64:7611/worker_generate_stream",
Expand Down
51 changes: 34 additions & 17 deletions connectors/flageval-connector.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import logging, json
import json

import aiohttp
from aiohttp import ClientResponse
from typing import Callable
from moonshot.src.connectors.connector import Connector, perform_retry
from moonshot.src.connectors.connector_prompt_arguments import ConnectorPromptArguments
from moonshot.src.connectors_endpoints.connector_endpoint_arguments import (
ConnectorEndpointArguments,
)
from moonshot.src.utils.log import configure_logger

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Create a logger for this module
logger = configure_logger(__name__)


class FlagJudgeConnector(Connector):
Expand All @@ -20,7 +20,26 @@ def __init__(self, ep_arguments: ConnectorEndpointArguments):

@Connector.rate_limited
@perform_retry
async def get_response(self, prompt: str, prediction: str, ground_truth: str) -> str:
async def get_response(self, prompt: str) -> str:
"""
Abstract method to be implemented by subclasses to get a response from the connector.

This method should asynchronously send a prompt to the connector's API and return the response.

Args:
prompt (str): The input prompt to be sent to the connector.

Returns:
str: The response received from the connector.
"""
# Default implementation, can be overridden by subclasses
raise NotImplementedError(
"This connector is specifically designed for the FlagevalAnnotator metric. Please select an appropriate connector."
)

async def get_judge_response(
self, prompt: str, prediction: str, ground_truth: str
) -> str:
"""
Retrieve and return a response.
This method is used to retrieve a response, typically from an object or service represented by
Expand All @@ -36,7 +55,7 @@ async def get_response(self, prompt: str, prediction: str, ground_truth: str) ->
"pred": prediction,
"gold": ground_truth,
"echo": False,
"stream": False
"stream": False,
}
async with aiohttp.ClientSession() as session:
async with session.post(
Expand All @@ -48,10 +67,7 @@ async def get_response(self, prompt: str, prediction: str, ground_truth: str) ->

@Connector.rate_limited
@perform_retry
async def get_prediction(
self,
generated_prompt: ConnectorPromptArguments
):
async def get_prediction(self, generated_prompt: ConnectorPromptArguments):
"""
The method then returns the `judge_result` generated by flagjudge model.

Expand All @@ -65,21 +81,22 @@ async def get_prediction(
Exception: If there is an error during prediction.
"""
try:
print(f"Predicting prompt {generated_prompt.prompt_index} [{self.id}]")
logger.info(
f"Predicting prompt {generated_prompt.prompt_index} [{self.id}]"
)

judge_result = await self.get_response(
judge_result = await self.get_judge_response(
generated_prompt.prompt,
generated_prompt.predicted_results,
generated_prompt.target
generated_prompt.target,
)
# Return the judge_result
return judge_result

except Exception as e:
print(f"Failed to get prediction: {str(e)}")
logger.error(f"Failed to get prediction: {str(e)}")
raise e


def _prepare_headers(self) -> dict:
"""
Prepare HTTP headers for authentication using a bearer token.
Expand Down Expand Up @@ -121,9 +138,9 @@ async def _process_response(self, response: ClientResponse) -> str:
text = data["text"].strip()
output = text
return output

except Exception as exception:
print(
logger.error(
f"An exception has occurred: {str(exception)}, {await response.text()}"
)
raise exception
33 changes: 19 additions & 14 deletions databases-modules/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def create_connection(self) -> bool:
object's initialization.

If the connection is successfully established, it returns True.
If an error occurs during the connection process, it prints an error message with the details of the
If an error occurs during the connection process, it logs an error message with the details of the
SQLite error and returns False.

Returns:
Expand All @@ -44,7 +44,7 @@ def close_connection(self) -> None:
Closes the connection to the SQLite database.

If the connection is already established, it attempts to close it and sets the connection attribute to None.
If an error occurs during the closing process, it prints an error message with the details of the SQLite error.
If an error occurs during the closing process, it logs an error message with the details of the SQLite error.

Returns:
None
Expand All @@ -68,7 +68,7 @@ def create_table(self, create_table_sql: str) -> None:

This method attempts to create a table in the SQLite database using the provided SQL query.
If the connection to the SQLite database is established, it executes the SQL query.
If an error occurs during the table creation process, it prints an error message with the details of the
If an error occurs during the table creation process, it logs an error message with the details of the
SQLite error.

Args:
Expand Down Expand Up @@ -99,7 +99,7 @@ def create_record(self, record: tuple, create_record_sql: str) -> tuple | None:
If the operation is successful, it commits the transaction and returns the ID of the inserted record along
with the record data.

If an error occurs during the record insertion process, it prints an error message with the details of the
If an error occurs during the record insertion process, it logs an error message with the details of the
SQLite error and returns None.

Args:
Expand Down Expand Up @@ -135,7 +135,7 @@ def read_record(self, record: tuple, read_record_sql: str) -> tuple | None:

If the connection to the SQLite database is established, it executes the SQL query with the record and returns
the fetched record.
If an error occurs during the record reading process, it prints an error message with the details of the SQLite
If an error occurs during the record reading process, it logs an error message with the details of the SQLite
error and returns None.

Args:
Expand Down Expand Up @@ -164,7 +164,7 @@ def read_records(self, read_records_sql: str) -> list[tuple] | None:

This method attempts to execute a provided SQL query to read data from a table within the SQLite database.
If the connection to the database is established, it executes the query and returns all fetched rows as a list.
In case of an error during the execution of the query, it prints an error message detailing the issue.
In case of an error during the execution of the query, it logs an error message detailing the issue.

Args:
read_records_sql (str): The SQL query string used to read data from a table.
Expand All @@ -191,7 +191,7 @@ def update_record(self, record: tuple, update_record_sql: str) -> None:

This method attempts to update a record in the SQLite database using the provided SQL query and record.
If the connection to the SQLite database is established, it executes the SQL query with the record.
If an error occurs during the record updating process, it prints an error message with the details of the
If an error occurs during the record updating process, it logs an error message with the details of the
SQLite error.

Args:
Expand All @@ -217,7 +217,7 @@ def delete_record_by_id(self, record_id: int, delete_record_sql: str) -> None:

This method attempts to delete a record from the SQLite database using the provided SQL query and record ID.
If the connection to the SQLite database is established, it executes the SQL query with the record ID.
If an error occurs during the record deletion process, it prints an error message with the details of the
If an error occurs during the record deletion process, it logs an error message with the details of the
SQLite error.

Args:
Expand All @@ -234,15 +234,19 @@ def delete_record_by_id(self, record_id: int, delete_record_sql: str) -> None:
cursor.execute(delete_record_sql, (record_id,))
self.sqlite_conn.commit()
except sqlite3.Error as sqlite3_error:
print(f"Error deleting record from database - {str(sqlite3_error)}")
logger.error(
f"Error deleting record from database - {str(sqlite3_error)}"
)

def delete_records_in_table(self, delete_record_sql: str) -> None:
"""
Deletes all records from a table in the SQLite database using the provided SQL query.

This method attempts to delete all records from a specific table in the SQLite database using the provided SQL query.
This method attempts to delete all records from a specific table in the SQLite database using the provided SQL
query.

If the connection to the SQLite database is established, it executes the SQL query to delete the records.
If an error occurs during the deletion process, it prints an error message with the details of the SQLite error.
If an error occurs during the deletion process, it logs an error message with the details of the SQLite error.

Args:
delete_record_sql (str): The SQL query to delete all records from a table.
Expand All @@ -257,8 +261,9 @@ def delete_records_in_table(self, delete_record_sql: str) -> None:
cursor.execute(delete_record_sql)
self.sqlite_conn.commit()
except sqlite3.Error as sqlite3_error:
print(f"Error deleting records from database - {str(sqlite3_error)}")

logger.error(
f"Error deleting records from database - {str(sqlite3_error)}"
)

def check_database_table_exists(self, table_name: str) -> bool | None:
"""
Expand Down Expand Up @@ -298,7 +303,7 @@ def delete_database_table(self, delete_table_sql: str) -> None:

This method attempts to delete a table from the SQLite database using the provided SQL query.
If the connection to the SQLite database is established, it executes the SQL query to delete the table.
If an error occurs during the table deletion process, it prints an error message with the details of the
If an error occurs during the table deletion process, it logs an error message with the details of the
SQLite error.

Args:
Expand Down
14 changes: 10 additions & 4 deletions metrics/cache.json
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,16 @@
"description": "Sample Metric will provide examples on connecting to LLMs.",
"hash": "c8d1e48a91f66eb8"
},
"singapore-toxicity-classifier": {
"id": "singapore-toxicity-classifier",
"name": "Singapore Toxicity Classifier",
"lionguardclassifier": {
"id": "lionguardclassifier",
"name": "LionGuard Classifier",
"description": "This classifier measures how unsafe a given input is in the Singapore context.\nIt can calculate the number of toxic sentences detected. ",
"hash": "012e026bfa28ac13"
"hash": "e1e5a08a2651091a"
},
"flagjudgeannotator": {
"id": "flagjudgeannotator",
"name": "FlagJudge Annotator",
"description": "Calculates the number of correct response using FlagJudge (A judge model from FlagEval Group).",
"hash": "b9207a4b57a331da"
}
}
9 changes: 4 additions & 5 deletions metrics/flagjudgeannotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async def get_results(
self, prompts: Any, predicted_results: Any, targets: Any, *args, **kwargs
) -> dict:
"""
Calculates the number of refused response using GPT4.
Calculates the number of refused response using GPT4.

Args:
prompts (Any): The prompts used for prediction.
Expand All @@ -59,14 +59,14 @@ async def get_results(
prompt_index=prompt_index,
prompt=prompt,
target=target,
predicted_results=result
predicted_results=result,
)

prompt_index += 1
# We need to deepcopy because the connector will overwrite the prompt argument with the predicted results
# and the duration taken.
my_new_prompt = copy.deepcopy(sample_prompt_argument)

judge_result = await evaluation_model.get_prediction(my_new_prompt)
judge_results.append(judge_result)

Expand All @@ -79,7 +79,6 @@ async def get_results(
else:
number_of_wrong += 1


total = number_of_correct + number_of_wrong
correct_rate = number_of_correct / total * 100

Expand All @@ -89,4 +88,4 @@ async def get_results(
"number_of_wrong": number_of_wrong,
"total": total,
"grading_criteria": {"correct_rate": correct_rate},
}
}
Loading