Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions patchwork/common/client/llm/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing_extensions import AsyncIterator, Dict, Iterable, List, Optional, Union

from patchwork.common.client.llm.anthropic import AnthropicLlmClient
from patchwork.common.client.llm.google import GoogleLlmClient
from patchwork.common.client.llm.google_ import GoogleLlmClient
from patchwork.common.client.llm.openai_ import OpenAiLlmClient
from patchwork.common.client.llm.protocol import NOT_GIVEN, LlmClient, NotGiven
from patchwork.common.constants import DEFAULT_PATCH_URL
Expand All @@ -31,10 +31,10 @@ def __init__(self, *clients: LlmClient):
self.__supported_models = set()
for client in clients:
try:
self.__supported_models.update(client.get_models())
client.test()
self.__clients.append(client)
except Exception:
pass
except Exception as e:
logger.error(f"{client.__class__.__name__} Failed with exception: {e}")

def __get_model(self, model_settings: ModelSettings | None) -> Optional[str]:
if model_settings is None:
Expand All @@ -45,6 +45,9 @@ def __get_model(self, model_settings: ModelSettings | None) -> Optional[str]:

return model_name

def test(self) -> None:
pass

async def request(
self,
messages: list[ModelMessage],
Expand Down Expand Up @@ -94,9 +97,6 @@ def model_name(self) -> str:
def system(self) -> str:
return next(iter(self.__clients)).system

def get_models(self) -> set[str]:
return self.__supported_models

def is_model_supported(self, model: str) -> bool:
return any(client.is_model_supported(model) for client in self.__clients)

Expand Down Expand Up @@ -216,8 +216,9 @@ def create_aio_client(inputs) -> "AioLlmClient" | None:
clients.append(client)

google_key = inputs.get("google_api_key")
if google_key is not None:
client = GoogleLlmClient(google_key, **client_args)
is_gcp = bool(client_args.get("is_gcp") or os.environ.get("GOOGLE_GENAI_USE_VERTEXAI") or False)
if google_key is not None or is_gcp:
client = GoogleLlmClient(api_key=google_key, is_gcp=is_gcp)
clients.append(client)

anthropic_key = inputs.get("anthropic_api_key")
Expand Down
7 changes: 3 additions & 4 deletions patchwork/common/client/llm/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json
import time
from functools import cached_property, lru_cache
from functools import cached_property
from pathlib import Path

from anthropic import Anthropic
Expand Down Expand Up @@ -245,9 +245,8 @@ def __adapt_chat_completion_request(

return NotGiven.remove_not_given(input_kwargs)

@lru_cache(maxsize=None)
def get_models(self) -> set[str]:
return self.__definitely_allowed_models.union(set(f"{self.__allowed_model_prefix}*"))
def test(self):
return

def is_model_supported(self, model: str) -> bool:
return model in self.__definitely_allowed_models or model.startswith(self.__allowed_model_prefix)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from __future__ import annotations

import os
import time
from functools import lru_cache
from functools import lru_cache, partial
from pathlib import Path

import magic
import vertexai
from google import genai
from google.auth.exceptions import GoogleAuthError
from google.genai import types
from google.genai.errors import APIError
from google.genai.types import (
CountTokensConfig,
File,
Expand All @@ -26,7 +30,8 @@
from openai.types.chat.chat_completion import ChatCompletion, Choice
from pydantic import BaseModel
from pydantic_ai.messages import ModelMessage, ModelResponse
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse
from pydantic_ai.models import Model as PydanticAiModel
from pydantic_ai.models import ModelRequestParameters, StreamedResponse
from pydantic_ai.models.gemini import GeminiModel
from pydantic_ai.settings import ModelSettings
from pydantic_ai.usage import Usage
Expand All @@ -40,9 +45,11 @@
Type,
Union,
)
from vertexai.generative_models import GenerativeModel, SafetySetting

from patchwork.common.client.llm.protocol import NOT_GIVEN, LlmClient, NotGiven
from patchwork.common.client.llm.utils import json_schema_to_model
from patchwork.logger import logger


class GoogleLlmClient(LlmClient):
Expand All @@ -51,30 +58,63 @@ class GoogleLlmClient(LlmClient):
dict(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="BLOCK_NONE"),
dict(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_NONE"),
dict(category="HARM_CATEGORY_HARASSMENT", threshold="BLOCK_NONE"),
dict(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="BLOCK_NONE"),
]
__VERTEX_SAFETY_SETTINGS = [
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold=SafetySetting.HarmBlockThreshold.OFF,
),
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=SafetySetting.HarmBlockThreshold.OFF,
),
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold=SafetySetting.HarmBlockThreshold.OFF,
),
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=SafetySetting.HarmBlockThreshold.OFF
),
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY,
threshold=SafetySetting.HarmBlockThreshold.OFF,
),
]
__MODEL_PREFIX = "models/"

