Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
348 changes: 273 additions & 75 deletions libs/python/agent/agent/adapters/cua_adapter.py

Large diffs are not rendered by default.

53 changes: 41 additions & 12 deletions libs/python/agent/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,7 @@ async def run(
"messages": preprocessed_messages,
"model": self.model,
"tools": self.tool_schemas,
"stream": False,
"stream": stream,
"computer_handler": self.computer_handler,
"max_retries": self.max_retries,
"use_prompt_caching": self.use_prompt_caching,
Expand Down Expand Up @@ -930,32 +930,61 @@ def contains_image_content(msgs):
# ---------------------------------

# Run agent loop iteration
result = await self.agent_loop.predict_step(
predict_result = await self.agent_loop.predict_step(
**loop_kwargs,
_on_api_start=self._on_api_start,
_on_api_end=self._on_api_end,
_on_usage=self._on_usage,
_on_screenshot=self._on_screenshot,
)
result = get_json(result)

# Lifecycle hook: Postprocess messages after the LLM call
# Use cases:
# - PII deanonymization (if you want tool calls to see PII)
result["output"] = await self._on_llm_end(result.get("output", []))
await self._on_responses(loop_kwargs, result)
# Handle streaming vs non-streaming response
if stream and hasattr(predict_result, "__aiter__"):
# Streaming: iterate over async generator
accumulated_output: List[Dict[str, Any]] = []
async for chunk in predict_result:
chunk = get_json(chunk)
chunk_output = chunk.get("output", [])

# Postprocess chunk output
chunk["output"] = await self._on_llm_end(chunk_output)

# Yield streaming chunk
yield chunk

# Accumulate output for handling computer actions
accumulated_output.extend(chunk["output"])

# Check for completed status
if chunk.get("status") == "completed":
break

# Set result for handling computer actions
result = {"output": accumulated_output, "usage": chunk.get("usage", {})}

# Call the responses callback for trajectory saving (same as non-streaming path)
await self._on_responses(loop_kwargs, result)
else:
# Non-streaming: single response
result = get_json(predict_result)

# Lifecycle hook: Postprocess messages after the LLM call
# Use cases:
# - PII deanonymization (if you want tool calls to see PII)
result["output"] = await self._on_llm_end(result.get("output", []))
await self._on_responses(loop_kwargs, result)

# Yield agent response
yield result
# Yield agent response
yield result

# Add agent response to new_items
new_items += result.get("output")
new_items += result.get("output", [])

# Get output call ids
output_call_ids = get_output_call_ids(result.get("output", []))

# Handle computer actions
for item in result.get("output"):
for item in result.get("output", []):
partial_items = await self._handle_item(
item, self.computer_handler, ignore_call_ids=output_call_ids
)
Expand Down
262 changes: 239 additions & 23 deletions libs/python/agent/agent/loops/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import asyncio
import json
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Tuple, Union

import litellm
from litellm.responses.litellm_completion_transformation.transformation import (
Expand Down Expand Up @@ -1493,16 +1493,17 @@ async def predict_step(
stream: bool = False,
computer_handler=None,
use_prompt_caching: Optional[bool] = False,
_on_api_start=None,
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
_on_api_start: Optional[Callable] = None,
_on_api_end: Optional[Callable] = None,
_on_usage: Optional[Callable] = None,
_on_screenshot: Optional[Callable] = None,
**kwargs,
) -> Dict[str, Any]:
) -> Union[Dict[str, Any], AsyncGenerator[Dict[str, Any], None]]:
"""
Anthropic hosted tools agent loop using liteLLM acompletion.

Supports Anthropic's computer use models with hosted tools.
When stream=True, returns an AsyncGenerator that yields partial results.
"""
tools = tools or []

Expand Down Expand Up @@ -1538,28 +1539,243 @@ async def predict_step(
if _on_api_start:
await _on_api_start(api_kwargs)

# Use liteLLM acompletion
if stream:
return self._predict_step_streaming(
api_kwargs=api_kwargs,
_on_api_end=_on_api_end,
_on_usage=_on_usage,
)
else:
# Non-streaming path
response = await litellm.acompletion(**api_kwargs)

# Call API end hook
if _on_api_end:
await _on_api_end(api_kwargs, response)

# Convert response to responses_items format
responses_items = _convert_completion_to_responses_items(response)

# Extract usage information
responses_usage = {
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage(
response.usage
).model_dump(),
"response_cost": response._hidden_params.get("response_cost", 0.0),
}
if _on_usage:
await _on_usage(responses_usage)

# Return in AsyncAgentConfig format
return {"output": responses_items, "usage": responses_usage}

async def _predict_step_streaming(
self,
api_kwargs: Dict[str, Any],
_on_api_end: Optional[Callable] = None,
_on_usage: Optional[Callable] = None,
) -> AsyncGenerator[Dict[str, Any], None]:
"""
Internal streaming implementation for predict_step.

Yields partial results as they arrive from the API.
"""
# Use liteLLM acompletion with streaming
response = await litellm.acompletion(**api_kwargs)

# Call API end hook
if _on_api_end:
await _on_api_end(api_kwargs, response)
collected_content: List[Dict[str, Any]] = []
collected_tool_calls: List[Dict[str, Any]] = []
collected_usage: Dict[str, Any] = {}

async for chunk in response:
# Process streaming chunk
if hasattr(chunk, "choices") and chunk.choices:
choice = chunk.choices[0]
delta = getattr(choice, "delta", None)

if delta:
# Handle text content
if hasattr(delta, "content") and delta.content:
text_item = make_output_text_item(delta.content)
yield {"output": [text_item], "usage": {}, "status": "streaming"}

# Handle tool calls
if hasattr(delta, "tool_calls") and delta.tool_calls:
for tool_call_delta in delta.tool_calls:
# Accumulate tool call data
if hasattr(tool_call_delta, "index"):
idx = tool_call_delta.index
while len(collected_tool_calls) <= idx:
collected_tool_calls.append(
{"id": "", "function": {"name": "", "arguments": ""}}
)

# Convert response to responses_items format
responses_items = _convert_completion_to_responses_items(response)
if hasattr(tool_call_delta, "id") and tool_call_delta.id:
collected_tool_calls[idx]["id"] = tool_call_delta.id

if hasattr(tool_call_delta, "function"):
func = tool_call_delta.function
if hasattr(func, "name") and func.name:
collected_tool_calls[idx]["function"]["name"] = func.name
if hasattr(func, "arguments") and func.arguments:
collected_tool_calls[idx]["function"][
"arguments"
] += func.arguments

# Collect usage from final chunk
if hasattr(chunk, "usage") and chunk.usage:
collected_usage = {
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage(
chunk.usage
).model_dump(),
"response_cost": (
chunk._hidden_params.get("response_cost", 0.0)
if hasattr(chunk, "_hidden_params")
else 0.0
),
}

# Extract usage information
responses_usage = {
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage(
response.usage
).model_dump(),
"response_cost": response._hidden_params.get("response_cost", 0.0),
}
if _on_usage:
await _on_usage(responses_usage)
# Process any accumulated tool calls into response items
final_items: List[Dict[str, Any]] = []
for tc in collected_tool_calls:
if tc["function"]["name"] == "computer":
try:
args = json.loads(tc["function"]["arguments"])
action_type = args.get("action")
call_id = tc["id"]

# Convert tool call to computer_call format
item = self._convert_tool_call_to_computer_call(action_type, args, call_id)
if item:
final_items.append(item)
except json.JSONDecodeError:
pass
else:
# Function call
from ..responses import make_function_call_item

try:
args_dict = json.loads(tc["function"]["arguments"])
except json.JSONDecodeError:
args_dict = {}
final_items.append(
make_function_call_item(
function_name=tc["function"]["name"],
arguments=args_dict,
call_id=tc["id"],
)
)

if final_items:
yield {"output": final_items, "usage": {}, "status": "streaming"}

# Return in AsyncAgentConfig format
return {"output": responses_items, "usage": responses_usage}
# Call API end hook with None response for streaming
if _on_api_end:
await _on_api_end(api_kwargs, None)

# Call usage hook
if _on_usage and collected_usage:
await _on_usage(collected_usage)

# Final yield with usage information
yield {"output": [], "usage": collected_usage, "status": "completed"}

def _convert_tool_call_to_computer_call(
self, action_type: str, args: Dict[str, Any], call_id: str
) -> Optional[Dict[str, Any]]:
"""Convert a tool call to a computer_call response item."""
if action_type == "screenshot":
return make_screenshot_item(call_id=call_id)
elif action_type in ["click", "left_click"]:
coordinate = args.get("coordinate", [0, 0])
return make_click_item(
x=coordinate[0] if len(coordinate) > 0 else 0,
y=coordinate[1] if len(coordinate) > 1 else 0,
call_id=call_id,
)
elif action_type in ["type", "type_text"]:
return make_type_item(text=args.get("text", ""), call_id=call_id)
elif action_type in ["key", "keypress", "hotkey"]:
return make_keypress_item(
keys=args.get("text", "").replace("+", "-").split("-"),
call_id=call_id,
)
elif action_type in ["mouse_move", "move_cursor", "move"]:
coordinate = args.get("coordinate", [0, 0])
return make_move_item(
x=coordinate[0] if len(coordinate) > 0 else 0,
y=coordinate[1] if len(coordinate) > 1 else 0,
call_id=call_id,
)
elif action_type == "scroll":
coordinate = args.get("coordinate", [0, 0])
direction = args.get("scroll_direction", "down")
amount = args.get("scroll_amount", 3)
scroll_x = amount if direction == "left" else -amount if direction == "right" else 0
scroll_y = amount if direction == "up" else -amount if direction == "down" else 0
return make_scroll_item(
x=coordinate[0] if len(coordinate) > 0 else 0,
y=coordinate[1] if len(coordinate) > 1 else 0,
scroll_x=scroll_x,
scroll_y=scroll_y,
call_id=call_id,
)
elif action_type in ["left_click_drag", "drag"]:
start_coord = args.get("start_coordinate", [0, 0])
end_coord = args.get("end_coordinate", [0, 0])
return make_drag_item(
path=[
{
"x": start_coord[0] if len(start_coord) > 0 else 0,
"y": start_coord[1] if len(start_coord) > 1 else 0,
},
{
"x": end_coord[0] if len(end_coord) > 0 else 0,
"y": end_coord[1] if len(end_coord) > 1 else 0,
},
],
call_id=call_id,
)
elif action_type == "right_click":
coordinate = args.get("coordinate", [0, 0])
return make_click_item(
x=coordinate[0] if len(coordinate) > 0 else 0,
y=coordinate[1] if len(coordinate) > 1 else 0,
button="right",
call_id=call_id,
)
elif action_type == "middle_click":
coordinate = args.get("coordinate", [0, 0])
return make_click_item(
x=coordinate[0] if len(coordinate) > 0 else 0,
y=coordinate[1] if len(coordinate) > 1 else 0,
button="wheel",
call_id=call_id,
)
elif action_type == "double_click":
coordinate = args.get("coordinate", [0, 0])
return make_double_click_item(
x=coordinate[0] if len(coordinate) > 0 else 0,
y=coordinate[1] if len(coordinate) > 1 else 0,
call_id=call_id,
)
elif action_type == "left_mouse_down":
coordinate = args.get("coordinate", [None, None])
return make_left_mouse_down_item(
x=coordinate[0] if len(coordinate) > 0 else None,
y=coordinate[1] if len(coordinate) > 1 else None,
call_id=call_id,
)
elif action_type == "left_mouse_up":
coordinate = args.get("coordinate", [None, None])
return make_left_mouse_up_item(
x=coordinate[0] if len(coordinate) > 0 else None,
y=coordinate[1] if len(coordinate) > 1 else None,
call_id=call_id,
)
elif action_type == "wait":
return make_wait_item(call_id=call_id)
return None

async def predict_click(
self, model: str, image_b64: str, instruction: str, **kwargs
Expand Down
Loading
Loading