88import sys
99from json import JSONDecodeError
1010from pathlib import Path
11+ from typing import Union , Any
1112
1213import chevron
1314from openai .types .chat import ChatCompletionMessageParam
1415from openai .types .chat .chat_completion_tool_param import ChatCompletionToolParam
16+ from pydantic_ai import Agent
17+ from pydantic_ai .models .anthropic import AnthropicModel
18+ from pydantic import BaseModel
1519
1620from patchwork .common .client .llm .protocol import LlmClient
21+ from patchwork .common .client .llm .utils import example_string_to_base_model , example_json_to_base_model
1722from patchwork .common .tools import CodeEditTool , Tool
1823from patchwork .common .tools .agentic_tools import EndTool
1924
@@ -106,22 +111,46 @@ def __init__(self, llm_client: LlmClient, tool_set: dict[str, Tool], system_prom
106111 self .history .append (dict (role = "system" , content = system_prompt ))
107112
108113
114+
115+
116+ class AgentConfig (BaseModel ):
117+ name : str
118+ tool_set : dict [str , Tool ]
119+ system_prompt : str = ''
120+
121+
109122class AgenticStrategy :
110123 def __init__ (
111124 self ,
112- llm_client : LlmClient ,
113- tool_set : dict [str , Tool ],
125+ api_key : str ,
114126 template_data : dict [str , str ],
115127 system_prompt_template : str ,
116128 user_prompt_template : str ,
129+ agent_configs : list [AgentConfig ],
130+ example_json : Union [str , dict [str , Any ]] = '{"output":"output text"}' ,
117131 * args ,
118132 ** kwargs ,
119133 ):
120- self .tool_set = dict (end = EndTool (), ** tool_set )
121134 self .__template_data = template_data
122135 self .__user_prompt_template = user_prompt_template
123- self .__assistant_role = Assistant (llm_client , self .tool_set , self .__render_prompt (system_prompt_template ))
124- self .__user_role = UserProxy (llm_client , dict ())
136+ model = AnthropicModel ("claude-3-5-sonnet-latest" , api_key = api_key )
137+ self .__user_role = Agent (
138+ model ,
139+ system_prompt = self .__render_prompt (system_prompt_template ),
140+ result_type = example_json_to_base_model (example_json ),
141+ )
142+ self .__assistants = []
143+ for assistant_config in agent_configs :
144+ tools = []
145+ for tool in assistant_config .tool_set .values ():
146+ tools .append (tool .to_pydantic_ai_function_tool ())
147+ assistant = Agent (
148+ "claude-3-5-sonnet-latest" ,
149+ system_prompt = self .__render_prompt (assistant_config .system_prompt ),
150+ tools = tools
151+ )
152+
153+ self .__assistants .append (assistant )
125154
126155 def __render_prompt (self , prompt_template : str ) -> str :
127156 chevron .render .__globals__ ["_html_escape" ] = lambda x : x
@@ -133,9 +162,6 @@ def __render_prompt(self, prompt_template: str) -> str:
133162 partials_dict = dict (),
134163 )
135164
136- def __get_initial_prompt (self ) -> list [ChatCompletionMessageParam ]:
137- return [dict (role = "user" , content = self .__render_prompt (self .__user_prompt_template ))]
138-
139165 def __is_session_completed (self ) -> bool :
140166 for message in reversed (self .__assistant_role .history ):
141167 if message .get ("tool" ) is not None :
@@ -149,9 +175,10 @@ def execute(self, limit: int | None = None) -> None:
149175 message = self .__render_prompt (self .__user_prompt_template )
150176 try :
151177 for i in range (limit or self .__limit or sys .maxsize ):
178+ self .__user_role .run_sync (self .__user_prompt_template )
152179 self .run_count = i + 1
153- for role in [self .__assistant_role , self .__user_role ]:
154- message = role .generate_reply (message )
180+ for role in [* self .__assistants , self .__user_role ]:
181+ message = role .run_sync (message )
155182 if self .__is_session_completed ():
156183 break
157184 except Exception as e :
0 commit comments