Skip to content

Commit 070fce1

Browse files
committed
pydantic ai model
1 parent 2ba316f commit 070fce1

File tree

5 files changed

+219
-18
lines changed

5 files changed

+219
-18
lines changed

patchwork/common/client/llm/aio.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
ChatCompletionToolParam,
1010
completion_create_params,
1111
)
12-
from typing_extensions import Dict, Iterable, List, Optional, Union
13-
12+
from typing_extensions import Dict, Iterable, List, Optional, Union, AsyncIterator
13+
from pydantic_ai.messages import ModelMessage, ModelResponse
14+
from pydantic_ai.models import ModelRequestParameters, StreamedResponse, Model
15+
from pydantic_ai.settings import ModelSettings
16+
from pydantic_ai.usage import Usage
1417
from patchwork.common.client.llm.anthropic import AnthropicLlmClient
1518
from patchwork.common.client.llm.google import GoogleLlmClient
1619
from patchwork.common.client.llm.openai_ import OpenAiLlmClient
@@ -31,6 +34,64 @@ def __init__(self, *clients: LlmClient):
3134
except Exception:
3235
pass
3336

37+
def __get_model(self, model_settings: ModelSettings | None) -> None:
38+
if model_settings is None:
39+
raise ValueError("Model settings cannot be None")
40+
model_name = model_settings.get("model")
41+
if model_name is None:
42+
raise ValueError("Model must be set cannot be None")
43+
44+
return model_name
45+
46+
async def request(
47+
self,
48+
messages: list[ModelMessage],
49+
model_settings: ModelSettings | None,
50+
model_request_parameters: ModelRequestParameters,
51+
) -> tuple[ModelResponse, Usage]:
52+
model = self.__get_model(model_settings)
53+
if model is None:
54+
raise ValueError("Model cannot be unset")
55+
56+
for client in self.__clients:
57+
if client.is_model_supported(model):
58+
return await client.request(messages, model_settings, model_request_parameters)
59+
60+
client_names = [client.__class__.__name__ for client in self.__original_clients]
61+
raise ValueError(
62+
f"Model {model} is not supported by {client_names} clients. "
63+
f"Please ensure that the respective API keys are correct."
64+
)
65+
66+
async def request_stream(
67+
self,
68+
messages: list[ModelMessage],
69+
model_settings: ModelSettings | None,
70+
model_request_parameters: ModelRequestParameters,
71+
) -> AsyncIterator[StreamedResponse]:
72+
model = self.__get_model(model_settings)
73+
if model is None:
74+
raise ValueError("Model cannot be unset")
75+
76+
for client in self.__clients:
77+
if client.is_model_supported(model):
78+
yield client.request(messages, model_settings, model_request_parameters)
79+
return
80+
81+
client_names = [client.__class__.__name__ for client in self.__original_clients]
82+
raise ValueError(
83+
f"Model {model} is not supported by {client_names} clients. "
84+
f"Please ensure that the respective API keys are correct."
85+
)
86+
87+
@property
88+
def model_name(self) -> str:
89+
return "Undetermined"
90+
91+
@property
92+
def system(self) -> str:
93+
return next(iter(self.__clients)).system
94+
3495
def get_models(self) -> set[str]:
3596
return self.__supported_models
3697

patchwork/common/client/llm/anthropic.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import json
44
import time
5-
from functools import lru_cache
5+
from functools import lru_cache, cached_property
66

77
from anthropic import Anthropic
88
from anthropic.types import Message, MessageParam, TextBlockParam
@@ -14,13 +14,18 @@
1414
ChatCompletionToolParam,
1515
completion_create_params,
1616
)
17+
from pydantic_ai.messages import ModelMessage, ModelResponse
18+
from pydantic_ai.models import ModelRequestParameters, StreamedResponse, Model
19+
from pydantic_ai.settings import ModelSettings
20+
from pydantic_ai.usage import Usage
21+
from pydantic_ai.models.anthropic import AnthropicModel
1722
from openai.types.chat.chat_completion import Choice, CompletionUsage
1823
from openai.types.chat.chat_completion_message_tool_call import (
1924
ChatCompletionMessageToolCall,
2025
Function,
2126
)
2227
from openai.types.completion_usage import CompletionUsage
23-
from typing_extensions import Dict, Iterable, List, Optional, Union
28+
from typing_extensions import Dict, Iterable, List, Optional, Union, AsyncIterator
2429

