Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions patchwork/common/client/llm/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ def create_aio_client(inputs) -> "AioLlmClient" | None:
clients.append(client)

anthropic_key = inputs.get("anthropic_api_key")
if anthropic_key is not None:
client = AnthropicLlmClient(anthropic_key)
if anthropic_key is not None or "is_aws" in client_args.keys():
client = AnthropicLlmClient(anthropic_key, is_aws=client_args.get("is_aws", False))
clients.append(client)

if len(clients) == 0:
Expand Down
41 changes: 32 additions & 9 deletions patchwork/common/client/llm/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from functools import cached_property, lru_cache
from pathlib import Path

from anthropic import Anthropic
import boto3
from anthropic import Anthropic, AnthropicBedrock
from anthropic.types import Message, MessageParam, TextBlockParam
from openai.types.chat import (
ChatCompletion,
Expand All @@ -24,6 +25,7 @@
from pydantic_ai.messages import ModelMessage, ModelResponse
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse
from pydantic_ai.models.anthropic import AnthropicModel
from pydantic_ai.models.bedrock import BedrockConverseModel
from pydantic_ai.settings import ModelSettings
from pydantic_ai.usage import Usage
from typing_extensions import AsyncIterator, Dict, Iterable, List, Optional, Union
Expand Down Expand Up @@ -74,25 +76,32 @@ def _anthropic_to_openai_response(model: str, anthropic_response: Message) -> Ch


class AnthropicLlmClient(LlmClient):
__allowed_model_prefix = "claude-3-"
__definitely_allowed_models = {"claude-2.0", "claude-2.1", "claude-instant-1.2"}
__non_aws_alias = {"claude-3-7-sonnet-latest", "claude-3-5-haiku-latest", "claude-3-5-sonnet-latest", "claude-3-opus-latest"}
__100k_models = {"claude-2.0", "claude-instant-1.2"}

def __init__(self, api_key: str):
def __init__(self, api_key: Optional[str] = None, is_aws: bool = False):
self.__api_key = api_key
self.__is_aws = is_aws
if self.__api_key is None and not is_aws:
raise ValueError("api_key is required if is_aws is False")

@cached_property
def __client(self):
return Anthropic(api_key=self.__api_key)
if not self.__is_aws:
return Anthropic(api_key=self.__api_key)
else:
return AnthropicBedrock()

def __get_pydantic_model(self, model_settings: ModelSettings | None) -> Model:
if model_settings is None:
raise ValueError("Model settings cannot be None")
model_name = model_settings.get("model")
if model_name is None:
raise ValueError("Model must be set cannot be None")

return AnthropicModel(model_name, api_key=self.__api_key)
if not self.__is_aws:
return AnthropicModel(model_name, api_key=self.__api_key)
else:
return BedrockConverseModel(model_name)

async def request(
self,
Expand Down Expand Up @@ -247,10 +256,24 @@ def __adapt_chat_completion_request(

@lru_cache(maxsize=None)
def get_models(self) -> set[str]:
return self.__definitely_allowed_models.union(set(f"{self.__allowed_model_prefix}*"))
rv = set()
if not self.__is_aws:
rv.update(self.__non_aws_alias)
for model_info in self.__client.models.list():
rv.add(model_info.id)
else:
bedrock = boto3.client(service_name="bedrock")
response = bedrock.list_foundation_models(byProvider="anthropic")
for model_info in response["modelSummaries"]:
rv.add(model_info["modelId"])

return rv

def is_model_supported(self, model: str) -> bool:
return model in self.__definitely_allowed_models or model.startswith(self.__allowed_model_prefix)
if not self.__is_aws:
return model in self.get_models()
else:
return any(True for model_id in self.get_models() if model.endswith(model_id))

def is_prompt_supported(
self,
Expand Down
2 changes: 2 additions & 0 deletions patchwork/steps/FileAgent/FileAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self, inputs):

self.strat_kwargs = dict(
model="claude-3-5-sonnet-latest",
# model="apac.anthropic.claude-3-5-sonnet-20241022-v2:0",
llm_client=AioLlmClient.create_aio_client(inputs),
template_data=dict(),
system_prompt_template=f"""\
Expand All @@ -36,6 +37,7 @@ def __init__(self, inputs):
AgentConfig(
name="Assistant",
model="claude-3-7-sonnet-latest",
# model="apac.anthropic.claude-3-5-sonnet-20241022-v2:0",
tool_set=dict(),
system_prompt="""\
You are a assistant that is supposed to help me with a set of files. These files are commonly tabular formatted like csv, xls or xlsx.
Expand Down
Loading