Skip to content

Commit 7935b0f

Browse files
committed
add new endpoint
1 parent 64fc0f9 commit 7935b0f

File tree

13 files changed

+3226
-1
lines changed

13 files changed

+3226
-1
lines changed

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/__init__.py

Whitespace-only changes.

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/lats_agent.py

Lines changed: 776 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
import numpy as np
2+
from dataclasses import dataclass
3+
from typing import Optional
4+
from pydantic import BaseModel
5+
import base64
6+
from ...webagent_utils_async.evaluation.feedback import Feedback
7+
8+
@dataclass
9+
class Element:
10+
"""Represents a DOM element with its properties."""
11+
text: str
12+
tag: str
13+
id: str
14+
title: str
15+
ariaLabel: str
16+
name: str
17+
value: str
18+
placeholder: str
19+
class_name: str # Changed from 'class' as it's a reserved keyword
20+
role: str
21+
unique_selector: str
22+
selector_uniqueness_validated: bool
23+
24+
class Observation(BaseModel):
25+
text: str
26+
image: Optional[bytes] = None
27+
image_base64: Optional[str] = None
28+
29+
def get_base64_image(self):
30+
if self.image_base64 is None:
31+
self.image_base64 = base64.b64encode(self.image).decode('utf-8')
32+
return self.image_base64
33+
34+
class LATSNode:
35+
"""
36+
A node class for Language-based Action Tree Search (LATS).
37+
38+
This class implements a tree structure for MCTS-like search algorithms,
39+
specifically designed for language-based action planning in UI interactions.
40+
41+
Attributes:
42+
natural_language_description (str): Human-readable description of the action
43+
action (str): The actual action to be executed
44+
prob (float): Probability or confidence score for this action
45+
element (Element): DOM element associated with this action
46+
goal (str): The target goal state
47+
parent (Optional[LATSNode]): Parent node in the tree
48+
children (list[LATSNode]): Child nodes in the tree
49+
visits (int): Number of times this node has been visited
50+
value (float): Accumulated value/score of this node
51+
depth (int): Depth of this node in the tree
52+
is_terminal (bool): Whether this node is a terminal state
53+
reward (float): Reward received at this node
54+
exhausted (bool): Whether all children have been explored
55+
em (float): Exact match score for evaluation
56+
"""
57+
58+
def __init__(
59+
self,
60+
natural_language_description: str,
61+
action: str,
62+
prob: float,
63+
element: dict, # Using dict instead of Element for backward compatibility
64+
goal: str,
65+
parent: Optional['LATSNode'] = None
66+
) -> None:
67+
"""
68+
Initialize a new LATSNode.
69+
70+
Args:
71+
natural_language_description: Human-readable description of the action
72+
action: The actual action to be executed
73+
prob: Probability or confidence score for this action
74+
element: DOM element associated with this action
75+
goal: The target goal state
76+
parent: Parent node in the tree, if any
77+
"""
78+
self.natural_language_description = natural_language_description
79+
self.action = action
80+
self.prob = prob
81+
self.element = element
82+
self.feedback = ''
83+
self.goal_finish_feedback: Optional[Feedback] = None
84+
self.parent = parent
85+
self.goal = goal
86+
self.children: list[LATSNode] = []
87+
self.visits = 0
88+
self.value = 0.0
89+
self.depth = 0 if parent is None else parent.depth + 1
90+
self.is_terminal = False
91+
self.reward = 0.0
92+
self.exhausted = False # If all children are terminal
93+
self.em = 0.0 # Exact match, evaluation metric
94+
self.observation: Optional[Observation] = None
95+
96+
def uct(self) -> float:
97+
"""
98+
Calculate the UCT (Upper Confidence Bound for Trees) value for this node.
99+
100+
Returns:
101+
float: The UCT value for this node. If the node has never been visited,
102+
returns the node's current value.
103+
"""
104+
if self.visits == 0:
105+
return self.value
106+
return self.value / self.visits + np.sqrt(2 * np.log(self.parent.visits) / self.visits)
107+
108+
def get_best_leaf(self) -> 'LATSNode':
109+
unfinished_children = [c for c in self.children if not c.is_terminal]
110+
if not unfinished_children:
111+
return self
112+
113+
best_child = max(unfinished_children, key=lambda x: x.uct())
114+
return best_child.get_best_leaf()
115+
116+
def get_action_trajectory(self) -> list[dict]:
117+
trajectory = []
118+
node = self
119+
# exclude the root node
120+
while node.parent is not None:
121+
trajectory.append({
122+
"action": node.action,
123+
"natural_language_description": node.natural_language_description,
124+
"element": node.element
125+
})
126+
node = node.parent
127+
return trajectory[::-1]
128+
129+
def get_trajectory(self) -> list[dict]:
130+
trajectory = []
131+
node = self
132+
# exclude the root node
133+
while node.parent is not None:
134+
trajectory.append({
135+
"natural_language_description": node.natural_language_description,
136+
"action": node.action
137+
})
138+
node = node.parent
139+
return trajectory[::-1]
140+
141+
def add_child(self, child: 'LATSNode') -> None:
142+
self.children.append(child)
143+
child.parent = self
144+
child.depth = self.depth + 1
145+
146+
def check_terminal(self) -> bool:
147+
if not self.children or all(child.is_terminal for child in self.children):
148+
self.is_terminal = True
149+
if self.parent:
150+
self.parent.check_terminal()
151+
152+
def __str__(self) -> str:
153+
"""
154+
Get a string representation of the node.
155+
156+
Returns:
157+
str: A string describing the node's key attributes
158+
"""
159+
return (f"Node(depth={self.depth}, value={self.value:.2f}, "
160+
f"visits={self.visits}, action={self.action}, "
161+
f"feedback={self.feedback})")
162+
163+
def to_dict(self) -> dict:
164+
"""
165+
Convert the node and its subtree to a dictionary representation.
166+
167+
Returns:
168+
dict: A dictionary containing all node attributes and recursive
169+
representations of parent and children nodes
170+
"""
171+
return {
172+
'state': self.state,
173+
'question': self.question,
174+
'parent': self.parent.to_dict() if self.parent else None,
175+
'children': [child.to_dict() for child in self.children],
176+
'visits': self.visits,
177+
'value': self.value,
178+
'depth': self.depth,
179+
'is_terminal': self.is_terminal,
180+
'reward': self.reward,
181+
'em': self.em,
182+
}
183+
184+
@property
185+
def state(self) -> dict:
186+
"""
187+
Get the current state representation of the node.
188+
189+
Returns:
190+
dict: A dictionary containing the node's state information
191+
"""
192+
return {
193+
'natural_language_description': self.natural_language_description,
194+
'action': self.action,
195+
'prob': self.prob,
196+
'element': self.element
197+
}
198+
199+
@property
200+
def question(self) -> str:
201+
"""
202+
Get the goal/question associated with this node.
203+
204+
Returns:
205+
str: The goal or question string
206+
"""
207+
return self.goal
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import logging
2+
import time
3+
from typing import Any, Dict, List, Optional
4+
from collections import deque
5+
from datetime import datetime
6+
import os
7+
import json
8+
import subprocess
9+
10+
from openai import OpenAI
11+
from dotenv import load_dotenv
12+
load_dotenv()
13+
import aiohttp
14+
15+
from ...core_async.config import AgentConfig
16+
17+
from ...webagent_utils_async.action.highlevel import HighLevelActionSet
18+
from ...webagent_utils_async.utils.playwright_manager import AsyncPlaywrightManager, setup_playwright
19+
from ...webagent_utils_async.utils.utils import parse_function_args, locate_element
20+
from ...evaluation_async.evaluators import goal_finished_evaluator
21+
from ...replay_async import generate_feedback, playwright_step_execution
22+
from ...webagent_utils_async.action.prompt_functions import extract_top_actions
23+
from ...webagent_utils_async.browser_env.observation import extract_page_info
24+
from .lats_node import LATSNode
25+
from .tree_vis import better_print, print_trajectory, collect_all_nodes, GREEN, RESET, print_entire_tree
26+
from .trajectory_score import create_llm_prompt, score_trajectory_with_openai
27+
from ...webagent_utils_async.utils.utils import urls_to_images
28+
29+
logger = logging.getLogger(__name__)
30+
openai_client = OpenAI()
31+
32+
class MCTSAgent:
33+
def __init__(
34+
self,
35+
starting_url: str,
36+
messages: list[dict[str, Any]],
37+
goal: str,
38+
images: list,
39+
playwright_manager: AsyncPlaywrightManager,
40+
config: AgentConfig,
41+
):
42+
self.starting_url = starting_url
43+
self.goal = goal
44+
self.image_urls = images
45+
self.images = urls_to_images(self.image_urls)
46+
self.messages = messages
47+
self.messages.append({"role": "user", "content": f"The goal is: {self.goal}"})
48+
49+
self.playwright_manager = playwright_manager
50+
51+
self.config = config
52+
53+
self.agent_type = ["bid", "nav", "file", "select_option"]
54+
self.action_set = HighLevelActionSet(
55+
subsets=self.agent_type, strict=False, multiaction=True, demo_mode="default"
56+
)
57+
self.root_node = LATSNode(
58+
natural_language_description=None,
59+
action=None,
60+
prob=None,
61+
element=None,
62+
goal=self.goal,
63+
parent=None
64+
)
65+
self.reset_url = os.environ["ACCOUNT_RESET_URL"]
66+
67+
async def run(self, websocket=None) -> List[Dict[str, Any]]:
68+
pass

0 commit comments

Comments
 (0)