def __init__(self, api_key: str, location: Optional[str] = None):
def __init__(self, api_key: Optional[str] = None, is_gcp: bool = False):
self.__api_key = api_key
self.__location = location
self.client = genai.Client(api_key=api_key, location=location)
self.__is_gcp = is_gcp
if not self.__is_gcp:
self.client = genai.Client(api_key=api_key)
else:
self.client = genai.Client(api_key=api_key, vertexai=True)
location = os.environ.get("GOOGLE_CLOUD_LOCATION", "global")
vertexai.init(
project=os.environ.get("GOOGLE_CLOUD_PROJECT"),
location=location,
api_endpoint=f"{location}-aiplatform.googleapis.com",
)

@lru_cache(maxsize=1)
def __get_models_info(self) -> list[Model]:
return list(self.client.models.list())
if not self.__is_gcp:
return list(self.client.models.list())
else:
return list()

def __get_pydantic_model(self, model_settings: ModelSettings | None) -> Model:
def __get_pydantic_model(self, model_settings: ModelSettings | None) -> PydanticAiModel:
if model_settings is None:
raise ValueError("Model settings cannot be None")
model_name = model_settings.get("model")
if model_name is None:
raise ValueError("Model must be set cannot be None")

if self.__location is None:
if not self.__is_gcp:
return GeminiModel(model_name, api_key=self.__api_key)

url_template = f"https://{self.__location}-generativelanguage.googleapis.com/v1beta/models/{{model}}:"
return GeminiModel(model_name, api_key=self.__api_key, url_template=url_template)
else:
return GeminiModel(model_name, provider="google-vertex")

