Skip to content

Commit

Permalink
fix merging-pipe only increasing bug (nottelabs#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
andreakiro authored Dec 14, 2024
1 parent 9246c97 commit 6f57367
Show file tree
Hide file tree
Showing 17 changed files with 349 additions and 191 deletions.
1 change: 1 addition & 0 deletions examples/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ async def run(self) -> None:
]

while True:
logger.info("> looping in")
resp: str = self.think(messages)
messages.append({"role": "assistant", "content": resp})
logger.info(f"🤖 {resp}")
Expand Down
39 changes: 4 additions & 35 deletions notte/browser/context.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,20 @@
from dataclasses import dataclass
from typing import Any

from PIL import Image

from notte.actions.base import Action
from notte.actions.space import ActionSpace
from notte.browser.node_type import NotteNode
from notte.browser.node_type import InteractionNode, NotteNode
from notte.browser.snapshot import BrowserSnapshot
from notte.utils import image


@dataclass
class Observation:
url: str
screenshot: bytes | None = None
space: ActionSpace | None = None

@property
def clean_url(self) -> str:
# remove anything after ? i.. ?tfs=CBwQARooEgoyMDI0LTEyLTAzagwIAh
return self.url.split("?")[0]

def display_screenshot(self) -> Image.Image | None:
if self.screenshot is None:
return None
return image.image_from_bytes(self.screenshot)

@staticmethod
def from_json(json: dict[str, Any]) -> "Observation":
return Observation(
url=json["url"],
screenshot=json["screenshot"],
space=ActionSpace.from_json(json["space"]),
)


@dataclass
class Context:
node: NotteNode
snapshot: BrowserSnapshot

def interaction_nodes(self) -> list[NotteNode]:
return self.node.flatten(only_interaction=True)
def interaction_nodes(self) -> list[InteractionNode]:
return self.node.interaction_nodes()

def markdown_description(self) -> str:
f = self.format(self.node, indent_level=0)
return f
return self.format(self.node, indent_level=0)

def format(self, node: NotteNode, indent_level: int = 0) -> str:
indent = " " * indent_level
Expand Down
2 changes: 1 addition & 1 deletion notte/browser/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, **kwargs: Unpack[BrowserArgs]) -> None:
self._page: Page | None = None
self._playwright: Playwright | None = None
self.timeout: int = kwargs.get("timeout", DEFAULT_LOADING_TIMEOUT)
self.headless: bool = kwargs.get("headless", True)
self.headless: bool = kwargs.get("headless", False)

async def start(self) -> None:
self.playwright = await async_playwright().start()
Expand Down
15 changes: 14 additions & 1 deletion notte/browser/node_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,9 +438,13 @@ def is_interaction(self) -> bool:
return self.role.category().value == NodeCategory.INTERACTION.value

def flatten(self, only_interaction: bool = False) -> list["NotteNode"]:
base: list["NotteNode"] = [self] if not only_interaction or self.is_interaction() else []
base: list["NotteNode"] = [] if only_interaction and not self.is_interaction() else [self]
return base + [node for child in self.children for node in child.flatten(only_interaction)]

def interaction_nodes(self) -> list["InteractionNode"]:
inodes = self.flatten(only_interaction=True)
return [InteractionNode(**{k: v for k, v in inode.__dict__.items() if k != "subtree_ids"}) for inode in inodes]

def subtree_filter(self, ft: Callable[["NotteNode"], bool]) -> "NotteNode | None":
def inner(node: NotteNode) -> NotteNode | None:
children = node.children
Expand All @@ -459,6 +463,15 @@ def inner(node: NotteNode) -> NotteNode | None:
return inner(self)


class InteractionNode(NotteNode):
id: str # type: ignore

def __post_init__(self) -> None:
if self.id is None:
raise ValueError("InteractionNode must have a valid non-None id")
super().__post_init__()


class A11yNode(TypedDict, total=False):
# from the a11y tree
role: Required[str]
Expand Down
62 changes: 62 additions & 0 deletions notte/browser/observation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from dataclasses import dataclass
from typing import Any

from PIL import Image
from typing_extensions import override

from notte.actions.space import ActionSpace
from notte.utils import image


@dataclass
class Observation:
_url: str
_space: ActionSpace
_screenshot: bytes | None = None

@property
def url(self) -> str:
return self._url

@property
def clean_url(self) -> str:
# remove anything after ? i.. ?tfs=CBwQARooEgoyMDI0LTEyLTAzagwIAh
return self.url.split("?")[0]

@property
def space(self) -> ActionSpace:
return self._space

@property
def screenshot(self) -> bytes | None:
return self._screenshot

def display_screenshot(self) -> Image.Image | None:
if self.screenshot is None:
return None
return image.image_from_bytes(self.screenshot)