2530
from patchwork.common.client.llm.protocol import NOT_GIVEN, LlmClient, NotGiven
2631

@@ -73,7 +78,46 @@ class AnthropicLlmClient(LlmClient):
7378
__100k_models = {"claude-2.0", "claude-instant-1.2"}
7479

7580
def __init__(self, api_key: str):
76-
self.client = Anthropic(api_key=api_key)
81+
self.__api_key = api_key
82+
83+
@cached_property
84+
def __client(self):
85+
return Anthropic(api_key=self.__api_key)
86+
87+
def __get_pydantic_model(self, model_settings: ModelSettings | None) -> Model:
88+
if model_settings is None:
89+
raise ValueError("Model settings cannot be None")
90+
model_name = model_settings.get("model")
91+
if model_name is None:
92+
raise ValueError("Model must be set cannot be None")
93+
94+
return AnthropicModel(model_name, api_key=self.__api_key)
95+
96+
async def request(
97+
self,
98+
messages: list[ModelMessage],
99+
model_settings: ModelSettings | None,
100+
model_request_parameters: ModelRequestParameters,
101+
) -> tuple[ModelResponse, Usage]:
102+
model = self.__get_pydantic_model(model_settings)
103+
return await model.request(messages, model_settings, model_request_parameters)
104+
105+
async def request_stream(
106+
self,
107+
messages: list[ModelMessage],
108+
model_settings: ModelSettings | None,
109+
model_request_parameters: ModelRequestParameters,
110+
) -> AsyncIterator[StreamedResponse]:
111+
model = self.__get_pydantic_model(model_settings)
112+
yield model.request_stream(messages, model_settings, model_request_parameters)
113+
114+
@property
115+
def model_name(self) -> str:
116+
return "Undetermined"
117+
118+
@property
119+
def system(self) -> str:
120+
return "anthropic"
77121

78122
def __get_model_limit(self, model: str) -> int:
79123
# it is observed that the count tokens is not accurate, so we are using a safety margin
@@ -248,7 +292,7 @@ def is_prompt_supported(
248292
for k, v in input_kwargs.items()
249293
if k in {"messages", "model", "system", "tool_choice", "tools", "beta"}
250294
}
251-
message_token_count = self.client.beta.messages.count_tokens(**count_token_input_kwargs)
295+
message_token_count = self.__client.beta.messages.count_tokens(**count_token_input_kwargs)
252296
return model_limit - message_token_count.input_tokens
253297

254298
def truncate_messages(
@@ -292,5 +336,5 @@ def chat_completion(
292336
top_p=top_p,
293337
)
294338

295-
response = self.client.messages.create(**input_kwargs)
339+
response = self.__client.messages.create(**input_kwargs)
296340
return _anthropic_to_openai_response(model, response)

patchwork/common/client/llm/google.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,13 @@
2020
ChatCompletionToolParam,
2121
completion_create_params,
2222
)
23+
from pydantic_ai.messages import ModelMessage, ModelResponse
24+
from pydantic_ai.models import ModelRequestParameters, StreamedResponse, Model
25+
from pydantic_ai.settings import ModelSettings
26+
from pydantic_ai.usage import Usage
27+
from pydantic_ai.models.gemini import GeminiModel
2328
from openai.types.chat.chat_completion import ChatCompletion, Choice
24-
from typing_extensions import Any, Dict, Iterable, List, Optional, Union
29+
from typing_extensions import Any, Dict, Iterable, List, Optional, Union, AsyncIterator
2530

