Skip to content

Commit 238b79c

Browse files
committed
sav
1 parent f5eef37 commit 238b79c

File tree

10 files changed

+825
-556
lines changed

10 files changed

+825
-556
lines changed

patchwork/common/client/llm/utils.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,7 @@ def example_json_to_schema(json_example: str | dict | None) -> ResponseFormat |
6060
if json_example is None:
6161
return None
6262

63-
base_model = None
64-
if isinstance(json_example, str):
65-
base_model = __example_string_to_base_model(json_example)
66-
elif isinstance(json_example, dict):
67-
base_model = __example_dict_to_base_model(json_example)
68-
63+
base_model = example_json_to_base_model(json_example)
6964
if base_model is None:
7065
return None
7166

@@ -76,25 +71,38 @@ def base_model_to_schema(base_model: Type[BaseModel]) -> ResponseFormat:
7671
return type_to_response_format_param(base_model)
7772

7873

79-
def __example_string_to_base_model(json_example: str) -> Type[BaseModel] | None:
74+
def example_json_to_base_model(json_example: str | dict | None) -> Type[BaseModel] | None:
75+
if json_example is None:
76+
return None
77+
78+
base_model = None
79+
if isinstance(json_example, str):
80+
base_model = example_string_to_base_model(json_example)
81+
elif isinstance(json_example, dict):
82+
base_model = example_dict_to_base_model(json_example)
83+
84+
return base_model
85+
86+
87+
def example_string_to_base_model(json_example: str) -> Type[BaseModel] | None:
8088
try:
8189
example_data = json.loads(json_example)
8290
except Exception as e:
8391
logger.error(f"Failed to parse example json", e)
8492
return None
8593

86-
return __example_dict_to_base_model(example_data)
94+
return example_dict_to_base_model(example_data)
8795

8896

89-
def __example_dict_to_base_model(example_data: dict) -> Type[BaseModel]:
97+
def example_dict_to_base_model(example_data: dict) -> Type[BaseModel]:
9098
base_model_field_defs: dict[str, tuple[type | BaseModel, Field]] = dict()
9199
for example_data_key, example_data_value in example_data.items():
92100
if isinstance(example_data_value, dict):
93-
value_typing = __example_dict_to_base_model(example_data_value)
101+
value_typing = example_dict_to_base_model(example_data_value)
94102
elif isinstance(example_data_value, list):
95103
nested_value = example_data_value[0]
96104
if isinstance(nested_value, dict):
97-
nested_typing = __example_dict_to_base_model(nested_value)
105+
nested_typing = example_dict_to_base_model(nested_value)
98106
else:
99107
nested_typing = type(nested_value)
100108
value_typing = List[nested_typing]

patchwork/common/multiturn_strategy/agentic_strategy.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,17 @@
88
import sys
99
from json import JSONDecodeError
1010
from pathlib import Path
11+
from typing import Union, Any
1112

1213
import chevron
1314
from openai.types.chat import ChatCompletionMessageParam
1415
from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
16+
from pydantic_ai import Agent
17+
from pydantic_ai.models.anthropic import AnthropicModel
18+
from pydantic import BaseModel
1519

1620
from patchwork.common.client.llm.protocol import LlmClient
21+
from patchwork.common.client.llm.utils import example_string_to_base_model, example_json_to_base_model
1722
from patchwork.common.tools import CodeEditTool, Tool
1823
from patchwork.common.tools.agentic_tools import EndTool
1924

@@ -106,22 +111,46 @@ def __init__(self, llm_client: LlmClient, tool_set: dict[str, Tool], system_prom
106111
self.history.append(dict(role="system", content=system_prompt))
107112

108113

114+
115+
116+
class AgentConfig(BaseModel):
117+
name: str
118+
tool_set: dict[str, Tool]
119+
system_prompt: str = ''
120+
121+
109122
class AgenticStrategy:
110123
def __init__(
111124
self,
112-
llm_client: LlmClient,
113-
tool_set: dict[str, Tool],
125+
api_key: str,
114126
template_data: dict[str, str],
115127
system_prompt_template: str,
116128
user_prompt_template: str,
129+
agent_configs: list[AgentConfig],
130+
example_json: Union[str, dict[str, Any]] = '{"output":"output text"}',
117131
*args,
118132
**kwargs,
119133
):
120-
self.tool_set = dict(end=EndTool(), **tool_set)
121134
self.__template_data = template_data
122135
self.__user_prompt_template = user_prompt_template
123-
self.__assistant_role = Assistant(llm_client, self.tool_set, self.__render_prompt(system_prompt_template))
124-
self.__user_role = UserProxy(llm_client, dict())
136+
model = AnthropicModel("claude-3-5-sonnet-latest", api_key=api_key)
137+
self.__user_role = Agent(
138+
model,
139+
system_prompt=self.__render_prompt(system_prompt_template),
140+
result_type=example_json_to_base_model(example_json),
141+
)
142+
self.__assistants = []
143+
for assistant_config in agent_configs:
144+
tools = []
145+
for tool in assistant_config.tool_set.values():
146+
tools.append(tool.to_pydantic_ai_function_tool())
147+
assistant = Agent(
148+
"claude-3-5-sonnet-latest",
149+
system_prompt=self.__render_prompt(assistant_config.system_prompt),
150+
tools=tools
151+
)
152+
153+
self.__assistants.append(assistant)
125154

126155
def __render_prompt(self, prompt_template: str) -> str:
127156
chevron.render.__globals__["_html_escape"] = lambda x: x
@@ -133,9 +162,6 @@ def __render_prompt(self, prompt_template: str) -> str:
133162
partials_dict=dict(),
134163
)
135164

