33import boto3
44from llmstudio_core .agents .bedrock .data_models import (
55 BedrockAgent ,
6- BedrockResult ,
6+ BedrockCreateAgentRequest ,
77 BedrockRun ,
8+ BedrockRunAgentRequest ,
9+ )
10+ from llmstudio_core .agents .data_models import (
11+ Attachment ,
12+ ImageFile ,
13+ ImageFileContent ,
14+ Message ,
15+ ResultBase ,
16+ RetrieveResultRequest ,
17+ TextContent ,
818)
919from llmstudio_core .agents .manager import AgentManager , agent_manager
20+ from llmstudio_core .exceptions import AgentError
21+ from pydantic import ValidationError
1022
11- SERVICE = "bedrock-agent"
23+ AGENT_SERVICE = "bedrock-agent"
24+ RUNTIME_SERVICE = "bedrock-agent-runtime"
1225
1326
1427@agent_manager
1528class BedrockAgentManager (AgentManager ):
1629 def __init__ (self , ** kwargs ):
1730 super ().__init__ (** kwargs )
1831 self ._client = boto3 .client (
19- SERVICE ,
32+ service_name = AGENT_SERVICE ,
33+ region_name = self .region if self .region else os .getenv ("BEDROCK_REGION" ),
34+ aws_access_key_id = self .access_key
35+ if self .access_key
36+ else os .getenv ("BEDROCK_ACCESS_KEY" ),
37+ aws_secret_access_key = self .secret_key
38+ if self .secret_key
39+ else os .getenv ("BEDROCK_SECRET_KEY" ),
40+ )
41+ self ._runtime_client = boto3 .client (
42+ service_name = RUNTIME_SERVICE ,
2043 region_name = self .region if self .region else os .getenv ("BEDROCK_REGION" ),
2144 aws_access_key_id = self .access_key
2245 if self .access_key
@@ -31,31 +54,245 @@ def _agent_config_name():
3154 return "bedrock"
3255
3356 def _validate_create_request (self , request ):
34- raise NotImplementedError ( "Agents need to implement the method" )
57+ return BedrockCreateAgentRequest ( ** request )
3558
3659 def _validate_run_request (self , request ):
37- raise NotImplementedError ( "Agents need to implement the method" )
60+ return BedrockRunAgentRequest ( ** request )
3861
39- def _validate_create_request (self , request ):
40- raise NotImplementedError ( "Agents need to implement the method" )
62+ def _validate_result_request (self , request ):
63+ return RetrieveResultRequest ( ** request )
4164
42- def create_agent (self , ** kargs ) -> BedrockAgent :
65+ def create_agent (self , ** kwargs ) -> BedrockAgent :
4366 """
44- Creates a new instance of the agent.
67+ This method validates the input parameters, creates a new agent using the client,
68+ waits for the agent to reach the 'NOT_PREPARED' status, adds tools to the agent,
69+ prepares the agent for use, creates an alias for the agent, and waits for the alias
70+ to be prepared.
71+
72+ Args:
73+ **kwargs: Agent creation parameters.
74+
75+ Returns:
76+ BedrockAgent: An instance of the created BedrockAgent.
77+
78+ Raises:
79+ AgentError: If there is a validation error or if an unsupported tool type is provided.
80+
4581 """
4682
47- raise NotImplementedError ("Agents need to implement the 'create' method." )
83+ try :
84+ agent_request = self ._validate_create_request (
85+ dict (
86+ ** kwargs ,
87+ )
88+ )
89+
90+ except ValidationError as e :
91+ raise AgentError (str (e ))
92+
93+ bedrock_create = self ._client .create_agent (
94+ agentName = agent_request .name ,
95+ foundationModel = agent_request .model ,
96+ instruction = agent_request .instructions ,
97+ agentResourceRoleArn = agent_request .agent_resource_role_arn ,
98+ )
99+
100+ agentId = bedrock_create ["agent" ]["agentId" ]
101+
102+ # Wait for agent to reach 'NOT_PREPARED' status
103+ agentStatus = ""
104+ while agentStatus != "NOT_PREPARED" :
105+ response = self ._client .get_agent (agentId = agentId )
106+ agentStatus = response ["agent" ]["agentStatus" ]
107+
108+ # Add tools to the agent
109+ for tool in agent_request .tools :
110+ if tool .type == "code_interpreter" :
111+ response = self ._client .create_agent_action_group (
112+ actionGroupName = "CodeInterpreterAction" ,
113+ actionGroupState = "ENABLED" ,
114+ agentId = agentId ,
115+ agentVersion = "DRAFT" ,
116+ parentActionGroupSignature = "AMAZON.CodeInterpreter" ,
117+ )
118+
119+ actionGroupId = response ["agentActionGroup" ]["actionGroupId" ]
120+
121+ actionGroupStatus = ""
122+ while actionGroupStatus != "ENABLED" :
123+ response = self ._client .get_agent_action_group (
124+ agentId = agentId ,
125+ actionGroupId = actionGroupId ,
126+ agentVersion = "DRAFT" ,
127+ )
128+ actionGroupStatus = response ["agentActionGroup" ]["actionGroupState" ]
129+ else :
130+ raise AgentError (f"Tool { tool .get ('type' )} not supported" )
131+
132+ # Prepare the agent for use
133+ response = self ._client .prepare_agent (agentId = agentId )
134+
135+ # Wait for agent to reach 'PREPARED' status
136+ agentStatus = ""
137+ while agentStatus != "PREPARED" :
138+ response = self ._client .get_agent (agentId = agentId )
139+ agentStatus = response ["agent" ]["agentStatus" ]
140+
141+ # Create an alias for the agent
142+ response = self ._client .create_agent_alias (
143+ agentAliasName = agent_request .agent_alias , agentId = agentId
144+ )
145+
146+ agentAliasId = response ["agentAlias" ]["agentAliasId" ]
147+
148+ # Wait for agent alias to be prepared
149+ agentAliasStatus = ""
150+ while agentAliasStatus != "PREPARED" :
151+ response = self ._client .get_agent_alias (
152+ agentId = agentId , agentAliasId = agentAliasId
153+ )
154+ agentAliasStatus = response ["agentAlias" ]["agentAliasStatus" ]
155+
156+ return BedrockAgent (
157+ id = agentId ,
158+ created_at = int (bedrock_create ["agent" ]["createdAt" ].timestamp ()),
159+ name = bedrock_create ["agent" ]["agentName" ],
160+ description = bedrock_create .get ("agent" , {}).get ("description" , None ),
161+ model = agent_request .model ,
162+ instructions = bedrock_create ["agent" ]["instruction" ],
163+ tools = agent_request .tools ,
164+ agent_arn = bedrock_create ["agent" ]["agentArn" ],
165+ agent_resource_role_arn = bedrock_create ["agent" ]["agentResourceRoleArn" ],
166+ agent_status = bedrock_create ["agent" ]["agentStatus" ],
167+ agent_alias_id = agentAliasId ,
168+ )
48169
49170 def run_agent (self , ** kwargs ) -> BedrockRun :
50171 """
51- Runs the agent
172+ Runs the agent with the provided keyword arguments.
173+
174+ This method validates the run request and invokes the agent using the runtime client.
175+ If the validation fails, an AgentError is raised.
176+
177+ Returns:
178+ BedrockRun: An object containing the agent ID, status, session ID, and response of the run.
179+
180+ Raises:
181+ AgentError: If the run request validation fails.
52182 """
53- raise NotImplementedError (
54- "Agents need to implement the 'create_thread_and_run' method."
183+
184+ try :
185+ run_request = self ._validate_run_request (
186+ dict (
187+ ** kwargs ,
188+ )
189+ )
190+ except ValidationError as e :
191+ raise AgentError (str (e ))
192+
193+ sessionState = {"files" : []}
194+
195+ for attachment in run_request .message .attachments :
196+ if any (tool .type == "code_interpreter" for tool in attachment .tools ):
197+ sessionState ["files" ].append (
198+ {
199+ "name" : attachment .file_name ,
200+ "source" : {
201+ "byteContent" : {
202+ "data" : attachment .file_content ,
203+ "mediaType" : attachment .file_type ,
204+ },
205+ "sourceType" : "BYTE_CONTENT" ,
206+ },
207+ "useCase" : "CODE_INTERPRETER" ,
208+ }
209+ )
210+
211+ if isinstance (run_request .message .content , str ):
212+ input_text = run_request .message .content # Use it directly if it's a string
213+ elif isinstance (run_request .message .content , list ):
214+ input_text = " " .join (
215+ item .text
216+ for item in run_request .message .content
217+ if isinstance (item , TextContent )
218+ )
219+ else :
220+ input_text = "" # Default to an empty string if content is not valid
221+
222+ invoke_request = self ._runtime_client .invoke_agent (
223+ agentId = run_request .agent_id ,
224+ agentAliasId = run_request .agent_alias_id ,
225+ sessionId = run_request .session_id ,
226+ inputText = input_text ,
227+ sessionState = sessionState ,
228+ )
229+
230+ return BedrockRun (
231+ agent_id = run_request .agent_id ,
232+ status = "completed" ,
233+ session_id = run_request .session_id ,
234+ response = invoke_request ,
55235 )
56236
57- def retrieve_result (self , ** kwargs ) -> BedrockResult :
237+ def retrieve_result (self , ** kwargs ) -> ResultBase :
58238 """
59- Retrieves an existing agent.
239+ Retrieve the result based on the provided keyword arguments.
240+ This method validates the result request and processes the event stream to
241+ extract content and attachments. It constructs a message with the extracted
242+ content and attachments and returns it wrapped in a ResultBase object.
243+
244+ Returns:
245+ ResultBase: An object containing the constructed message with content and attachments.
246+ Raises:
247+ AgentError: If the result request validation fails.
60248 """
61- raise NotImplementedError ("Agents need to implement the 'retrieve' method." )
249+
250+ try :
251+ result_request = self ._validate_result_request (
252+ dict (
253+ ** kwargs ,
254+ )
255+ )
256+
257+ except ValidationError as e :
258+ raise AgentError (str (e ))
259+
260+ content = []
261+ attachments = []
262+ event_stream = result_request .run .response .get ("completion" )
263+ for event in event_stream :
264+ if "chunk" in event :
265+ chunk = event ["chunk" ]
266+ if "bytes" in chunk :
267+ content .append (TextContent (text = chunk ["bytes" ].decode ("utf-8" )))
268+
269+ if "files" in event :
270+ files = event ["files" ]["files" ]
271+ for file in files :
272+ if file ["type" ] == "image/png" :
273+ content .append (
274+ ImageFileContent (
275+ image_file = ImageFile (
276+ file_name = file ["name" ],
277+ file_content = file ["bytes" ],
278+ file_type = file ["type" ],
279+ )
280+ )
281+ )
282+ else :
283+ attachments .append (
284+ Attachment (
285+ file_name = file ["name" ],
286+ file_content = file ["bytes" ],
287+ file_type = file ["type" ],
288+ )
289+ )
290+
291+ message = Message (
292+ thread_id = result_request .run .session_id ,
293+ role = "assistant" ,
294+ content = content ,
295+ attachments = attachments ,
296+ )
297+
298+ return ResultBase (message = message )
0 commit comments