2631
from patchwork.common.client.llm.protocol import NOT_GIVEN, LlmClient, NotGiven
2732
from patchwork.common.client.llm.utils import json_schema_to_model
@@ -45,6 +50,41 @@ def __init__(self, api_key: str):
4550
self.__api_key = api_key
4651
generativeai.configure(api_key=api_key)
4752

53+
def __get_pydantic_model(self, model_settings: ModelSettings | None) -> Model:
54+
if model_settings is None:
55+
raise ValueError("Model settings cannot be None")
56+
model_name = model_settings.get("model")
57+
if model_name is None:
58+
raise ValueError("Model must be set cannot be None")
59+
60+
return GeminiModel(model_name, api_key=self.__api_key)
61+
62+
async def request(
63+
self,
64+
messages: list[ModelMessage],
65+
model_settings: ModelSettings | None,
66+
model_request_parameters: ModelRequestParameters,
67+
) -> tuple[ModelResponse, Usage]:
68+
model = self.__get_pydantic_model(model_settings)
69+
return await model.request(messages, model_settings, model_request_parameters)
70+
71+
async def request_stream(
72+
self,
73+
messages: list[ModelMessage],
74+
model_settings: ModelSettings | None,
75+
model_request_parameters: ModelRequestParameters,
76+
) -> AsyncIterator[StreamedResponse]:
77+
model = self.__get_pydantic_model(model_settings)
78+
yield model.request_stream(messages, model_settings, model_request_parameters)
79+
80+
@property
81+
def model_name(self) -> str:
82+
return "Undetermined"
83+
84+
@property
85+
def system(self) -> str:
86+
return "google-gla"
87+
4888
def __get_model_limits(self, model: str) -> int:
4989
for model_info in _cached_list_model_from_google():
5090
if model_info.name == f"{self.__MODEL_PREFIX}{model}":

patchwork/common/client/llm/openai_.py

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import functools
4+
from functools import cached_property
45

56
import tiktoken
67
from openai import OpenAI
@@ -11,7 +12,12 @@
1112
ChatCompletionToolParam,
1213
completion_create_params,
1314
)
14-
from typing_extensions import Dict, Iterable, List, Optional, Union
15+
from pydantic_ai.messages import ModelMessage, ModelResponse
16+
from pydantic_ai.models import ModelRequestParameters, StreamedResponse, Model
17+
from pydantic_ai.settings import ModelSettings
18+
from pydantic_ai.usage import Usage
19+
from pydantic_ai.models.openai import OpenAIModel
20+
from typing_extensions import Dict, Iterable, List, Optional, Union, AsyncIterator
1521

1622
from patchwork.common.client.llm.protocol import NOT_GIVEN, LlmClient, NotGiven
1723
from patchwork.logger import logger
@@ -41,20 +47,59 @@ class OpenAiLlmClient(LlmClient):
4147
}
4248

4349
def __init__(self, api_key: str, base_url=None, **kwargs):
44-
self.api_key = api_key
45-
self.base_url = base_url
46-
self.client = OpenAI(api_key=api_key, base_url=base_url, **kwargs)
50+
self.__api_key = api_key
51+
self.__base_url = base_url
52+
self.__kwargs = kwargs
53+
54+
@cached_property
55+
def __client(self) -> OpenAI:
56+
return OpenAI(api_key=self.__api_key, base_url=self.__base_url, **self.__kwargs)
57+
58+
def __get_pydantic_model(self, model_settings: ModelSettings | None) -> Model:
59+
if model_settings is None:
60+
raise ValueError("Model settings cannot be None")
61+
model_name = model_settings.get("model")
62+
if model_name is None:
63+
raise ValueError("Model must be set cannot be None")
64+
65+
return OpenAIModel(model_name, base_url=self.__base_url, api_key=self.__api_key)
66+
67+
async def request(
68+
self,
69+
messages: list[ModelMessage],
70+
model_settings: ModelSettings | None,
71+
model_request_parameters: ModelRequestParameters,
72+
) -> tuple[ModelResponse, Usage]:
73+
model = self.__get_pydantic_model(model_settings)
74+
return await model.request(messages, model_settings, model_request_parameters)
75+
76+
async def request_stream(
77+
self,
78+
messages: list[ModelMessage],
79+
model_settings: ModelSettings | None,
80+
model_request_parameters: ModelRequestParameters,
81+
) -> AsyncIterator[StreamedResponse]:
82+
model = self.__get_pydantic_model(model_settings)
83+
yield model.request_stream(messages, model_settings, model_request_parameters)
84+
85+
@property
86+
def model_name(self) -> str:
87+
return "Undetermined"
88+
89+
@property
90+
def system(self) -> str | None:
91+
return "openai"
4792