136-
def __get_initial_prompt(self) -> list[ChatCompletionMessageParam]:
137-
return [dict(role="user", content=self.__render_prompt(self.__user_prompt_template))]
138-
139165
def __is_session_completed(self) -> bool:
140166
for message in reversed(self.__assistant_role.history):
141167
if message.get("tool") is not None:
@@ -149,9 +175,10 @@ def execute(self, limit: int | None = None) -> None:
149175
message = self.__render_prompt(self.__user_prompt_template)
150176
try:
151177
for i in range(limit or self.__limit or sys.maxsize):
178+
self.__user_role.run_sync(self.__user_prompt_template)
152179
self.run_count = i + 1
153-
for role in [self.__assistant_role, self.__user_role]:
154-
message = role.generate_reply(message)
180+
for role in [*self.__assistants, self.__user_role]:
181+
message = role.run_sync(message)
155182
if self.__is_session_completed():
156183
break
157184
except Exception as e:

patchwork/common/tools/code_edit_tools.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
from __future__ import annotations
22

3+
import os
34
from pathlib import Path
4-
from typing import Literal
5+
from typing import Literal, Optional
56

6-
from patchwork.common.tools.tool import Tool
77
from patchwork.common.utils.utils import detect_newline
8+
from pydantic import BaseModel
9+
from pydantic_ai import Agent, RunContext
10+
from pydantic_ai.models.test import TestModel
11+
from pydantic_ai.tools import Tool, ToolDefinition
812

913

1014
class CodeEditTool(Tool, tool_name="code_edit_tool"):

patchwork/common/tools/tool.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from abc import ABC, abstractmethod
22
from typing import Type
3+
from pydantic_ai.tools import ToolDefinition, Tool as PydanticTool, RunContext
34

45

56
class Tool(ABC):
@@ -42,3 +43,12 @@ def get_description(tooling: "ToolProtocol") -> str:
4243
@staticmethod
4344
def get_parameters(tooling: "ToolProtocol") -> str:
4445
return ", ".join(tooling.json_schema.get("required", []))
46+
47+
def to_pydantic_ai_function_tool(self) -> PydanticTool[None]:
48+
async def _prep(ctx: RunContext[None], tool_def: ToolDefinition) -> ToolDefinition:
49+
tool_def.name = self.name
50+
tool_def.description = self.json_schema.get("description", "")
51+
tool_def.parameters_json_schema = self.json_schema.get("input_schema", {})
52+
return tool_def
53+
54+
return PydanticTool(self.execute, prepare=_prep)

patchwork/patchflows/LogAnalysis/LogAnalysis.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,11 @@
1-
from enum import IntEnum
21
from pathlib import Path
32

43
import yaml
54

65
from patchwork.common.utils.progress_bar import PatchflowProgressBar
76
from patchwork.common.utils.step_typing import validate_steps_with_inputs
8-
from patchwork.logger import logger
97
from patchwork.step import Step
10-
from patchwork.steps import (
11-
LLM,
12-
PR,
13-
CallLLM,
14-
CommitChanges,
15-
CreatePR,
16-
ExtractCode,
17-
ExtractModelResponse,
18-
ModifyCode,
19-
PreparePR,
20-
PreparePrompt,
21-
ScanSemgrep,
22-
ScanSonar, CallSQL, AgenticLLM,
23-
)
8+
from patchwork.steps import AgenticLLM, CallSQL
249

2510
_DEFAULT_INPUT_FILE = Path(__file__).parent / "defaults.yml"
2611

patchwork/step.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,6 @@ def __init__(self, inputs: DataPoint):
7373
self.run = self.__managed_run
7474

7575
def __init_subclass__(cls, input_class: Optional[Type] = None, output_class: Optional[Type] = None, **kwargs):
76-
if cls.__name__ == "PreparePR":
77-
print(1)
7876
input_class = input_class or getattr(cls, "input_class", None)
7977
if input_class is not None and not is_typeddict(input_class):
8078
input_class = None

patchwork/steps/AgenticLLM/AgenticLLM.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from pathlib import Path
22

33
from patchwork.common.client.llm.aio import AioLlmClient
4-
from patchwork.common.multiturn_strategy.agentic_strategy import AgenticStrategy
4+
from patchwork.common.multiturn_strategy.agentic_strategy import AgenticStrategy, AgentConfig
55
from patchwork.common.tools import Tool
66
from patchwork.step import Step
77
from patchwork.steps.AgenticLLM.typed import AgenticLLMInputs, AgenticLLMOutputs
@@ -15,13 +15,16 @@ def __init__(self, inputs):
1515
base_path = str(Path.cwd())
1616
self.conversation_limit = int(int(inputs.get("max_llm_calls", 2)) / 2)
1717
self.agentic_strategy = AgenticStrategy(
18-
llm_client=AioLlmClient.create_aio_client(inputs),
19-
tool_set=Tool.get_tools(path=base_path),
18+
api_key=inputs.get("anthropic_api_key"),
2019
template_data=inputs.get("prompt_value"),
21-
system_prompt_template=inputs.get("system_prompt"),
20+
system_prompt_template="",
2221
user_prompt_template=inputs.get("user_prompt"),
22+
agent_configs=[
23+
AgentConfig(name="", tool_set=Tool.get_tools(path=base_path), system_prompt=inputs.get("system_prompt"))
24+
]
2325
)
2426

27+
2528
def run(self) -> dict:
2629
self.agentic_strategy.execute(limit=self.conversation_limit)
2730
return dict(

patchwork/steps/CallShell/CallShell.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from patchwork.common.utils.utils import mustache_render
99
from patchwork.logger import logger
1010
from patchwork.step import Step, StepStatus
11-
from patchwork.steps import CallSQL
1211
from patchwork.steps.CallShell.typed import CallShellInputs, CallShellOutputs
1312

1413

0 commit comments

Comments
 (0)