Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved callback_context #1952

Merged
merged 16 commits into from
Apr 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ This project adheres to [Semantic Versioning](https://semver.org/).
## [Unreleased]

### Added
- [#1952](https://github.com/plotly/dash/pull/1952) Improved callback_context
- Closes [#1818](https://github.com/plotly/dash/issues/1818) Closes [#1054](https://github.com/plotly/dash/issues/1054)
- adds `dash.ctx`, a more concise name for `dash.callback_context`
- adds `ctx.triggered_prop_ids`, a dictionary of the component ids and props that triggered the callback.
- adds `ctx.triggered_id`, the `id` of the component that triggered the callback.
- adds `ctx.args_grouping`, a dict of the inputs used with flexible callback signatures.

- [#2009](https://github.com/plotly/dash/pull/2009) Add support for Promises within Client-side callbacks as requested in [#1364](https://github.com/plotly/dash/pull/1364).

Expand Down
1 change: 1 addition & 0 deletions dash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@
get_relative_path,
strip_relative_path,
)
ctx = callback_context
148 changes: 146 additions & 2 deletions dash/_callback_context.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import functools
import warnings

import json
from copy import deepcopy
import flask

from . import exceptions
from ._utils import stringify_id, AttributeDict


def has_context(func):
Expand Down Expand Up @@ -46,16 +48,158 @@ def states(self):
@property
@has_context
def triggered(self):
"""
Returns a list of all the Input props that changed and caused the callback to execute. It is empty when the
callback is called on initial load, unless an Input prop got its value from another initial callback.
Callbacks triggered by user actions typically have one item in triggered, unless the same action changes
two props at once or the callback has several Input props that are all modified by another callback based on
a single user action.

Example: To get the id of the component that triggered the callback:
`component_id = ctx.triggered[0]['prop_id'].split('.')[0]`

Example: To detect initial call, empty triggered is not really empty, it's falsy so that you can use:
`if ctx.triggered:`
"""
# For backward compatibility: previously `triggered` always had a
# value - to avoid breaking existing apps, add a dummy item but
# make the list still look falsy. So `if ctx.triggered` will make it
# look empty, but you can still do `triggered[0]["prop_id"].split(".")`
return getattr(flask.g, "triggered_inputs", []) or falsy_triggered

@property
@has_context
def triggered_prop_ids(self):
"""
Returns a dictionary of all the Input props that changed and caused the callback to execute. It is empty when
the callback is called on initial load, unless an Input prop got its value from another initial callback.
Callbacks triggered by user actions typically have one item in triggered, unless the same action changes
two props at once or the callback has several Input props that are all modified by another callback based
on a single user action.

triggered_prop_ids (dict):
- keys (str) : the triggered "prop_id" composed of "component_id.component_property"
- values (str or dict): the id of the component that triggered the callback. Will be the dict id for pattern matching callbacks

Example - regular callback
{"btn-1.n_clicks": "btn-1"}

Example - pattern matching callbacks:
{'{"index":0,"type":"filter-dropdown"}.value': {"index":0,"type":"filter-dropdown"}}

Example usage:
`if "btn-1.n_clicks" in ctx.triggered_prop_ids:
do_something()`
"""
triggered = getattr(flask.g, "triggered_inputs", [])
ids = AttributeDict({})
for item in triggered:
component_id, _, _ = item["prop_id"].rpartition(".")
ids[item["prop_id"]] = component_id
if component_id.startswith("{"):
ids[item["prop_id"]] = AttributeDict(json.loads(component_id))
return ids

@property
@has_context
def triggered_id(self):
"""
Returns the component id (str or dict) of the Input component that triggered the callback.

Note - use `triggered_prop_ids` if you need both the component id and the prop that triggered the callback or if
multiple Inputs triggered the callback.

Example usage:
`if "btn-1" == ctx.triggered_id:
do_something()`

"""
component_id = None
if self.triggered:
prop_id = self.triggered_prop_ids.first()
component_id = self.triggered_prop_ids[prop_id]
return component_id

@property
@has_context
def args_grouping(self):
return getattr(flask.g, "args_grouping", [])
"""
args_grouping is a dict of the inputs used with flexible callback signatures. The keys are the variable names
and the values are dictionaries containing:
- “id”: (string or dict) the component id. If it’s a pattern matching id, it will be a dict.
- “id_str”: (str) for pattern matching ids, it’s the strigified dict id with no white spaces.
- “property”: (str) The component property used in the callback.
- “value”: the value of the component property at the time the callback was fired.
- “triggered”: (bool)Whether this input triggered the callback.

Example usage:
@app.callback(
Output("container", "children"),
inputs=dict(btn1=Input("btn-1", "n_clicks"), btn2=Input("btn-2", "n_clicks")),
)
def display(btn1, btn2):
c = ctx.args_grouping
if c.btn1.triggered:
return f"Button 1 clicked {btn1} times"
elif c.btn2.triggered:
return f"Button 2 clicked {btn2} times"
else:
return "No clicks yet"

"""
triggered = getattr(flask.g, "triggered_inputs", [])
triggered = [item["prop_id"] for item in triggered]
grouping = getattr(flask.g, "args_grouping", {})

def update_args_grouping(g):
if isinstance(g, dict) and "id" in g:
str_id = stringify_id(g["id"])
prop_id = f"{str_id}.{g['property']}"

new_values = {
"value": g.get("value"),
"str_id": str_id,
"triggered": prop_id in triggered,
"id": AttributeDict(g["id"])
if isinstance(g["id"], dict)
else g["id"],
}
g.update(new_values)

def recursive_update(g):
if isinstance(g, (tuple, list)):
for i in g:
update_args_grouping(i)
recursive_update(i)
if isinstance(g, dict):
for i in g.values():
update_args_grouping(i)
recursive_update(i)

recursive_update(grouping)

return grouping

# todo not sure whether we need this, but it removes a level of nesting so
# you don't need to use `.value` to get the value.
@property
@has_context
def args_grouping_values(self):
grouping = getattr(flask.g, "args_grouping", {})
grouping = deepcopy(grouping)

def recursive_update(g):
if isinstance(g, (tuple, list)):
for i in g:
recursive_update(i)
if isinstance(g, dict):
for k, v in g.items():
if isinstance(v, dict) and "id" in v:
g[k] = v["value"]
recursive_update(v)

recursive_update(grouping)
return grouping

@property
@has_context
Expand Down
5 changes: 3 additions & 2 deletions dash/_grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""
from dash.exceptions import InvalidCallbackReturnValue
from ._utils import AttributeDict


def flatten_grouping(grouping, schema=None):
Expand Down Expand Up @@ -123,14 +124,14 @@ def map_grouping(fn, grouping):
return [map_grouping(fn, g) for g in grouping]

if isinstance(grouping, dict):
return {k: map_grouping(fn, g) for k, g in grouping.items()}
return AttributeDict({k: map_grouping(fn, g) for k, g in grouping.items()})

return fn(grouping)


def make_grouping_by_key(schema, source, default=None):
"""
Create a grouping from a schema by ujsing the schema's scalar values to look up
Create a grouping from a schema by using the schema's scalar values to look up
items in the provided source object.

:param schema: A grouping of potential keys in source
Expand Down
14 changes: 13 additions & 1 deletion dash/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ def first(self, *names):
value = self.get(name)
if value:
return value
if not names:
return next(iter(self), {})


def create_callback_id(output):
Expand Down Expand Up @@ -152,7 +154,7 @@ def stringify_id(id_):


def inputs_to_dict(inputs_list):
inputs = {}
inputs = AttributeDict()
for i in inputs_list:
inputsi = i if isinstance(i, list) else [i]
for ii in inputsi:
Expand All @@ -161,6 +163,16 @@ def inputs_to_dict(inputs_list):
return inputs


def convert_to_AttributeDict(nested_list):
new_dict = []
for i in nested_list:
if isinstance(i, dict):
new_dict.append(AttributeDict(i))
else:
new_dict.append([AttributeDict(ii) for ii in i])
return new_dict


def inputs_to_vals(inputs):
return [
[ii.get("value") for ii in i] if isinstance(i, list) else i.get("value")
Expand Down
5 changes: 5 additions & 0 deletions dash/_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,11 @@ def validate_and_group_input_args(flat_args, arg_index_grouping):
if isinstance(arg_index_grouping, dict):
func_args = []
func_kwargs = args_grouping
for key in func_kwargs:
if not key.isidentifier():
raise exceptions.CallbackException(
f"{key} is not a valid Python variable name"
)
elif isinstance(arg_index_grouping, (tuple, list)):
func_args = list(args_grouping)
func_kwargs = {}
Expand Down
5 changes: 5 additions & 0 deletions dash/dash.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
patch_collections_abc,
split_callback_id,
to_json,
convert_to_AttributeDict,
gen_salt,
)
from . import _callback
Expand Down Expand Up @@ -1297,6 +1298,7 @@ def callback(_triggers, user_store_data, user_callback_args):

def dispatch(self):
body = flask.request.get_json()

flask.g.inputs_list = inputs = body.get( # pylint: disable=assigning-non-slot
"inputs", []
)
Expand Down Expand Up @@ -1331,9 +1333,12 @@ def dispatch(self):
# Add args_grouping
inputs_state_indices = cb["inputs_state_indices"]
inputs_state = inputs + state
inputs_state = convert_to_AttributeDict(inputs_state)

args_grouping = map_grouping(
lambda ind: inputs_state[ind], inputs_state_indices
)

flask.g.args_grouping = args_grouping # pylint: disable=assigning-non-slot
flask.g.using_args_grouping = ( # pylint: disable=assigning-non-slot
not isinstance(inputs_state_indices, int)
Expand Down
58 changes: 57 additions & 1 deletion tests/integration/callbacks/test_callback_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import operator
import pytest

from dash import Dash, Input, Output, html, dcc, callback_context
from dash import Dash, ALL, Input, Output, html, dcc, callback_context, ctx

from dash.exceptions import PreventUpdate, MissingCallbackContextException
import dash.testing.wait as wait
Expand Down Expand Up @@ -330,3 +330,59 @@ def update_results(n1, n2, nsum):
assert len(keys1) == 2
assert "sum-number.value" in keys1
assert "input-number-2.value" in keys1


def test_cbcx007_triggered_id(dash_duo):
app = Dash(__name__)

btns = ["btn-{}".format(x) for x in range(1, 6)]

app.layout = html.Div(
[html.Div([html.Button(btn, id=btn) for btn in btns]), html.Div(id="output")]
)

@app.callback(Output("output", "children"), [Input(x, "n_clicks") for x in btns])
def on_click(*args):
if not ctx.triggered:
raise PreventUpdate
for btn in btns:
if btn in ctx.triggered_prop_ids.values():
assert btn == ctx.triggered_id
return f"Just clicked {btn}"

dash_duo.start_server(app)

for i in range(1, 5):
for btn in btns:
dash_duo.find_element("#" + btn).click()
dash_duo.wait_for_text_to_equal("#output", f"Just clicked {btn}")


def test_cbcx008_triggered_id_pmc(dash_duo):

app = Dash()
app.layout = html.Div(
[
html.Button("Click me", id={"type": "btn", "index": "myindex"}),
html.Div(id="output"),
]
)

@app.callback(
Output("output", "children"), Input({"type": "btn", "index": ALL}, "n_clicks")
)
def func(n_clicks):
if ctx.triggered:
triggered_id, dict_id = next(iter(ctx.triggered_prop_ids.items()))

assert dict_id == ctx.triggered_id

if dict_id == {"type": "btn", "index": "myindex"}:
return dict_id["index"]

dash_duo.start_server(app)

dash_duo.find_element(
'#\\{\\"index\\"\\:\\"myindex\\"\\,\\"type\\"\\:\\"btn\\"\\}'
).click()
dash_duo.wait_for_text_to_equal("#output", "myindex")
Loading