4893
def __is_not_openai_url(self):
4994
# Some providers/apis only implement the chat completion endpoint.
5095
# We mainly use this to skip using the model endpoints.
51-
return self.base_url is not None and self.base_url != "https://api.openai.com/v1"
96+
return self.__base_url is not None and self.__base_url != "https://api.openai.com/v1"
5297

5398
def get_models(self) -> set[str]:
5499
if self.__is_not_openai_url():
55100
return set()
56101

57-
return _cached_list_models_from_openai(self.api_key)
102+
return _cached_list_models_from_openai(self.__api_key)
58103

59104
def is_model_supported(self, model: str) -> bool:
60105
# might not implement model endpoint
@@ -144,4 +189,4 @@ def chat_completion(
144189
top_p=top_p,
145190
)
146191

147-
return self.client.chat.completions.create(**NotGiven.remove_not_given(input_kwargs))
192+
return self.__client.chat.completions.create(**NotGiven.remove_not_given(input_kwargs))

patchwork/common/client/llm/protocol.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,25 @@
11
from __future__ import annotations
22

3+
from abc import abstractmethod
4+
from typing import Dict, Any, List
5+
6+
from pydantic_ai.models import Model
7+
38
from openai.types.chat import (
49
ChatCompletion,
510
ChatCompletionMessageParam,
611
ChatCompletionToolChoiceOptionParam,
712
ChatCompletionToolParam,
813
completion_create_params,
914
)
10-
from typing_extensions import Any, Dict, Iterable, List, Optional, Protocol, Union
15+
from typing_extensions import Any, Dict, Iterable, List, Optional, Union
1116

1217

1318
class NotGiven:
1419
...
1520

1621
@staticmethod
17-
def remove_not_given(obj: Any) -> Any:
22+
def remove_not_given(obj: Any) -> Union[None, dict[Any, Any], list[Any], Any]:
1823
if isinstance(obj, NotGiven):
1924
return None
2025
if isinstance(obj, dict):
@@ -27,13 +32,16 @@ def remove_not_given(obj: Any) -> Any:
2732
NOT_GIVEN = NotGiven()
2833

2934

30-
class LlmClient(Protocol):
35+
class LlmClient(Model):
36+
@abstractmethod
3137
def get_models(self) -> set[str]:
3238
...
3339

40+
@abstractmethod
3441
def is_model_supported(self, model: str) -> bool:
3542
...
3643

44+
@abstractmethod
3745
def is_prompt_supported(
3846
self,
3947
messages: Iterable[ChatCompletionMessageParam],
@@ -54,6 +62,7 @@ def is_prompt_supported(
5462
) -> int:
5563
...
5664

65+
@abstractmethod
5766
def truncate_messages(
5867
self, messages: Iterable[ChatCompletionMessageParam], model: str
5968
) -> Iterable[ChatCompletionMessageParam]:
@@ -118,6 +127,7 @@ def __truncate_message(message, direction_callback, min_guess, max_guess):
118127

119128
return LlmClient.__truncate_message(message, direction_callback, min_guess, max_guess)
120129

130+
@abstractmethod
121131
def chat_completion(
122132
self,
123133
messages: Iterable[ChatCompletionMessageParam],
@@ -137,3 +147,4 @@ def chat_completion(
137147
top_p: Optional[float] | NotGiven = NOT_GIVEN,
138148
) -> ChatCompletion:
139149
...
150+

0 commit comments

Comments
 (0)