diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index ae1477359f..577187076f 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -532,6 +532,7 @@ async def agent_step( task, step, browser_state, + organization, ) detailed_agent_step_output.scraped_page = scraped_page detailed_agent_step_output.extract_action_prompt = extract_action_prompt @@ -890,6 +891,7 @@ async def _scrape_with_type( step: Step, browser_state: BrowserState, scrape_type: ScrapeType, + organization: Organization | None = None, ) -> ScrapedPage | None: if scrape_type == ScrapeType.NORMAL: pass @@ -912,7 +914,7 @@ async def _scrape_with_type( return await scrape_website( browser_state, task.url, - app.AGENT_FUNCTION.cleanup_element_tree, + app.AGENT_FUNCTION.cleanup_element_tree_factory(task=task, step=step, organization=organization), scrape_exclude=app.scrape_exclude, ) @@ -921,6 +923,7 @@ async def _build_and_record_step_prompt( task: Task, step: Step, browser_state: BrowserState, + organization: Organization | None = None, ) -> tuple[ScrapedPage, str]: # start the async tasks while running scrape_website self.async_operation_pool.run_operation(task.task_id, AgentPhase.scrape) @@ -934,7 +937,11 @@ async def _build_and_record_step_prompt( for idx, scrape_type in enumerate(SCRAPE_TYPE_ORDER): try: scraped_page = await self._scrape_with_type( - task=task, step=step, browser_state=browser_state, scrape_type=scrape_type + task=task, + step=step, + browser_state=browser_state, + scrape_type=scrape_type, + organization=organization, ) break except FailedToTakeScreenshot as e: diff --git a/skyvern/forge/agent_functions.py b/skyvern/forge/agent_functions.py index 480d1f2290..d377660699 100644 --- a/skyvern/forge/agent_functions.py +++ b/skyvern/forge/agent_functions.py @@ -1,4 +1,4 @@ -from typing import Dict, List +from typing import Awaitable, Callable from playwright.async_api import Page @@ -9,6 +9,8 @@ from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus from skyvern.webeye.browser_factory import BrowserState +CleanupElementTreeFunc = Callable[[str, list[dict]], Awaitable[list[dict]]] + def _remove_rect(element: dict) -> None: if "rect" in element: @@ -64,28 +66,32 @@ def generate_async_operations( ) -> list[AsyncOperation]: return [] - async def cleanup_element_tree( + def cleanup_element_tree_factory( self, - url: str, - element_tree: List[Dict], - ) -> List[Dict]: - """ - Remove rect and attribute.unique_id from the elements. - The reason we're doing it is to - 1. reduce unnecessary data so that llm get less distrction - TODO later: 2. reduce tokens sent to llm to save money - :param elements: List of elements to remove xpaths from. - :return: List of elements without xpaths. - """ - queue = [] - for element in element_tree: - queue.append(element) - while queue: - queue_ele = queue.pop(0) - _remove_rect(queue_ele) - # TODO: we can come back to test removing the unique_id - # from element attributes to make sure this won't increase hallucination - # _remove_unique_id(queue_ele) - if "children" in queue_ele: - queue.extend(queue_ele["children"]) - return element_tree + task: Task, + step: Step, + organization: Organization | None = None, + ) -> CleanupElementTreeFunc: + async def cleanup_element_tree_func(url: str, element_tree: list[dict]) -> list[dict]: + """ + Remove rect and attribute.unique_id from the elements. + The reason we're doing it is to + 1. reduce unnecessary data so that llm get less distrction + TODO later: 2. reduce tokens sent to llm to save money + :param elements: List of elements to remove xpaths from. + :return: List of elements without xpaths. + """ + queue = [] + for element in element_tree: + queue.append(element) + while queue: + queue_ele = queue.pop(0) + _remove_rect(queue_ele) + # TODO: we can come back to test removing the unique_id + # from element attributes to make sure this won't increase hallucination + # _remove_unique_id(queue_ele) + if "children" in queue_ele: + queue.extend(queue_ele["children"]) + return element_tree + + return cleanup_element_tree_func diff --git a/skyvern/webeye/actions/handler.py b/skyvern/webeye/actions/handler.py index adfc9970fa..15bc4f4341 100644 --- a/skyvern/webeye/actions/handler.py +++ b/skyvern/webeye/actions/handler.py @@ -326,7 +326,7 @@ async def handle_input_text_action( await asyncio.sleep(5) incremental_element = await incremental_scraped.get_incremental_element_tree( - app.AGENT_FUNCTION.cleanup_element_tree + app.AGENT_FUNCTION.cleanup_element_tree_factory(task=task, step=step) ) if len(incremental_element) == 0: LOG.info( @@ -593,7 +593,7 @@ async def handle_select_option_action( is_open = True incremental_element = await incremental_scraped.get_incremental_element_tree( - app.AGENT_FUNCTION.cleanup_element_tree + app.AGENT_FUNCTION.cleanup_element_tree_factory(step=step, task=task) ) if len(incremental_element) == 0: raise NoIncrementalElementFoundForCustomSelection(element_id=action.element_id) @@ -887,7 +887,7 @@ async def select_from_dropdown( ) trimmed_element_tree = await incremental_scraped.get_incremental_element_tree( - app.AGENT_FUNCTION.cleanup_element_tree + app.AGENT_FUNCTION.cleanup_element_tree_factory(step=step, task=task) ) if dropdown_menu_element: # if there's a dropdown menu detected, only elements in the dropdown should be sent to LLM