Skip to content

Commit aad3465

Browse files
Feat/bedrock agents (#203)
## LLMstudio Version X.X.X ### What was done in this PR: - Add bedrock agents ### How it was tested: - ... ### Additional notes: - Any breaking changes? - Any new dependencies added? - Any performance improvements?
2 parents 38c6880 + c183ed4 commit aad3465

File tree

3 files changed

+301
-35
lines changed

3 files changed

+301
-35
lines changed
Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,29 @@
11
from llmstudio_core.agents.data_models import (
22
AgentBase,
33
CreateAgentRequest,
4-
ResultBase,
54
RunAgentRequest,
65
RunBase,
76
)
87

98

109
class BedrockAgent(AgentBase):
11-
agentResourceRoleArn: str
12-
agentStatus: str
13-
agentVersion: str
14-
agentArn: str
10+
agent_resource_role_arn: str
11+
agent_status: str
12+
agent_arn: str
13+
agent_alias_id: str
1514

1615

1716
class BedrockRun(RunBase):
1817
session_id: str
1918
response: dict
2019

2120

22-
class BedrockResult(ResultBase):
23-
session_id: str
24-
25-
2621
class BedrockCreateAgentRequest(CreateAgentRequest):
27-
agent_resourcerole_arn: str
22+
agent_resource_role_arn: str
2823
agent_alias: str
24+
name: str
2925

3026

3127
class BedrockRunAgentRequest(RunAgentRequest):
3228
session_id: str
29+
agent_alias_id: str

libs/core/llmstudio_core/agents/bedrock/manager.py

Lines changed: 253 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,43 @@
33
import boto3
44
from llmstudio_core.agents.bedrock.data_models import (
55
BedrockAgent,
6-
BedrockResult,
6+
BedrockCreateAgentRequest,
77
BedrockRun,
8+
BedrockRunAgentRequest,
9+
)
10+
from llmstudio_core.agents.data_models import (
11+
Attachment,
12+
ImageFile,
13+
ImageFileContent,
14+
Message,
15+
ResultBase,
16+
RetrieveResultRequest,
17+
TextContent,
818
)
919
from llmstudio_core.agents.manager import AgentManager, agent_manager
20+
from llmstudio_core.exceptions import AgentError
21+
from pydantic import ValidationError
1022

11-
SERVICE = "bedrock-agent"
23+
AGENT_SERVICE = "bedrock-agent"
24+
RUNTIME_SERVICE = "bedrock-agent-runtime"
1225

1326

1427
@agent_manager
1528
class BedrockAgentManager(AgentManager):
1629
def __init__(self, **kwargs):
1730
super().__init__(**kwargs)
1831
self._client = boto3.client(
19-
SERVICE,
32+
service_name=AGENT_SERVICE,
33+
region_name=self.region if self.region else os.getenv("BEDROCK_REGION"),
34+
aws_access_key_id=self.access_key
35+
if self.access_key
36+
else os.getenv("BEDROCK_ACCESS_KEY"),
37+
aws_secret_access_key=self.secret_key
38+
if self.secret_key
39+
else os.getenv("BEDROCK_SECRET_KEY"),
40+
)
41+
self._runtime_client = boto3.client(
42+
service_name=RUNTIME_SERVICE,
2043
region_name=self.region if self.region else os.getenv("BEDROCK_REGION"),
2144
aws_access_key_id=self.access_key
2245
if self.access_key
@@ -31,31 +54,245 @@ def _agent_config_name():
3154
return "bedrock"
3255

3356
def _validate_create_request(self, request):
34-
raise NotImplementedError("Agents need to implement the method")
57+
return BedrockCreateAgentRequest(**request)
3558

3659
def _validate_run_request(self, request):
37-
raise NotImplementedError("Agents need to implement the method")
60+
return BedrockRunAgentRequest(**request)
3861

39-
def _validate_create_request(self, request):
40-
raise NotImplementedError("Agents need to implement the method")
62+
def _validate_result_request(self, request):
63+
return RetrieveResultRequest(**request)
4164

42-
def create_agent(self, **kargs) -> BedrockAgent:
65+
def create_agent(self, **kwargs) -> BedrockAgent:
4366
"""
44-
Creates a new instance of the agent.
67+
This method validates the input parameters, creates a new agent using the client,
68+
waits for the agent to reach the 'NOT_PREPARED' status, adds tools to the agent,
69+
prepares the agent for use, creates an alias for the agent, and waits for the alias
70+
to be prepared.
71+
72+
Args:
73+
**kwargs: Agent creation parameters.
74+
75+
Returns:
76+
BedrockAgent: An instance of the created BedrockAgent.
77+
78+
Raises:
79+
AgentError: If there is a validation error or if an unsupported tool type is provided.
80+
4581
"""
4682

47-
raise NotImplementedError("Agents need to implement the 'create' method.")
83+
try:
84+
agent_request = self._validate_create_request(
85+
dict(
86+
**kwargs,
87+
)
88+
)
89+
90+
except ValidationError as e:
91+
raise AgentError(str(e))
92+
93+
bedrock_create = self._client.create_agent(
94+
agentName=agent_request.name,
95+
foundationModel=agent_request.model,
96+
instruction=agent_request.instructions,
97+
agentResourceRoleArn=agent_request.agent_resource_role_arn,
98+
)
99+
100+
agentId = bedrock_create["agent"]["agentId"]
101+
102+
# Wait for agent to reach 'NOT_PREPARED' status
103+
agentStatus = ""
104+
while agentStatus != "NOT_PREPARED":
105+
response = self._client.get_agent(agentId=agentId)
106+
agentStatus = response["agent"]["agentStatus"]
107+
108+
# Add tools to the agent
109+
for tool in agent_request.tools:
110+
if tool.type == "code_interpreter":
111+
response = self._client.create_agent_action_group(
112+
actionGroupName="CodeInterpreterAction",
113+
actionGroupState="ENABLED",
114+
agentId=agentId,
115+
agentVersion="DRAFT",
116+
parentActionGroupSignature="AMAZON.CodeInterpreter",
117+
)
118+
119+
actionGroupId = response["agentActionGroup"]["actionGroupId"]
120+
121+
actionGroupStatus = ""
122+
while actionGroupStatus != "ENABLED":
123+
response = self._client.get_agent_action_group(
124+
agentId=agentId,
125+
actionGroupId=actionGroupId,
126+
agentVersion="DRAFT",
127+
)
128+
actionGroupStatus = response["agentActionGroup"]["actionGroupState"]
129+
else:
130+
raise AgentError(f"Tool {tool.get('type')} not supported")
131+
132+
# Prepare the agent for use
133+
response = self._client.prepare_agent(agentId=agentId)
134+
135+
# Wait for agent to reach 'PREPARED' status
136+
agentStatus = ""
137+
while agentStatus != "PREPARED":
138+
response = self._client.get_agent(agentId=agentId)
139+
agentStatus = response["agent"]["agentStatus"]
140+
141+
# Create an alias for the agent
142+
response = self._client.create_agent_alias(
143+
agentAliasName=agent_request.agent_alias, agentId=agentId
144+
)
145+
146+
agentAliasId = response["agentAlias"]["agentAliasId"]
147+
148+
# Wait for agent alias to be prepared
149+
agentAliasStatus = ""
150+
while agentAliasStatus != "PREPARED":
151+
response = self._client.get_agent_alias(
152+
agentId=agentId, agentAliasId=agentAliasId
153+
)
154+
agentAliasStatus = response["agentAlias"]["agentAliasStatus"]
155+
156+
return BedrockAgent(
157+
id=agentId,
158+
created_at=int(bedrock_create["agent"]["createdAt"].timestamp()),
159+
name=bedrock_create["agent"]["agentName"],
160+
description=bedrock_create.get("agent", {}).get("description", None),
161+
model=agent_request.model,
162+
instructions=bedrock_create["agent"]["instruction"],
163+
tools=agent_request.tools,
164+
agent_arn=bedrock_create["agent"]["agentArn"],
165+
agent_resource_role_arn=bedrock_create["agent"]["agentResourceRoleArn"],
166+
agent_status=bedrock_create["agent"]["agentStatus"],
167+
agent_alias_id=agentAliasId,
168+
)
48169

49170
def run_agent(self, **kwargs) -> BedrockRun:
50171
"""
51-
Runs the agent
172+
Runs the agent with the provided keyword arguments.
173+
174+
This method validates the run request and invokes the agent using the runtime client.
175+
If the validation fails, an AgentError is raised.
176+
177+
Returns:
178+
BedrockRun: An object containing the agent ID, status, session ID, and response of the run.
179+
180+
Raises:
181+
AgentError: If the run request validation fails.
52182
"""
53-
raise NotImplementedError(
54-
"Agents need to implement the 'create_thread_and_run' method."
183+
184+
try:
185+
run_request = self._validate_run_request(
186+
dict(
187+
**kwargs,
188+
)
189+
)
190+
except ValidationError as e:
191+
raise AgentError(str(e))
192+
193+
sessionState = {"files": []}
194+
195+
for attachment in run_request.message.attachments:
196+
if any(tool.type == "code_interpreter" for tool in attachment.tools):
197+
sessionState["files"].append(
198+
{
199+
"name": attachment.file_name,
200+
"source": {
201+
"byteContent": {
202+
"data": attachment.file_content,
203+
"mediaType": attachment.file_type,
204+
},
205+
"sourceType": "BYTE_CONTENT",
206+
},
207+
"useCase": "CODE_INTERPRETER",
208+
}
209+
)
210+
211+
if isinstance(run_request.message.content, str):
212+
input_text = run_request.message.content # Use it directly if it's a string
213+
elif isinstance(run_request.message.content, list):
214+
input_text = " ".join(
215+
item.text
216+
for item in run_request.message.content
217+
if isinstance(item, TextContent)
218+
)
219+
else:
220+
input_text = "" # Default to an empty string if content is not valid
221+
222+
invoke_request = self._runtime_client.invoke_agent(
223+
agentId=run_request.agent_id,
224+
agentAliasId=run_request.agent_alias_id,
225+
sessionId=run_request.session_id,
226+
inputText=input_text,
227+
sessionState=sessionState,
228+
)
229+
230+
return BedrockRun(
231+
agent_id=run_request.agent_id,
232+
status="completed",
233+
session_id=run_request.session_id,
234+
response=invoke_request,
55235
)
56236

57-
def retrieve_result(self, **kwargs) -> BedrockResult:
237+
def retrieve_result(self, **kwargs) -> ResultBase:
58238
"""
59-
Retrieves an existing agent.
239+
Retrieve the result based on the provided keyword arguments.
240+
This method validates the result request and processes the event stream to
241+
extract content and attachments. It constructs a message with the extracted
242+
content and attachments and returns it wrapped in a ResultBase object.
243+
244+
Returns:
245+
ResultBase: An object containing the constructed message with content and attachments.
246+
Raises:
247+
AgentError: If the result request validation fails.
60248
"""
61-
raise NotImplementedError("Agents need to implement the 'retrieve' method.")
249+
250+
try:
251+
result_request = self._validate_result_request(
252+
dict(
253+
**kwargs,
254+
)
255+
)
256+
257+
except ValidationError as e:
258+
raise AgentError(str(e))
259+
260+
content = []
261+
attachments = []
262+
event_stream = result_request.run.response.get("completion")
263+
for event in event_stream:
264+
if "chunk" in event:
265+
chunk = event["chunk"]
266+
if "bytes" in chunk:
267+
content.append(TextContent(text=chunk["bytes"].decode("utf-8")))
268+
269+
if "files" in event:
270+
files = event["files"]["files"]
271+
for file in files:
272+
if file["type"] == "image/png":
273+
content.append(
274+
ImageFileContent(
275+
image_file=ImageFile(
276+
file_name=file["name"],
277+
file_content=file["bytes"],
278+
file_type=file["type"],
279+
)
280+
)
281+
)
282+
else:
283+
attachments.append(
284+
Attachment(
285+
file_name=file["name"],
286+
file_content=file["bytes"],
287+
file_type=file["type"],
288+
)
289+
)
290+
291+
message = Message(
292+
thread_id=result_request.run.session_id,
293+
role="assistant",
294+
content=content,
295+
attachments=attachments,
296+
)
297+
298+
return ResultBase(message=message)

0 commit comments

Comments
 (0)