Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 123 additions & 12 deletions src/agentlab/analyze/agent_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
from logging import warning
from pathlib import Path

from agentlab.llm.chat_api import (
AzureChatModel,
OpenAIChatModel,
OpenRouterChatModel,
make_system_message,
make_user_message,
)

import gradio as gr
import matplotlib.patches as patches
import matplotlib.pyplot as plt
Expand All @@ -15,7 +23,7 @@
from attr import dataclass
from browsergym.experiments.loop import StepInfo as BGymStepInfo
from langchain.schema import BaseMessage, HumanMessage
from openai import OpenAI
from openai import AzureOpenAI
from openai.types.responses import ResponseFunctionToolCall
from PIL import Image

Expand Down Expand Up @@ -399,14 +407,27 @@ def run_gradio(results_dir: Path):
interactive=True,
elem_id="prompt_tests_textbox",
)
submit_button = gr.Button(value="Submit")
with gr.Row():
num_queries_input = gr.Number(
value=3,
label="Number of model queries",
minimum=1,
maximum=10,
step=1,
precision=0,
interactive=True,
)
submit_button = gr.Button(value="Submit")
result_box = gr.Textbox(
value="", label="Result", show_label=True, interactive=False
value="", label="Result", show_label=True, interactive=False, max_lines=20
)
with gr.Row():
# Add plot component for action distribution graph
action_plot = gr.Plot(label="Action Distribution", show_label=True)

# Define the interaction
submit_button.click(
fn=submit_action, inputs=prompt_tests_textbox, outputs=result_box
fn=submit_action, inputs=[prompt_tests_textbox, num_queries_input], outputs=[result_box, action_plot]
)

# Handle Events #
Expand Down Expand Up @@ -843,9 +864,11 @@ def _page_to_iframe(page: str):
return page


def submit_action(input_text):
def submit_action(input_text, num_queries=3):
global info
agent_info = info.exp_result.steps_info[info.step].agent_info
# Get the current step's action string for comparison
step_action_str = info.exp_result.steps_info[info.step].action
chat_messages = deepcopy(agent_info.get("chat_messages", ["No Chat Messages"])[:2])
if isinstance(chat_messages[1], BaseMessage): # TODO remove once langchain is deprecated
assert isinstance(chat_messages[1], HumanMessage), "Second message should be user"
Expand All @@ -858,14 +881,102 @@ def submit_action(input_text):
else:
raise ValueError("Chat messages should be a list of BaseMessage or dict")

client = OpenAI()
client = AzureChatModel(model_name="gpt-35-turbo", deployment_name="gpt-35-turbo")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoded LLM Provider Configuration category Design

Tell me more
What is the issue?

The code hardcodes the Azure model configuration without providing flexibility for different model providers or configurations.

Why this matters

Users won't be able to use different LLM providers or models for hint tuning, which limits the tool's applicability and contradicts the developer's intent of providing a systematic approach for evaluation.

Suggested change ∙ Feature Preview

Create a configurable model provider system:

def get_model_client(provider="azure", model_name="gpt-35-turbo"):
    if provider == "azure":
        return AzureChatModel(model_name=model_name, deployment_name=model_name)
    elif provider == "openai":
        return OpenAIChatModel(model_name=model_name)
    elif provider == "openrouter":
        return OpenRouterChatModel(model_name=model_name)
    else:
        raise ValueError(f"Unsupported provider: {provider}")

# Usage in submit_action:
client = get_model_client(provider="azure", model_name="gpt-35-turbo")
Provide feedback to improve future suggestions

Nice Catch Incorrect Not in Scope Not in coding standard Other

💬 Looking for more details? Reply to this comment to chat with Korbit.

chat_messages[1]["content"] = input_text
completion = client.chat.completions.create(
model="gpt-4o-mini",
messages=chat_messages,
)
result_text = completion.choices[0].message.content
return result_text

# Query the model N times
answers = []
actions = []
import re

for _ in range(num_queries):
answer = client(chat_messages)
content = answer.get("content", "")
answers.append(content)
Comment on lines +892 to +895
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sequential Model Queries category Performance

Tell me more
What is the issue?

Sequential model queries are being made in a loop, causing unnecessary latency accumulation.

Why this matters

Making sequential API calls significantly increases total response time as each request must wait for the previous one to complete. With multiple queries, this creates a substantial performance bottleneck.

Suggested change ∙ Feature Preview

Parallelize the model queries using async/await pattern or concurrent.futures to make multiple requests simultaneously:

async def get_model_responses(client, chat_messages, num_queries):
    async with asyncio.TaskGroup() as tg:
        tasks = [
            tg.create_task(client.acall(chat_messages))
            for _ in range(num_queries)
        ]
    return [t.result() for t in tasks]