@staticmethod
def from_json(json: dict[str, Any]) -> "Observation":
url: str | None = json.get("url", None)
if not isinstance(url, str):
raise ValueError("url must be a string")
screenshot: bytes | None = json.get("screenshot", None)
space: ActionSpace | None = json.get("space", None)
if not isinstance(space, dict):
raise ValueError("space must be a dictionary")
return Observation(
_url=url,
_screenshot=screenshot,
_space=ActionSpace.from_json(space),
)


@dataclass
class PreObservation(Observation):
_space: ActionSpace | None = None # type: ignore

@override
@property
def space(self) -> ActionSpace:
raise ValueError("space is not available for pre-observations")
2 changes: 1 addition & 1 deletion notte/common/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:
start_time = time.time()
result = await func(*args, **kwargs)
end_time = time.time()
logger.info(f"Function {name} took {end_time - start_time:.4f} seconds")
logger.info(f"function {name} took {end_time - start_time:.4f} seconds")
return result

return wrapper # type: ignore
Expand Down
10 changes: 6 additions & 4 deletions notte/common/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pydantic import BaseModel
from typing_extensions import override

from notte.browser.context import Observation
from notte.browser.observation import Observation


class EnvObserveParams(BaseModel):
Expand Down Expand Up @@ -89,7 +89,7 @@ def step(self, text: str) -> EnvStepParams:
if "action-id" not in d:
raise ValueError("No action-id found in action")
action_id = d["action-id"]
params = d["params"] if d["params"] is not None else None
params = d.get("params", None)
return EnvStepParams(action_id=action_id, params=params)

@override
Expand All @@ -104,11 +104,11 @@ def textify(self, obs: Observation) -> str:
s = """
The current URL is: {{url}}
Here are the available actions:
\n{{actions}}
{{actions}}
\n Now think about your current trajectory, and decide what action to take next.
You might need to perform some intermediate actions so be very careful, dont jump to conclusions too quickly.
\nProvide me with the ID of the action you want to take next.
You are allowed to take only exactly ONE action from the list.
You are allowed to take only exactly ONE action from this list (not previous lists)!
If the action is parameterized, provide the value for each parameter.
Use the exact following format:
<action>
Expand All @@ -123,6 +123,8 @@ def textify(self, obs: Observation) -> str:
* You are allowed to take only exactly ONE action from the list.
* Your action should be inside the <action> tag.
* If you're unable to pursue your goal, just say <done/>. Nothing else!
* You are ONLY allowed to pick actions from the latest list of actions!
* You are NOT allowed to pick actions from list of actions in previous messages!c
\n You are allowed to use <url> to navigate to a different url.
"""
return chevron.render(s, {"url": obs.url, "actions": obs.space.markdown("valid")})
86 changes: 43 additions & 43 deletions notte/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

