Skip to content

Commit

Permalink
Merge branch 'dev_main' into ms-632
Browse files Browse the repository at this point in the history
  • Loading branch information
imda-lionelteo committed Oct 21, 2024
2 parents 83ae0d2 + 720e1c4 commit 9f5810c
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 98 deletions.
14 changes: 10 additions & 4 deletions connectors/amazon-bedrock-connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import boto3
from botocore.config import Config
from moonshot.src.connectors.connector import Connector, perform_retry
from moonshot.src.connectors.connector_response import ConnectorResponse
from moonshot.src.connectors_endpoints.connector_endpoint_arguments import (
ConnectorEndpointArguments,
)
Expand Down Expand Up @@ -117,7 +118,7 @@ def __init__(self, ep_arguments: ConnectorEndpointArguments):

@Connector.rate_limited
@perform_retry
async def get_response(self, prompt: str) -> str:
async def get_response(self, prompt: str) -> ConnectorResponse:
"""Asynchronously send a prompt to the Amazon Bedrock API and return the generated response
This method uses the Bedrock Converse API, which provides more cross-model standardization
Expand All @@ -139,7 +140,7 @@ async def get_response(self, prompt: str) -> str:
prompt (str): The input prompt to send to the model.
Returns:
str: The text response generated by the selected model.
ConnectorResponse: An object containing the text response generated by the selected model.
"""
connector_prompt = f"{self.pre_prompt}{prompt}{self.post_prompt}"
req_params = {
Expand Down Expand Up @@ -167,6 +168,11 @@ async def get_response(self, prompt: str) -> str:
message,
)
# Ignore any non-text contents, and join together with '\n\n' if multiple are returned:
return "\n\n".join(
map(lambda m: m["text"], filter(lambda m: "text" in m, message["content"]))
return ConnectorResponse(
response="\n\n".join(
map(
lambda m: m["text"],
filter(lambda m: "text" in m, message["content"]),
)
)
)
9 changes: 5 additions & 4 deletions connectors/google-gemini-connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import google.generativeai as genai
from moonshot.src.connectors.connector import Connector, perform_retry
from moonshot.src.connectors.connector_response import ConnectorResponse
from moonshot.src.connectors_endpoints.connector_endpoint_arguments import (
ConnectorEndpointArguments,
)
Expand All @@ -22,20 +23,20 @@ def __init__(self, ep_arguments: ConnectorEndpointArguments):

@Connector.rate_limited
@perform_retry
async def get_response(self, prompt: str) -> str:
async def get_response(self, prompt: str) -> ConnectorResponse:
"""
Asynchronously sends a prompt to the Google Gemini API and returns the generated response.
This method constructs a request with the given prompt, optionally prepended and appended with
predefined strings, and sends it to the Google Gemini API. If a system prompt is set, it is included in the
request. The method then awaits the response from the API, processes it, and returns the resulting message
content as a string.
content wrapped in a ConnectorResponse object.
Args:
prompt (str): The input prompt to send to the Google Gemini API.
Returns:
str: The text response generated by the Google Gemini model.
ConnectorResponse: An object containing the text response generated by the Google Gemini model.
"""
connector_prompt = f"{self.pre_prompt}{prompt}{self.post_prompt}"

Expand All @@ -48,7 +49,7 @@ async def get_response(self, prompt: str) -> str:
response = model.generate_content(
connector_prompt, generation_config=generation_config
)
return await self._process_response(response)
return ConnectorResponse(response=await self._process_response(response))

async def _process_response(self, response: Any) -> str:
"""
Expand Down
4 changes: 2 additions & 2 deletions datasets/cache.json
Original file line number Diff line number Diff line change
Expand Up @@ -1237,10 +1237,10 @@
"description": "Contain prompts that test safety in Singapore-context",
"examples": null,
"num_of_dataset_prompts": 59,
"created_date": "2024-05-27 16:48:35",
"created_date": "2024-10-15 21:25:44",
"reference": "IMDA",
"license": "Apache-2.0",
"hash": "5d61c07c64808b9d"
"hash": "984714efbead7cc1"
},
"medqa-us": {
"id": "medqa-us",
Expand Down
Loading

0 comments on commit 9f5810c

Please sign in to comment.