Skip to content

Commit

Permalink
Anthropic MVP
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael McCulloch committed Jul 18, 2024
1 parent c944d21 commit ba92627
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 6 deletions.
16 changes: 16 additions & 0 deletions llmx/configs/config.default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,26 @@ model:

# list of supported providers.
providers:
anthropic:
name: Anthropic
description: Anthropic's Claude models.
models:
- name: claude-3-5-sonnet-20240620
max_tokens: 8192
model:
provider: anthropic
parameters:
model: claude-3-5-sonnet-20240620
openai:
name: OpenAI
description: OpenAI's and AzureOpenAI GPT-3 and GPT-4 models.
models:
- name: gpt-4o # general model name, can be anything
max_tokens: 4096 # max supported tokens
model:
provider: openai
parameters:
model: gpt-4o
- name: gpt-4 # general model name, can be anything
max_tokens: 8192 # max supported tokens
model:
Expand Down
129 changes: 129 additions & 0 deletions llmx/generators/text/anthropic_textgen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from typing import Union, List, Dict
import os
import anthropic
from dataclasses import asdict

from .base_textgen import TextGenerator
from ...datamodel import TextGenerationConfig, TextGenerationResponse, Message
from ...utils import cache_request, get_models_maxtoken_dict, num_tokens_from_messages


class AnthropicTextGenerator(TextGenerator):
def __init__(
self,
api_key: str = None,
provider: str = "anthropic",
model: str = None,
models: Dict = None,
):
super().__init__(provider=provider)
api_key = api_key or os.environ.get("ANTHROPIC_API_KEY", None)
if api_key is None:
raise ValueError(
"Anthropic API key is not set. Please set the ANTHROPIC_API_KEY environment variable."
)
self.client = anthropic.Anthropic(
api_key=api_key,
default_headers={"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"},
)
self.model_max_token_dict = get_models_maxtoken_dict(models)
self.model_name = model or "claude-3-5-sonnet-20240620"

def format_messages(self, messages):
formatted_messages = []
for message in messages:
formatted_message = {"role": message["role"], "content": message["content"]}
formatted_messages.append(formatted_message)
return formatted_messages


def generate(
self,
messages: Union[List[Dict], str],
config: TextGenerationConfig = TextGenerationConfig(),
**kwargs,
) -> TextGenerationResponse:
use_cache = config.use_cache
model = config.model or self.model_name
prompt_tokens = num_tokens_from_messages(messages)
max_tokens = max(
self.model_max_token_dict.get(model, 8192) - prompt_tokens - 10, 200
)

# Process messages
system_message = None
other_messages = []
for message in messages:
message["content"] = message["content"].strip()
if message["role"] == "system":
if system_message is None:
system_message = message["content"]
else:
# If multiple system messages, concatenate them
system_message += "\n" + message["content"]
else:
other_messages.append(message)

if not other_messages:
raise ValueError("At least one message is required")

# Check if inversion is needed
needs_inversion = other_messages[0]["role"] == "assistant"
if needs_inversion:
other_messages = self.invert_messages(other_messages)

anthropic_config = {
"model": model,
"max_tokens": config.max_tokens or max_tokens,
"temperature": config.temperature,
"top_p": config.top_p,
"messages": other_messages,
}

if system_message:
anthropic_config["system"] = system_message

self.model_name = model
cache_key_params = anthropic_config.copy()
cache_key_params["messages"] = messages # Keep original messages for caching

if use_cache:
response = cache_request(cache=self.cache, params=cache_key_params)
if response:
return TextGenerationResponse(**response)
anthropic_response = self.client.messages.create(**anthropic_config)

response_content = anthropic_response.content[0].text

# Always strip "Human: " prefix, regardless of inversion
if response_content.startswith("Human: "):
response_content = response_content[7:]

response = TextGenerationResponse(
text=[Message(role="assistant", content=response_content)],
logprobs=[],
config=anthropic_config,
usage={
"prompt_tokens": anthropic_response.usage.input_tokens,
"completion_tokens": anthropic_response.usage.output_tokens,
"total_tokens": anthropic_response.usage.input_tokens
+ anthropic_response.usage.output_tokens,
},
response=anthropic_response,
)

cache_request(
cache=self.cache, params=cache_key_params, values=asdict(response)
)
return response

def invert_messages(self, messages):
inverted = []
for i, message in enumerate(messages):
if i % 2 == 0:
inverted.append({"role": "user", "content": message["content"]})
else:
inverted.append({"role": "assistant", "content": message["content"]})
return inverted
def count_tokens(self, text) -> int:
return num_tokens_from_messages(text)
16 changes: 10 additions & 6 deletions llmx/generators/text/textgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,26 @@
from .openai_textgen import OpenAITextGenerator
from .palm_textgen import PalmTextGenerator
from .cohere_textgen import CohereTextGenerator
from .anthropic_textgen import AnthropicTextGenerator
import logging

logger = logging.getLogger("llmx")


def sanitize_provider(provider: str):
if provider.lower() == "openai" or provider.lower() == "default" or provider.lower(
) == "azureopenai" or provider.lower() == "azureoai":
if provider.lower() == "openai" or provider.lower() == "default" or provider.lower() == "azureopenai" or provider.lower() == "azureoai":
return "openai"
elif provider.lower() == "palm" or provider.lower() == "google":
return "palm"
elif provider.lower() == "cohere":
return "cohere"
elif provider.lower() == "hf" or provider.lower() == "huggingface":
return "hf"
elif provider.lower() == "anthropic" or provider.lower() == "claude":
return "anthropic"
else:
raise ValueError(
f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', and 'cohere'."
f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'cohere', and 'anthropic'."
)


Expand Down Expand Up @@ -54,6 +56,8 @@ def llm(provider: str = None, **kwargs):
return PalmTextGenerator(**kwargs)
elif provider.lower() == "cohere":
return CohereTextGenerator(**kwargs)
elif provider.lower() == "anthropic":
return AnthropicTextGenerator(**kwargs)
elif provider.lower() == "hf":
try:
import transformers
Expand All @@ -67,7 +71,7 @@ def llm(provider: str = None, **kwargs):
import torch
except ImportError:
raise ImportError(
"Please install the `torch` package to use the HFTextGenerator class. pip install llmx[transformers]"
"Please install the `torch` package to use the HFTextGenerator class. pip install llmx[transformers]"
)

from .hf_textgen import HFTextGenerator
Expand All @@ -76,5 +80,5 @@ def llm(provider: str = None, **kwargs):

else:
raise ValueError(
f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', and 'cohere'."
)
f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'cohere', and 'anthropic'."
)
9 changes: 9 additions & 0 deletions tests/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,16 @@
{"role": "user",
"content": "What is the capital of France? Only respond with the exact answer"}]

def test_anthropic():
anthropic_gen = llm(provider="anthropic", api_key=os.environ.get("ANTHROPIC_API_KEY", None))
config.model = "claude-3-5-sonnet-20240620" # or any other Anthropic model you want to test
anthropic_response = anthropic_gen.generate(messages, config=config)
answer = anthropic_response.text[0].content
print(anthropic_response.text[0].content)

assert ("paris" in answer.lower())
assert len(anthropic_response.text) == 1

def test_openai():
openai_gen = llm(provider="openai")
openai_response = openai_gen.generate(messages, config=config)
Expand Down

0 comments on commit ba92627

Please sign in to comment.