async def request(
self,
Expand Down Expand Up @@ -108,12 +148,15 @@ def __get_model_limits(self, model: str) -> int:
return model_info.input_token_limit
return 1_000_000

@lru_cache
def get_models(self) -> set[str]:
return {model_info.name.removeprefix(self.__MODEL_PREFIX) for model_info in self.__get_models_info()}
def test(self):
return

def is_model_supported(self, model: str) -> bool:
return model in self.get_models()
if not self.__is_gcp:
model_names = {model_info.name.removeprefix(self.__MODEL_PREFIX) for model_info in self.__get_models_info()}
return model in model_names
else:
return True

def __upload(self, file: Path | NotGiven) -> Part | File | None:
if isinstance(file, NotGiven):
Expand Down Expand Up @@ -163,6 +206,8 @@ def is_prompt_supported(
top_p: Optional[float] | NotGiven = NOT_GIVEN,
file: Path | NotGiven = NOT_GIVEN,
) -> int:
if self.__is_gcp:
return 1
system, contents = self.__openai_messages_to_google_messages(messages)

file_ref = self.__upload(file)
Expand All @@ -178,7 +223,12 @@ def is_prompt_supported(
),
)
token_count = token_response.total_tokens
except GoogleAuthError:
raise
except APIError:
raise
except Exception as e:
logger.debug(f"Error during token count at GoogleLlmClient: {e}")
return -1
model_limit = self.__get_model_limits(model)
return model_limit - token_count
Expand Down Expand Up @@ -245,15 +295,25 @@ def chat_completion(
if file_ref is not None:
contents.append(file_ref)

response = self.client.models.generate_content(
model=model,
contents=contents,
config=GenerateContentConfig(
system_instruction=system_content,
safety_settings=self.__SAFETY_SETTINGS,
**NotGiven.remove_not_given(generation_dict),
),
)
if not self.__is_gcp:
generate_content_func = partial(
self.client.models.generate_content,
model=model,
config=GenerateContentConfig(
system_instruction=system_content,
safety_settings=self.__SAFETY_SETTINGS,
**NotGiven.remove_not_given(generation_dict),
),
)
else:
vertexai_model = GenerativeModel(model, system_instruction=system_content)
generate_content_func = partial(
vertexai_model.generate_content,
safety_settings=self.__VERTEX_SAFETY_SETTINGS,
generation_config=NotGiven.remove_not_given(generation_dict),
)

response = generate_content_func(contents=contents)
return self.__google_response_to_openai_response(response, model)

@staticmethod
Expand Down
9 changes: 5 additions & 4 deletions patchwork/common/client/llm/openai_.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,18 @@ def __is_not_openai_url(self):
# We mainly use this to skip using the model endpoints.
return self.__base_url is not None and self.__base_url != "https://api.openai.com/v1"

def get_models(self) -> set[str]:
def test(self):
if self.__is_not_openai_url():
return set()
return

return _cached_list_models_from_openai(self.__api_key)
_cached_list_models_from_openai(self.__api_key)
return

def is_model_supported(self, model: str) -> bool:
# might not implement model endpoint
if self.__is_not_openai_url():
return True
return model in self.get_models()
return model in _cached_list_models_from_openai(self.__api_key)

def __get_model_limits(self, model: str) -> int:
return self.__MODEL_LIMITS.get(model, 128_000)
Expand Down
2 changes: 1 addition & 1 deletion patchwork/common/client/llm/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def remove_not_given(obj: Any) -> Union[None, dict[Any, Any], list[Any], Any]:

class LlmClient(Model):
@abstractmethod
def get_models(self) -> set[str]:
def test(self) -> None:
...

@abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion patchwork/common/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from patchwork.common.client.llm.aio import AioLlmClient
from patchwork.common.client.llm.anthropic import AnthropicLlmClient
from patchwork.common.client.llm.google import GoogleLlmClient
from patchwork.common.client.llm.google_ import GoogleLlmClient
from patchwork.common.client.llm.openai_ import OpenAiLlmClient

app = FastAPI()
Expand Down
11 changes: 7 additions & 4 deletions patchwork/steps/AgenticLLM/typed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ class AgenticLLMInputs(TypedDict, total=False):
user_prompt: str
max_llm_calls: Annotated[int, StepTypeConfig(is_config=True)]
openai_api_key: Annotated[
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "anthropic_api_key"])
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "anthropic_api_key"])
]
anthropic_api_key: Annotated[
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "openai_api_key"])
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "openai_api_key"])
]
patched_api_key: Annotated[
str,
StepTypeConfig(
is_config=True,
or_op=["openai_api_key", "google_api_key", "anthropic_api_key"],
or_op=["openai_api_key", "google_api_key", "client_is_gcp", "anthropic_api_key"],
msg=f"""\
Model API key not found.
Please login at: "{TOKEN_URL}"
Expand All @@ -31,7 +31,10 @@ class AgenticLLMInputs(TypedDict, total=False):
),
]
google_api_key: Annotated[
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key"])
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "client_is_gcp"])
]
client_is_gcp: Annotated[
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "google_api_key"])
]


Expand Down
11 changes: 7 additions & 4 deletions patchwork/steps/CallLLM/typed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@ class CallLLMInputs(TypedDict, total=False):
model_args: Annotated[str, StepTypeConfig(is_config=True)]
client_args: Annotated[str, StepTypeConfig(is_config=True)]
openai_api_key: Annotated[
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "anthropic_api_key"])
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "anthropic_api_key"])
]
anthropic_api_key: Annotated[
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "openai_api_key"])
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "client_is_gcp", "openai_api_key"])
]
patched_api_key: Annotated[
str,
StepTypeConfig(
is_config=True,
or_op=["openai_api_key", "google_api_key", "anthropic_api_key"],
or_op=["openai_api_key", "google_api_key", "client_is_gcp", "anthropic_api_key"],
msg=f"""\
Model API key not found.
Please login at: "{TOKEN_URL}"
Expand All @@ -33,7 +33,10 @@ class CallLLMInputs(TypedDict, total=False):
),
]
google_api_key: Annotated[
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key"])
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "client_is_gcp"])
]
client_is_gcp: Annotated[
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key", "google_api_key"])
]
file: Annotated[str, StepTypeConfig(is_path=True)]

Expand Down
10 changes: 1 addition & 9 deletions patchwork/steps/FileAgent/typed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,7 @@ class FileAgentInputs(__ReconcilationAgentRequiredInputs, total=False):
base_path: str
prompt_value: Dict[str, Any]
max_llm_calls: Annotated[int, StepTypeConfig(is_config=True)]
openai_api_key: Annotated[
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "anthropic_api_key"])
]
anthropic_api_key: Annotated[
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "google_api_key", "openai_api_key"])
]
google_api_key: Annotated[
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key"])
]
anthropic_api_key: Annotated[str, StepTypeConfig(is_config=True)]


class FileAgentOutputs(TypedDict):
Expand Down
Loading