Skip to content

Commit

Permalink
Added anthropic support
Browse files Browse the repository at this point in the history
  • Loading branch information
tnahddisttud committed Jul 14, 2024
1 parent 1c20ddb commit d8b69e0
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 3 deletions.
3 changes: 2 additions & 1 deletion mindsql/_utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,6 @@
CONFIG_REQUIRED_ERROR = "Configuration is required."
LLAMA_PROMPT_EXCEPTION = "Prompt cannot be empty."
OPENAI_VALUE_ERROR = "OpenAI API key is required"
OPENAI_PROMPT_EMPTY_EXCEPTION = "Prompt cannot be empty."
PROMPT_EMPTY_EXCEPTION = "Prompt cannot be empty."
POSTGRESQL_SHOW_CREATE_TABLE_QUERY = """SELECT 'CREATE TABLE "' || table_name || '" (' || array_to_string(array_agg(column_name || ' ' || data_type), ', ') || ');' AS create_statement FROM information_schema.columns WHERE table_name = '{table}' GROUP BY table_name;"""
ANTHROPIC_VALUE_ERROR = "Anthropic API key is required"
91 changes: 91 additions & 0 deletions mindsql/llms/anthropic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from anthropic import Anthropic

from . import ILlm
from .._utils.constants import ANTHROPIC_VALUE_ERROR, PROMPT_EMPTY_EXCEPTION


class AnthropicAi(ILlm):
def __init__(self, config=None, client=None):
"""
Initialize the class with an optional config parameter.
Parameters:
config (any): The configuration parameter.
client (any): The client parameter.
Returns:
None
"""
self.config = config
self.client = client

if client is not None:
self.client = client
return

if 'api_key' not in config:
raise ValueError(ANTHROPIC_VALUE_ERROR)
api_key = config.pop('api_key')
self.client = Anthropic(api_key=api_key, **config)

def system_message(self, message: str) -> any:
"""
Create a system message.
Parameters:
message (str): The message parameter.
Returns:
any
"""
return {"role": "system", "content": message}

def user_message(self, message: str) -> any:
"""
Create a user message.
Parameters:
message (str): The message parameter.
Returns:
any
"""
return {"role": "user", "content": message}

def assistant_message(self, message: str) -> any:
"""
Create an assistant message.
Parameters:
message (str): The message parameter.
Returns:
any
"""
return {"role": "assistant", "content": message}

def invoke(self, prompt, **kwargs) -> str:
"""
Submit a prompt to the model for generating a response.
Parameters:
prompt (str): The prompt parameter.
**kwargs: Additional keyword arguments (optional).
- temperature (float): The temperature parameter for controlling randomness in generation.
- max_tokens (int): Maximum number of tokens to be generated.
Returns:
str: The generated response from the model.
"""
if prompt is None or len(prompt) == 0:
raise Exception(PROMPT_EMPTY_EXCEPTION)

model = self.config.get("model", "claude-3-opus-20240229")
temperature = kwargs.get("temperature", 0.1)
max_tokens = kwargs.get("max_tokens", 1024)
response = self.client.messages.create(model=model, messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens, temperature=temperature)
for content in response.content:
if isinstance(content, dict) and content.get("type") == "text":
return content["text"]
elif hasattr(content, "text"):
return content.text
4 changes: 2 additions & 2 deletions mindsql/llms/open_ai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from openai import OpenAI

from . import ILlm
from .._utils.constants import OPENAI_VALUE_ERROR, OPENAI_PROMPT_EMPTY_EXCEPTION
from .._utils.constants import OPENAI_VALUE_ERROR, PROMPT_EMPTY_EXCEPTION


class OpenAi(ILlm):
Expand Down Expand Up @@ -77,7 +77,7 @@ def invoke(self, prompt, **kwargs) -> str:
str: The generated response from the model.
"""
if prompt is None or len(prompt) == 0:
raise Exception(OPENAI_PROMPT_EMPTY_EXCEPTION)
raise Exception(PROMPT_EMPTY_EXCEPTION)

model = self.config.get("model", "gpt-3.5-turbo")
temperature = kwargs.get("temperature", 0.1)
Expand Down

0 comments on commit d8b69e0

Please sign in to comment.