Provide feedback to improve future suggestions

Nice Catch Incorrect Not in Scope Not in coding standard Other

💬 Looking for more details? Reply to this comment to chat with Korbit.


# Extract action part using regex
action_match = re.search(r'<action>(.*?)</action>', content, re.DOTALL)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fragile Action Extraction Logic category Error Handling

Tell me more
What is the issue?

The code assumes actions are always wrapped in tags, but doesn't handle cases where the model output might be malformed or use different formats.

Why this matters

The action extraction could fail silently when the model output doesn't follow the expected format, leading to incomplete or misleading action distribution analysis.

Suggested change ∙ Feature Preview

Add robust action extraction with error handling:

def extract_action(content: str) -> str | None:
    # Try XML-style tags
    action_match = re.search(r'<action>(.*?)</action>', content, re.DOTALL)
    if action_match:
        return action_match.group(1).strip()
    
    # Try markdown-style formatting
    action_match = re.search(r'`action: (.*?)`', content, re.DOTALL)
    if action_match:
        return action_match.group(1).strip()
        
    # Try plain text format
    action_match = re.search(r'Action: (.*?)(?:
|$)', content)
    if action_match:
        return action_match.group(1).strip()
    
    return None

# Usage in loop:
action = extract_action(content)
if action:
    actions.append(action)
else:
    print(f"Warning: Could not extract action from response: {content[:100]}...")
Provide feedback to improve future suggestions

Nice Catch Incorrect Not in Scope Not in coding standard Other

💬 Looking for more details? Reply to this comment to chat with Korbit.

if action_match:
actions.append(action_match.group(1).strip())

# Prepare the aggregate result
result = ""

# Include full responses first
result += "\n\n===== FULL MODEL RESPONSES =====\n\n"
result += "\n\n===== MODEL RESPONSE SEPARATION =====\n\n".join(answers)

# Then add aggregated actions
result += "\n\n===== EXTRACTED ACTIONS =====\n\n"

# Create plot for action distribution
import matplotlib.pyplot as plt
import numpy as np
from collections import Counter

# Create a figure for the action distribution
fig = plt.figure(figsize=(10, 6))

if actions:
# Count unique actions
action_counts = Counter(actions)

# Get actions in most_common order to ensure consistency between plot and text output
most_common_actions = action_counts.most_common()

# Prepare data for plotting (using most_common order)
labels = [f"Action {i+1}" for i in range(len(most_common_actions))]
values = [count for _, count in most_common_actions]
percentages = [(count / len(actions)) * 100 for count in values]

# Create bar chart
plt.bar(labels, percentages, color='skyblue')
plt.xlabel('Actions')
plt.ylabel('Percentage (%)')
plt.title(f'Action Distribution (from {num_queries} model queries)')
plt.ylim(0, 100) # Set y-axis from 0 to 100%

# Add percentage labels on top of each bar
for i, v in enumerate(percentages):
plt.text(i, v + 2, f"{v:.1f}%", ha='center')

# Add total counts as text annotation
plt.figtext(0.5, 0.01,
f"Total actions extracted: {len(actions)} | Unique actions: {len(action_counts)}",
ha="center", fontsize=10, bbox={"facecolor":"white", "alpha":0.5, "pad":5})

# Display unique actions and their counts in text result
for i, (action, count) in enumerate(action_counts.most_common()):
percentage = (count / len(actions)) * 100

# Check if this action matches the current step's action
matches_current_action = step_action_str and action.strip() == step_action_str.strip()

# Highlight conditions:
# 1. If it's the most common action (i==0)
# 2. If it matches the current step's action
if i == 0 and matches_current_action:
result += f"** Predicted Action {i+1} (occurred {count}/{len(actions)} times - {percentage:.1f}%) [MATCHES CURRENT ACTION]**:\n**{action}**\n\n"
elif i == 0: # Just the most common
result += f"** Predicted Action {i+1} (occurred {count}/{len(actions)} times - {percentage:.1f}%)**:\n**{action}**\n\n"
elif matches_current_action: # Matches current action but not most common
result += f"** Action {i+1} (occurred {count}/{len(actions)} times - {percentage:.1f}%) [MATCHES CURRENT ACTION]**:\n**{action}**\n\n"
else: # Regular action
result += f"Action {i+1} (occurred {count}/{len(actions)} times - {percentage:.1f}%):\n{action}\n\n"
else:
result += "No actions found in any of the model responses.\n\n"

# Create empty plot with message
plt.text(0.5, 0.5, "No actions found in model responses",
ha='center', va='center', fontsize=14)
plt.axis('off') # Hide axes

plt.tight_layout()

# Return both the text result and the figure
return result, fig




def update_prompt_tests():
Expand Down
Loading