1+ import numpy as np
2+ from dataclasses import dataclass
3+ from typing import Optional
4+ from pydantic import BaseModel
5+ import base64
6+ from ...webagent_utils_async .evaluation .feedback import Feedback
7+
8+ @dataclass
9+ class Element :
10+ """Represents a DOM element with its properties."""
11+ text : str
12+ tag : str
13+ id : str
14+ title : str
15+ ariaLabel : str
16+ name : str
17+ value : str
18+ placeholder : str
19+ class_name : str # Changed from 'class' as it's a reserved keyword
20+ role : str
21+ unique_selector : str
22+ selector_uniqueness_validated : bool
23+
24+ class Observation (BaseModel ):
25+ text : str
26+ image : Optional [bytes ] = None
27+ image_base64 : Optional [str ] = None
28+
29+ def get_base64_image (self ):
30+ if self .image_base64 is None :
31+ self .image_base64 = base64 .b64encode (self .image ).decode ('utf-8' )
32+ return self .image_base64
33+
34+ class LATSNode :
35+ """
36+ A node class for Language-based Action Tree Search (LATS).
37+
38+ This class implements a tree structure for MCTS-like search algorithms,
39+ specifically designed for language-based action planning in UI interactions.
40+
41+ Attributes:
42+ natural_language_description (str): Human-readable description of the action
43+ action (str): The actual action to be executed
44+ prob (float): Probability or confidence score for this action
45+ element (Element): DOM element associated with this action
46+ goal (str): The target goal state
47+ parent (Optional[LATSNode]): Parent node in the tree
48+ children (list[LATSNode]): Child nodes in the tree
49+ visits (int): Number of times this node has been visited
50+ value (float): Accumulated value/score of this node
51+ depth (int): Depth of this node in the tree
52+ is_terminal (bool): Whether this node is a terminal state
53+ reward (float): Reward received at this node
54+ exhausted (bool): Whether all children have been explored
55+ em (float): Exact match score for evaluation
56+ """
57+
58+ def __init__ (
59+ self ,
60+ natural_language_description : str ,
61+ action : str ,
62+ prob : float ,
63+ element : dict , # Using dict instead of Element for backward compatibility
64+ goal : str ,
65+ parent : Optional ['LATSNode' ] = None
66+ ) -> None :
67+ """
68+ Initialize a new LATSNode.
69+
70+ Args:
71+ natural_language_description: Human-readable description of the action
72+ action: The actual action to be executed
73+ prob: Probability or confidence score for this action
74+ element: DOM element associated with this action
75+ goal: The target goal state
76+ parent: Parent node in the tree, if any
77+ """
78+ self .natural_language_description = natural_language_description
79+ self .action = action
80+ self .prob = prob
81+ self .element = element
82+ self .feedback = ''
83+ self .goal_finish_feedback : Optional [Feedback ] = None
84+ self .parent = parent
85+ self .goal = goal
86+ self .children : list [LATSNode ] = []
87+ self .visits = 0
88+ self .value = 0.0
89+ self .depth = 0 if parent is None else parent .depth + 1
90+ self .is_terminal = False
91+ self .reward = 0.0
92+ self .exhausted = False # If all children are terminal
93+ self .em = 0.0 # Exact match, evaluation metric
94+ self .observation : Optional [Observation ] = None
95+
96+ def uct (self ) -> float :
97+ """
98+ Calculate the UCT (Upper Confidence Bound for Trees) value for this node.
99+
100+ Returns:
101+ float: The UCT value for this node. If the node has never been visited,
102+ returns the node's current value.
103+ """
104+ if self .visits == 0 :
105+ return self .value
106+ return self .value / self .visits + np .sqrt (2 * np .log (self .parent .visits ) / self .visits )
107+
108+ def get_best_leaf (self ) -> 'LATSNode' :
109+ unfinished_children = [c for c in self .children if not c .is_terminal ]
110+ if not unfinished_children :
111+ return self
112+
113+ best_child = max (unfinished_children , key = lambda x : x .uct ())
114+ return best_child .get_best_leaf ()
115+
116+ def get_action_trajectory (self ) -> list [dict ]:
117+ trajectory = []
118+ node = self
119+ # exclude the root node
120+ while node .parent is not None :
121+ trajectory .append ({
122+ "action" : node .action ,
123+ "natural_language_description" : node .natural_language_description ,
124+ "element" : node .element
125+ })
126+ node = node .parent
127+ return trajectory [::- 1 ]
128+
129+ def get_trajectory (self ) -> list [dict ]:
130+ trajectory = []
131+ node = self
132+ # exclude the root node
133+ while node .parent is not None :
134+ trajectory .append ({
135+ "natural_language_description" : node .natural_language_description ,
136+ "action" : node .action
137+ })
138+ node = node .parent
139+ return trajectory [::- 1 ]
140+
141+ def add_child (self , child : 'LATSNode' ) -> None :
142+ self .children .append (child )
143+ child .parent = self
144+ child .depth = self .depth + 1
145+
146+ def check_terminal (self ) -> bool :
147+ if not self .children or all (child .is_terminal for child in self .children ):
148+ self .is_terminal = True
149+ if self .parent :
150+ self .parent .check_terminal ()
151+
152+ def __str__ (self ) -> str :
153+ """
154+ Get a string representation of the node.
155+
156+ Returns:
157+ str: A string describing the node's key attributes
158+ """
159+ return (f"Node(depth={ self .depth } , value={ self .value :.2f} , "
160+ f"visits={ self .visits } , action={ self .action } , "
161+ f"feedback={ self .feedback } )" )
162+
163+ def to_dict (self ) -> dict :
164+ """
165+ Convert the node and its subtree to a dictionary representation.
166+
167+ Returns:
168+ dict: A dictionary containing all node attributes and recursive
169+ representations of parent and children nodes
170+ """
171+ return {
172+ 'state' : self .state ,
173+ 'question' : self .question ,
174+ 'parent' : self .parent .to_dict () if self .parent else None ,
175+ 'children' : [child .to_dict () for child in self .children ],
176+ 'visits' : self .visits ,
177+ 'value' : self .value ,
178+ 'depth' : self .depth ,
179+ 'is_terminal' : self .is_terminal ,
180+ 'reward' : self .reward ,
181+ 'em' : self .em ,
182+ }
183+
184+ @property
185+ def state (self ) -> dict :
186+ """
187+ Get the current state representation of the node.
188+
189+ Returns:
190+ dict: A dictionary containing the node's state information
191+ """
192+ return {
193+ 'natural_language_description' : self .natural_language_description ,
194+ 'action' : self .action ,
195+ 'prob' : self .prob ,
196+ 'element' : self .element
197+ }
198+
199+ @property
200+ def question (self ) -> str :
201+ """
202+ Get the goal/question associated with this node.
203+
204+ Returns:
205+ str: The goal or question string
206+ """
207+ return self .goal
0 commit comments