from notte.actions.base import Action, ActionParameterValue
from notte.actions.code import process_action_code
from notte.actions.space import ActionSpace
from notte.browser.context import Context, Observation
from notte.browser.context import Context
from notte.browser.driver import BrowserArgs, BrowserDriver
from notte.browser.observation import Observation, PreObservation
from notte.browser.snapshot import BrowserSnapshot
from notte.common.logging import timeit
from notte.common.parser import BaseNotteParser, Parser
Expand Down Expand Up @@ -53,7 +53,6 @@ def __init__(
self._trajectory: list[Observation] = trajectory or []
self._parser: Parser = parser or BaseNotteParser()
self._context: Context | None = None
self._action_space: ActionSpace | None = None
self._context_to_action_space_pipe: ContextToActionSpacePipe = ContextToActionSpacePipe(
llmserve=llmserve,
)
Expand All @@ -65,58 +64,60 @@ def context(self) -> Context:
return self._context

@property
def list_actions(self) -> list[Action] | None:
if self._action_space is None:
def previous_actions(self) -> list[Action] | None:
# This function is always called after trajectory.append(preobs)
# —This means trajectory[-1] is always the "current (pre)observation"
# And trajectory[-2] is the "previous observation" we're interested in.
if len(self._trajectory) <= 1:
return None
if len(self._trajectory) >= 2 and self._trajectory[-1].clean_url != self._trajectory[-2].clean_url:
# If the last two observations are not on the same page, the last action space is invalid.
return None
return self._action_space.actions(status="all")
previous_obs: Observation = self._trajectory[-2]
if isinstance(previous_obs, PreObservation):
return None # we don't have a space for pre-observations
if self.context.snapshot.clean_url != previous_obs.clean_url:
return None # the page has significantly changed
return previous_obs.space.actions(status="all")

# ---------------------------- observe, step functions ----------------------------

async def _observe(self, snapshot: BrowserSnapshot) -> Observation:
def _preobserve(self, snapshot: BrowserSnapshot) -> PreObservation:
self._context = BrowserSnapshotToContextPipe.forward(snapshot)
obs = Observation(url=snapshot.url, screenshot=snapshot.screenshot, space=None)
self._trajectory.append(obs)
preobs = PreObservation(_url=snapshot.url, _screenshot=snapshot.screenshot, _space=None)
self._trajectory.append(preobs)
return preobs

def _obslisting(self, preobs: PreObservation) -> Observation:
space = self._context_to_action_space_pipe.forward(self.context, self.previous_actions)
obs = Observation(_url=preobs.url, _screenshot=preobs.screenshot, _space=space)
self._trajectory[-1] = obs # update the last observation with the new space
return obs

@timeit("goto")
async def goto(self, url: str) -> Observation:
async def goto(self, url: str) -> PreObservation:
snapshot = await self._browser.goto(url)
self._action_space = None
return await self._observe(snapshot)
obs = self._preobserve(snapshot)
return obs

@timeit("observe")
async def observe(self, url: str) -> Observation:
snapshot = await self._browser.goto(url)
obs = await self._observe(snapshot)
self._action_space = await self._context_to_action_space_pipe.forward(self.context, self.list_actions)
obs.space = self._action_space
return obs

async def _execute(
self,
action_id: str,
params: dict[str, str] | str | None = None,
enter: bool | None = None,
) -> Observation:
if self._context is None:
raise ValueError("Need to observe first to get a context.")
action, _params = self._parse_env(action_id, params)
enter = enter if enter is not None else action.id.startswith("I")
snapshot = await ExecutionPipe.forward(action, _params, self._context, self._browser, enter=enter)
return await self._observe(snapshot)
preobs = await self.goto(url)
logger.debug(f"ℹ️ previous actions IDs: {[a.id for a in self.previous_actions or []]}")
logger.debug(f"ℹ️ context inodes IDs: {[node.id for node in self.context.interaction_nodes()]}")
return self._obslisting(preobs)

@timeit("execute")
async def execute(
self,
action_id: str,
params: dict[str, str] | str | None = None,
enter: bool | None = None,
) -> Observation:
self._action_space = None
return await self._execute(action_id, params, enter=enter)
) -> PreObservation:
if action_id not in [inode.id for inode in self.context.interaction_nodes()]:
raise ValueError(f"action {action_id} not found in context")
action, _params = self._parse_env(action_id, params)
enter = enter if enter is not None else action.id.startswith("I")
snapshot = await ExecutionPipe.forward(action, _params, self.context, self._browser, enter=enter)
logger.info(f"🌌 action {action_id} executed in browser")
return self._preobserve(snapshot)

@timeit("step")
async def step(
Expand All @@ -125,17 +126,16 @@ async def step(
params: dict[str, str] | str | None = None,
enter: bool | None = None,
) -> Observation:
obs = await self._execute(action_id, params, enter=enter)
self._action_space = await self._context_to_action_space_pipe.forward(self.context, self.list_actions)
obs.space = self._action_space
return obs
preobs = await self.execute(action_id, params, enter=enter)
logger.debug(f"ℹ️ previous actions IDs: {[a.id for a in self.previous_actions or []]}")
logger.debug(f"ℹ️ context inodes IDs: {[node.id for node in self.context.interaction_nodes()]}")
return self._obslisting(preobs)

@timeit("reset")
async def reset(self, url: str) -> Observation:
async def reset(self, url: str) -> PreObservation:
self._trajectory = []
self._context = None
self._action_space = None
return await self.observe(url)
return await self.goto(url)

# ---------------------------- conversational environment ----------------------------

Expand Down
5 changes: 2 additions & 3 deletions notte/pipe/filtering.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
from typing import final

from notte.actions.base import Action
from notte.actions.space import ActionSpace
from notte.browser.context import Context


@final
class ActionFilteringPipe:

@staticmethod
def forward(context: Context, actions: list[Action]) -> ActionSpace:
def forward(context: Context, actions: list[Action]) -> list[Action]:
for action in actions:
if ActionFilteringPipe.exclude_actions_with_invalid_params(action):
action.status = "excluded"
if ActionFilteringPipe.exclude_actions_with_invalid_category(action):
action.status = "excluded"
if ActionFilteringPipe.exclude_actions_with_invalid_description(action):
action.status = "excluded"
return ActionSpace(_actions=actions)
return actions

@staticmethod
def exclude_actions_with_invalid_params(action: Action) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion notte/pipe/listing.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def forward(self, context: Context, previous_action_list: list[Action] | None =
try:
return self.pipe.forward(context, previous_action_list)
except Exception as e:
logger.error("Failed to get action but retrying...")
logger.warning("failed to parse action list but retrying...")
errors.append(str(e))
raise Exception(f"Failed to get action list after max tries with errors: {errors}")

Expand Down
Loading

0 comments on commit 6f57367

Please sign in to comment.