Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REF-2273] Implement .setvar special EventHandler #3163

Merged
merged 6 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions reflex/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def _create_event_chain(
if isinstance(value, List):
events: list[EventSpec] = []
for v in value:
if isinstance(v, EventHandler):
if isinstance(v, (EventHandler, EventSpec)):
# Call the event handler to get the event.
try:
event = call_event_handler(v, args_spec)
Expand All @@ -415,9 +415,6 @@ def _create_event_chain(

# Add the event to the chain.
events.append(event)
elif isinstance(v, EventSpec):
# Add the event to the chain.
events.append(v)
elif isinstance(v, Callable):
# Call the lambda to get the event chain.
events.extend(call_event_fn(v, args_spec))
Expand Down
69 changes: 43 additions & 26 deletions reflex/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from reflex import constants
from reflex.base import Base
from reflex.utils import console, format
from reflex.utils import format
from reflex.utils.types import ArgsSpec
from reflex.vars import BaseVar, Var

Expand Down Expand Up @@ -142,7 +142,7 @@ def is_background(self) -> bool:
"""
return getattr(self.fn, BACKGROUND_TASK_MARKER, False)

def __call__(self, *args: Var) -> EventSpec:
def __call__(self, *args: Any) -> EventSpec:
"""Pass arguments to the handler to get an event spec.

This method configures event handlers that take in arguments.
Expand Down Expand Up @@ -220,6 +220,34 @@ def with_args(self, args: Tuple[Tuple[Var, Var], ...]) -> EventSpec:
event_actions=self.event_actions.copy(),
)

def add_args(self, *args: Var) -> EventSpec:
"""Add arguments to the event spec.

Args:
*args: The arguments to add positionally.

Returns:
The event spec with the new arguments.

Raises:
TypeError: If the arguments are invalid.
"""
# Get the remaining unfilled function args.
fn_args = inspect.getfullargspec(self.handler.fn).args[1 + len(self.args) :]
fn_args = (Var.create_safe(arg) for arg in fn_args)

# Construct the payload.
values = []
for arg in args:
try:
values.append(Var.create(arg, _var_is_string=isinstance(arg, str)))
except TypeError as e:
raise TypeError(
f"Arguments to event handlers must be Vars or JSON-serializable. Got {arg} of type {type(arg)}."
) from e
new_payload = tuple(zip(fn_args, values))
return self.with_args(self.args + new_payload)


class CallableEventSpec(EventSpec):
"""Decorate an EventSpec-returning function to act as both a EventSpec and a function.
Expand Down Expand Up @@ -706,7 +734,8 @@ def get_hydrate_event(state) -> str:


def call_event_handler(
event_handler: EventHandler, arg_spec: Union[Var, ArgsSpec]
event_handler: EventHandler | EventSpec,
arg_spec: ArgsSpec,
) -> EventSpec:
"""Call an event handler to get the event spec.

Expand All @@ -724,33 +753,21 @@ def call_event_handler(
Returns:
The event spec from calling the event handler.
"""
args = inspect.getfullargspec(event_handler.fn).args
parsed_args = parse_args_spec(arg_spec) # type: ignore

# handle new API using lambda to define triggers
if isinstance(arg_spec, ArgsSpec):
parsed_args = parse_args_spec(arg_spec) # type: ignore
if isinstance(event_handler, EventSpec):
# Handle partial application of EventSpec args
return event_handler.add_args(*parsed_args)

if len(args) == len(["self", *parsed_args]):
return event_handler(*parsed_args) # type: ignore
else:
source = inspect.getsource(arg_spec) # type: ignore
raise ValueError(
f"number of arguments in {event_handler.fn.__qualname__} "
f"doesn't match the definition of the event trigger '{source.strip().strip(',')}'"
)
args = inspect.getfullargspec(event_handler.fn).args
if len(args) == len(["self", *parsed_args]):
return event_handler(*parsed_args) # type: ignore
else:
console.deprecate(
feature_name="EVENT_ARG API for triggers",
reason="Replaced by new API using lambda allow arbitrary number of args",
deprecation_version="0.2.8",
removal_version="0.5.0",
source = inspect.getsource(arg_spec) # type: ignore
raise ValueError(
f"number of arguments in {event_handler.fn.__qualname__} "
f"doesn't match the definition of the event trigger '{source.strip().strip(',')}'"
)
if len(args) == 1:
return event_handler()
assert (
len(args) == 2
), f"Event handler {event_handler.fn} must have 1 or 2 arguments."
return event_handler(arg_spec) # type: ignore


def parse_args_spec(arg_spec: ArgsSpec):
Expand Down
68 changes: 68 additions & 0 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,60 @@ def _split_substate_key(substate_key: str) -> tuple[str, str]:
return token, state_name


class EventHandlerSetVar(EventHandler):
"""A special event handler to wrap setvar functionality."""

state_cls: Type[BaseState]

def __init__(self, state_cls: Type[BaseState]):
"""Initialize the EventHandlerSetVar.

Args:
state_cls: The state class that vars will be set on.
"""
super().__init__(
fn=type(self).setvar,
state_full_name=state_cls.get_full_name(),
state_cls=state_cls, # type: ignore
)

def setvar(self, var_name: str, value: Any):
"""Set the state variable to the value of the event.

Note: `self` here will be an instance of the state, not EventHandlerSetVar.

Args:
var_name: The name of the variable to set.
value: The value to set the variable to.
"""
getattr(self, constants.SETTER_PREFIX + var_name)(value)

def __call__(self, *args: Any) -> EventSpec:
"""Performs pre-checks and munging on the provided args that will become an EventSpec.

Args:
*args: The event args.

Returns:
The (partial) EventSpec that will be used to create the event to setvar.

Raises:
AttributeError: If the given Var name does not exist on the state.
ValueError: If the given Var name is not a str
"""
if args:
if not isinstance(args[0], str):
raise ValueError(
f"Var name must be passed as a string, got {args[0]!r}"
)
# Check that the requested Var setter exists on the State at compile time.
if getattr(self.state_cls, constants.SETTER_PREFIX + args[0], None) is None:
raise AttributeError(
f"Variable `{args[0]}` cannot be set on `{self.state_cls.get_full_name()}`"
)
return super().__call__(*args)


class BaseState(Base, ABC, extra=pydantic.Extra.allow):
"""The state of the app."""

Expand Down Expand Up @@ -310,6 +364,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# Whether the state has ever been touched since instantiation.
_was_touched: bool = False

# A special event handler for setting base vars.
setvar: ClassVar[EventHandler]

def __init__(
self,
*args,
Expand Down Expand Up @@ -500,6 +557,9 @@ def __init_subclass__(cls, **kwargs):
value.__qualname__ = f"{cls.__name__}.{name}"
events[name] = value

# Create the setvar event handler for this state
cls._create_setvar()

for name, fn in events.items():
handler = cls._create_event_handler(fn)
cls.event_handlers[name] = handler
Expand Down Expand Up @@ -833,6 +893,11 @@ def _create_event_handler(cls, fn):
"""
return EventHandler(fn=fn, state_full_name=cls.get_full_name())

@classmethod
def _create_setvar(cls):
"""Create the setvar method for the state."""
cls.setvar = cls.event_handlers["setvar"] = EventHandlerSetVar(state_cls=cls)

@classmethod
def _create_setter(cls, prop: BaseVar):
"""Create a setter for the var.
Expand Down Expand Up @@ -1800,6 +1865,9 @@ def __getstate__(self):
return state


EventHandlerSetVar.update_forward_refs()


class State(BaseState):
"""The app Base State."""

Expand Down
37 changes: 36 additions & 1 deletion tests/test_event.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json
from typing import List

import pytest

from reflex import event
from reflex.event import Event, EventHandler, EventSpec, fix_events
from reflex.event import Event, EventHandler, EventSpec, call_event_handler, fix_events
from reflex.state import BaseState
from reflex.utils import format
from reflex.vars import Var
Expand Down Expand Up @@ -91,6 +92,40 @@ def test_fn_with_args(_, arg1, arg2):
handler(test_fn) # type: ignore


def test_call_event_handler_partial():
"""Calling an EventHandler with incomplete args returns an EventSpec that can be extended."""

def test_fn_with_args(_, arg1, arg2):
pass

test_fn_with_args.__qualname__ = "test_fn_with_args"

def spec(a2: str) -> List[str]:
return [a2]

handler = EventHandler(fn=test_fn_with_args)
event_spec = handler(make_var("first"))
event_spec2 = call_event_handler(event_spec, spec)

assert event_spec.handler == handler
assert len(event_spec.args) == 1
assert event_spec.args[0][0].equals(Var.create_safe("arg1"))
assert event_spec.args[0][1].equals(Var.create_safe("first"))
assert format.format_event(event_spec) == 'Event("test_fn_with_args", {arg1:first})'

assert event_spec2 is not event_spec
assert event_spec2.handler == handler
assert len(event_spec2.args) == 2
assert event_spec2.args[0][0].equals(Var.create_safe("arg1"))
assert event_spec2.args[0][1].equals(Var.create_safe("first"))
assert event_spec2.args[1][0].equals(Var.create_safe("arg2"))
assert event_spec2.args[1][1].equals(Var.create_safe("_a2"))
assert (
format.format_event(event_spec2)
== 'Event("test_fn_with_args", {arg1:first,arg2:_a2})'
)


@pytest.mark.parametrize(
("arg1", "arg2"),
(
Expand Down
38 changes: 38 additions & 0 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2845,3 +2845,41 @@ def bar(self) -> str:
assert RxState._potentially_dirty_substates() == {State}
assert State._potentially_dirty_substates() == {C1}
assert C1._potentially_dirty_substates() == set()


@pytest.mark.asyncio
async def test_setvar(mock_app: rx.App, token: str):
"""Test that setvar works correctly.

Args:
mock_app: An app that will be returned by `get_app()`
token: A token.
"""
state = await mock_app.state_manager.get_state(_substate_key(token, TestState))

# Set Var in same state (with Var type casting)
for event in rx.event.fix_events(
[TestState.setvar("num1", 42), TestState.setvar("num2", "4.2")], token
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the way to set the num2 to a string value of "4.2" here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num2 is typed as a float, you can't (shouldn't) set it to a str

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah, should've looked at the hidden part of the code. That's what the type casting is. Sounds good.

):
async for update in state._process(event):
print(update)
assert state.num1 == 42
assert state.num2 == 4.2

# Set Var in parent state
for event in rx.event.fix_events([GrandchildState.setvar("array", [43])], token):
async for update in state._process(event):
print(update)
assert state.array == [43]

# Cannot setvar for non-existant var
with pytest.raises(AttributeError):
TestState.setvar("non_existant_var")

# Cannot setvar for computed vars
with pytest.raises(AttributeError):
TestState.setvar("sum")

# Cannot setvar with non-string
with pytest.raises(ValueError):
TestState.setvar(42, 42)
Loading