|
2 | 2 |
|
3 | 3 | import json |
4 | 4 | import time |
5 | | -from functools import lru_cache |
| 5 | +from functools import lru_cache, cached_property |
6 | 6 |
|
7 | 7 | from anthropic import Anthropic |
8 | 8 | from anthropic.types import Message, MessageParam, TextBlockParam |
|
14 | 14 | ChatCompletionToolParam, |
15 | 15 | completion_create_params, |
16 | 16 | ) |
| 17 | +from pydantic_ai.messages import ModelMessage, ModelResponse |
| 18 | +from pydantic_ai.models import ModelRequestParameters, StreamedResponse, Model |
| 19 | +from pydantic_ai.settings import ModelSettings |
| 20 | +from pydantic_ai.usage import Usage |
| 21 | +from pydantic_ai.models.anthropic import AnthropicModel |
17 | 22 | from openai.types.chat.chat_completion import Choice, CompletionUsage |
18 | 23 | from openai.types.chat.chat_completion_message_tool_call import ( |
19 | 24 | ChatCompletionMessageToolCall, |
20 | 25 | Function, |
21 | 26 | ) |
22 | 27 | from openai.types.completion_usage import CompletionUsage |
23 | | -from typing_extensions import Dict, Iterable, List, Optional, Union |
| 28 | +from typing_extensions import Dict, Iterable, List, Optional, Union, AsyncIterator |
24 | 29 |
|
25 | 30 | from patchwork.common.client.llm.protocol import NOT_GIVEN, LlmClient, NotGiven |
26 | 31 |
|
@@ -73,7 +78,46 @@ class AnthropicLlmClient(LlmClient): |
73 | 78 | __100k_models = {"claude-2.0", "claude-instant-1.2"} |
74 | 79 |
|
75 | 80 | def __init__(self, api_key: str): |
76 | | - self.client = Anthropic(api_key=api_key) |
| 81 | + self.__api_key = api_key |
| 82 | + |
| 83 | + @cached_property |
| 84 | + def __client(self): |
| 85 | + return Anthropic(api_key=self.__api_key) |
| 86 | + |
| 87 | + def __get_pydantic_model(self, model_settings: ModelSettings | None) -> Model: |
| 88 | + if model_settings is None: |
| 89 | + raise ValueError("Model settings cannot be None") |
| 90 | + model_name = model_settings.get("model") |
| 91 | + if model_name is None: |
| 92 | + raise ValueError("Model must be set cannot be None") |
| 93 | + |
| 94 | + return AnthropicModel(model_name, api_key=self.__api_key) |
| 95 | + |
| 96 | + async def request( |
| 97 | + self, |
| 98 | + messages: list[ModelMessage], |
| 99 | + model_settings: ModelSettings | None, |
| 100 | + model_request_parameters: ModelRequestParameters, |
| 101 | + ) -> tuple[ModelResponse, Usage]: |
| 102 | + model = self.__get_pydantic_model(model_settings) |
| 103 | + return await model.request(messages, model_settings, model_request_parameters) |
| 104 | + |
| 105 | + async def request_stream( |
| 106 | + self, |
| 107 | + messages: list[ModelMessage], |
| 108 | + model_settings: ModelSettings | None, |
| 109 | + model_request_parameters: ModelRequestParameters, |
| 110 | + ) -> AsyncIterator[StreamedResponse]: |
| 111 | + model = self.__get_pydantic_model(model_settings) |
| 112 | + yield model.request_stream(messages, model_settings, model_request_parameters) |
| 113 | + |
| 114 | + @property |
| 115 | + def model_name(self) -> str: |
| 116 | + return "Undetermined" |
| 117 | + |
| 118 | + @property |
| 119 | + def system(self) -> str: |
| 120 | + return "anthropic" |
77 | 121 |
|
78 | 122 | def __get_model_limit(self, model: str) -> int: |
79 | 123 | # it is observed that the count tokens is not accurate, so we are using a safety margin |
@@ -248,7 +292,7 @@ def is_prompt_supported( |
248 | 292 | for k, v in input_kwargs.items() |
249 | 293 | if k in {"messages", "model", "system", "tool_choice", "tools", "beta"} |
250 | 294 | } |
251 | | - message_token_count = self.client.beta.messages.count_tokens(**count_token_input_kwargs) |
| 295 | + message_token_count = self.__client.beta.messages.count_tokens(**count_token_input_kwargs) |
252 | 296 | return model_limit - message_token_count.input_tokens |
253 | 297 |
|
254 | 298 | def truncate_messages( |
@@ -292,5 +336,5 @@ def chat_completion( |
292 | 336 | top_p=top_p, |
293 | 337 | ) |
294 | 338 |
|
295 | | - response = self.client.messages.create(**input_kwargs) |
| 339 | + response = self.__client.messages.create(**input_kwargs) |
296 | 340 | return _anthropic_to_openai_response(model, response) |
0 commit comments