-
Notifications
You must be signed in to change notification settings - Fork 83
Add hint tuning tool to evaluate best actions at a given step #285
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,14 @@ | |
from logging import warning | ||
from pathlib import Path | ||
|
||
from agentlab.llm.chat_api import ( | ||
AzureChatModel, | ||
OpenAIChatModel, | ||
OpenRouterChatModel, | ||
make_system_message, | ||
make_user_message, | ||
) | ||
|
||
import gradio as gr | ||
import matplotlib.patches as patches | ||
import matplotlib.pyplot as plt | ||
|
@@ -15,7 +23,7 @@ | |
from attr import dataclass | ||
from browsergym.experiments.loop import StepInfo as BGymStepInfo | ||
from langchain.schema import BaseMessage, HumanMessage | ||
from openai import OpenAI | ||
from openai import AzureOpenAI | ||
from openai.types.responses import ResponseFunctionToolCall | ||
from PIL import Image | ||
|
||
|
@@ -399,14 +407,27 @@ def run_gradio(results_dir: Path): | |
interactive=True, | ||
elem_id="prompt_tests_textbox", | ||
) | ||
submit_button = gr.Button(value="Submit") | ||
with gr.Row(): | ||
num_queries_input = gr.Number( | ||
value=3, | ||
label="Number of model queries", | ||
minimum=1, | ||
maximum=10, | ||
step=1, | ||
precision=0, | ||
interactive=True, | ||
) | ||
submit_button = gr.Button(value="Submit") | ||
result_box = gr.Textbox( | ||
value="", label="Result", show_label=True, interactive=False | ||
value="", label="Result", show_label=True, interactive=False, max_lines=20 | ||
) | ||
with gr.Row(): | ||
# Add plot component for action distribution graph | ||
action_plot = gr.Plot(label="Action Distribution", show_label=True) | ||
|
||
# Define the interaction | ||
submit_button.click( | ||
fn=submit_action, inputs=prompt_tests_textbox, outputs=result_box | ||
fn=submit_action, inputs=[prompt_tests_textbox, num_queries_input], outputs=[result_box, action_plot] | ||
) | ||
|
||
# Handle Events # | ||
|
@@ -843,9 +864,11 @@ def _page_to_iframe(page: str): | |
return page | ||
|
||
|
||
def submit_action(input_text): | ||
def submit_action(input_text, num_queries=3): | ||
global info | ||
agent_info = info.exp_result.steps_info[info.step].agent_info | ||
# Get the current step's action string for comparison | ||
step_action_str = info.exp_result.steps_info[info.step].action | ||
chat_messages = deepcopy(agent_info.get("chat_messages", ["No Chat Messages"])[:2]) | ||
if isinstance(chat_messages[1], BaseMessage): # TODO remove once langchain is deprecated | ||
assert isinstance(chat_messages[1], HumanMessage), "Second message should be user" | ||
|
@@ -858,14 +881,102 @@ def submit_action(input_text): | |
else: | ||
raise ValueError("Chat messages should be a list of BaseMessage or dict") | ||
|
||
client = OpenAI() | ||
client = AzureChatModel(model_name="gpt-35-turbo", deployment_name="gpt-35-turbo") | ||
chat_messages[1]["content"] = input_text | ||
completion = client.chat.completions.create( | ||
model="gpt-4o-mini", | ||
messages=chat_messages, | ||
) | ||
result_text = completion.choices[0].message.content | ||
return result_text | ||
|
||
# Query the model N times | ||
answers = [] | ||
actions = [] | ||
import re | ||
|
||
for _ in range(num_queries): | ||
answer = client(chat_messages) | ||
content = answer.get("content", "") | ||
answers.append(content) | ||
Comment on lines
+892
to
+895
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sequential Model Queries
Tell me moreWhat is the issue?Sequential model queries are being made in a loop, causing unnecessary latency accumulation. Why this mattersMaking sequential API calls significantly increases total response time as each request must wait for the previous one to complete. With multiple queries, this creates a substantial performance bottleneck. Suggested change ∙ Feature PreviewParallelize the model queries using async/await pattern or concurrent.futures to make multiple requests simultaneously: async def get_model_responses(client, chat_messages, num_queries):
async with asyncio.TaskGroup() as tg:
tasks = [
tg.create_task(client.acall(chat_messages))
for _ in range(num_queries)
]
return [t.result() for t in tasks] Provide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
|
||
# Extract action part using regex | ||
action_match = re.search(r'<action>(.*?)</action>', content, re.DOTALL) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fragile Action Extraction Logic
Tell me moreWhat is the issue?The code assumes actions are always wrapped in tags, but doesn't handle cases where the model output might be malformed or use different formats. Why this mattersThe action extraction could fail silently when the model output doesn't follow the expected format, leading to incomplete or misleading action distribution analysis. Suggested change ∙ Feature PreviewAdd robust action extraction with error handling: def extract_action(content: str) -> str | None:
# Try XML-style tags
action_match = re.search(r'<action>(.*?)</action>', content, re.DOTALL)
if action_match:
return action_match.group(1).strip()
# Try markdown-style formatting
action_match = re.search(r'`action: (.*?)`', content, re.DOTALL)
if action_match:
return action_match.group(1).strip()
# Try plain text format
action_match = re.search(r'Action: (.*?)(?:
|$)', content)
if action_match:
return action_match.group(1).strip()
return None
# Usage in loop:
action = extract_action(content)
if action:
actions.append(action)
else:
print(f"Warning: Could not extract action from response: {content[:100]}...") Provide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
if action_match: | ||
actions.append(action_match.group(1).strip()) | ||
|
||
# Prepare the aggregate result | ||
result = "" | ||
|
||
# Include full responses first | ||
result += "\n\n===== FULL MODEL RESPONSES =====\n\n" | ||
result += "\n\n===== MODEL RESPONSE SEPARATION =====\n\n".join(answers) | ||
|
||
# Then add aggregated actions | ||
result += "\n\n===== EXTRACTED ACTIONS =====\n\n" | ||
|
||
# Create plot for action distribution | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
from collections import Counter | ||
|
||
# Create a figure for the action distribution | ||
fig = plt.figure(figsize=(10, 6)) | ||
|
||
if actions: | ||
# Count unique actions | ||
action_counts = Counter(actions) | ||
|
||
# Get actions in most_common order to ensure consistency between plot and text output | ||
most_common_actions = action_counts.most_common() | ||
|
||
# Prepare data for plotting (using most_common order) | ||
labels = [f"Action {i+1}" for i in range(len(most_common_actions))] | ||
values = [count for _, count in most_common_actions] | ||
percentages = [(count / len(actions)) * 100 for count in values] | ||
|
||
# Create bar chart | ||
plt.bar(labels, percentages, color='skyblue') | ||
plt.xlabel('Actions') | ||
plt.ylabel('Percentage (%)') | ||
plt.title(f'Action Distribution (from {num_queries} model queries)') | ||
plt.ylim(0, 100) # Set y-axis from 0 to 100% | ||
|
||
# Add percentage labels on top of each bar | ||
for i, v in enumerate(percentages): | ||
plt.text(i, v + 2, f"{v:.1f}%", ha='center') | ||
|
||
# Add total counts as text annotation | ||
plt.figtext(0.5, 0.01, | ||
f"Total actions extracted: {len(actions)} | Unique actions: {len(action_counts)}", | ||
ha="center", fontsize=10, bbox={"facecolor":"white", "alpha":0.5, "pad":5}) | ||
|
||
# Display unique actions and their counts in text result | ||
for i, (action, count) in enumerate(action_counts.most_common()): | ||
percentage = (count / len(actions)) * 100 | ||
|
||
# Check if this action matches the current step's action | ||
matches_current_action = step_action_str and action.strip() == step_action_str.strip() | ||
|
||
# Highlight conditions: | ||
# 1. If it's the most common action (i==0) | ||
# 2. If it matches the current step's action | ||
if i == 0 and matches_current_action: | ||
result += f"** Predicted Action {i+1} (occurred {count}/{len(actions)} times - {percentage:.1f}%) [MATCHES CURRENT ACTION]**:\n**{action}**\n\n" | ||
elif i == 0: # Just the most common | ||
result += f"** Predicted Action {i+1} (occurred {count}/{len(actions)} times - {percentage:.1f}%)**:\n**{action}**\n\n" | ||
elif matches_current_action: # Matches current action but not most common | ||
result += f"** Action {i+1} (occurred {count}/{len(actions)} times - {percentage:.1f}%) [MATCHES CURRENT ACTION]**:\n**{action}**\n\n" | ||
else: # Regular action | ||
result += f"Action {i+1} (occurred {count}/{len(actions)} times - {percentage:.1f}%):\n{action}\n\n" | ||
else: | ||
result += "No actions found in any of the model responses.\n\n" | ||
|
||
# Create empty plot with message | ||
plt.text(0.5, 0.5, "No actions found in model responses", | ||
ha='center', va='center', fontsize=14) | ||
plt.axis('off') # Hide axes | ||
|
||
plt.tight_layout() | ||
|
||
# Return both the text result and the figure | ||
return result, fig | ||
|
||
|
||
|
||
|
||
def update_prompt_tests(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hardcoded LLM Provider Configuration
Tell me more
What is the issue?
The code hardcodes the Azure model configuration without providing flexibility for different model providers or configurations.
Why this matters
Users won't be able to use different LLM providers or models for hint tuning, which limits the tool's applicability and contradicts the developer's intent of providing a systematic approach for evaluation.
Suggested change ∙ Feature Preview
Create a configurable model provider system:
Provide feedback to improve future suggestions
💬 Looking for more details? Reply to this comment to chat with Korbit.