Skip to content

Commit

Permalink
Trigger state change event on iterators (#9299)
Browse files Browse the repository at this point in the history
* Fix render async

* add changeset

* Fix regression

* tests

* Add code

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
  • Loading branch information
3 people authored Sep 19, 2024
1 parent 4be0933 commit aa35b07
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 3 deletions.
7 changes: 7 additions & 0 deletions .changeset/deep-memes-cheat.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"@gradio/client": patch
"@gradio/core": patch
"gradio": patch
---

fix:Trigger state change event on iterators
5 changes: 4 additions & 1 deletion client/js/src/helpers/api_info.ts
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,10 @@ export function handle_message(
stage: data.success ? "generating" : "error",
code: data.code,
progress_data: data.progress_data,
eta: data.average_duration
eta: data.average_duration,
changed_state_ids: data.success
? data.output.changed_state_ids
: undefined
},
data: data.success ? data.output : null
};
Expand Down
2 changes: 1 addition & 1 deletion demo/state_change/run.ipynb
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: state_change"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "with gr.Blocks() as demo:\n", "\n", " with gr.Row():\n", " state_a = gr.State(0)\n", " btn_a = gr.Button(\"Increment A\")\n", " value_a = gr.Number(label=\"Number A\")\n", " btn_a.click(lambda x: x + 1, state_a, state_a)\n", " state_a.change(lambda x: x, state_a, value_a)\n", " with gr.Row():\n", " state_b = gr.State(0)\n", " btn_b = gr.Button(\"Increment B\")\n", " value_b = gr.Number(label=\"Number B\")\n", " btn_b.click(lambda x: x + 1, state_b, state_b)\n", "\n", " @gr.on(inputs=state_b, outputs=value_b)\n", " def identity(x):\n", " return x\n", "\n", " @gr.render(inputs=[state_a, state_b])\n", " def render(a, b):\n", " for x in range(a):\n", " with gr.Row():\n", " for y in range(b):\n", " gr.Button(f\"Button {x}, {y}\")\n", "\n", " list_state = gr.State([])\n", " dict_state = gr.State(dict())\n", " nested_list_state = gr.State([])\n", " set_state = gr.State(set())\n", "\n", " def transform_list(x):\n", " return {n: n for n in x}, [x[:] for _ in range(len(x))], set(x)\n", "\n", " list_state.change(\n", " transform_list,\n", " inputs=list_state,\n", " outputs=[dict_state, nested_list_state, set_state],\n", " )\n", "\n", " all_textbox = gr.Textbox(label=\"Output\")\n", " click_count = gr.Number(label=\"Clicks\")\n", " change_count = gr.Number(label=\"Changes\")\n", " gr.on(\n", " inputs=[change_count, dict_state, nested_list_state, set_state],\n", " triggers=[dict_state.change, nested_list_state.change, set_state.change],\n", " fn=lambda x, *args: (x + 1, \"\\n\".join(str(arg) for arg in args)),\n", " outputs=[change_count, all_textbox],\n", " )\n", "\n", " count_to_3_btn = gr.Button(\"Count to 3\")\n", " count_to_3_btn.click(lambda: [1, 2, 3], outputs=list_state)\n", " zero_all_btn = gr.Button(\"Zero All\")\n", " zero_all_btn.click(lambda x: [0] * len(x), inputs=list_state, outputs=list_state)\n", "\n", " gr.on(\n", " [count_to_3_btn.click, zero_all_btn.click],\n", " lambda x: x + 1,\n", " click_count,\n", " click_count,\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: state_change"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "with gr.Blocks() as demo:\n", "\n", " with gr.Row():\n", " state_a = gr.State(0)\n", " btn_a = gr.Button(\"Increment A\")\n", " value_a = gr.Number(label=\"Number A\")\n", " btn_a.click(lambda x: x + 1, state_a, state_a)\n", " state_a.change(lambda x: x, state_a, value_a)\n", " with gr.Row():\n", " state_b = gr.State(0)\n", " btn_b = gr.Button(\"Increment B\")\n", " value_b = gr.Number(label=\"Number B\")\n", " btn_b.click(lambda x: x + 1, state_b, state_b)\n", "\n", " @gr.on(inputs=state_b, outputs=value_b)\n", " def identity(x):\n", " return x\n", "\n", " @gr.render(inputs=[state_a, state_b])\n", " def render(a, b):\n", " for x in range(a):\n", " with gr.Row():\n", " for y in range(b):\n", " gr.Button(f\"Button {x}, {y}\")\n", "\n", " list_state = gr.State([])\n", " dict_state = gr.State(dict())\n", " nested_list_state = gr.State([])\n", " set_state = gr.State(set())\n", "\n", " def transform_list(x):\n", " return {n: n for n in x}, [x[:] for _ in range(len(x))], set(x)\n", "\n", " list_state.change(\n", " transform_list,\n", " inputs=list_state,\n", " outputs=[dict_state, nested_list_state, set_state],\n", " )\n", "\n", " all_textbox = gr.Textbox(label=\"Output\")\n", " click_count = gr.Number(label=\"Clicks\")\n", " change_count = gr.Number(label=\"Changes\")\n", " gr.on(\n", " inputs=[change_count, dict_state, nested_list_state, set_state],\n", " triggers=[dict_state.change, nested_list_state.change, set_state.change],\n", " fn=lambda x, *args: (x + 1, \"\\n\".join(str(arg) for arg in args)),\n", " outputs=[change_count, all_textbox],\n", " )\n", "\n", " count_to_3_btn = gr.Button(\"Count to 3\")\n", " count_to_3_btn.click(lambda: [1, 2, 3], outputs=list_state)\n", " zero_all_btn = gr.Button(\"Zero All\")\n", " zero_all_btn.click(lambda x: [0] * len(x), inputs=list_state, outputs=list_state)\n", "\n", " gr.on(\n", " [count_to_3_btn.click, zero_all_btn.click],\n", " lambda x: x + 1,\n", " click_count,\n", " click_count,\n", " )\n", "\n", " async def increment(x):\n", " yield x + 1\n", "\n", " n_text = gr.State(0)\n", " add_btn = gr.Button(\"Iterator State Change\")\n", " add_btn.click(increment, n_text, n_text)\n", "\n", " @gr.render(inputs=n_text)\n", " def render_count(count):\n", " for i in range(int(count)):\n", " gr.Markdown(value = f\"Success Box {i} added\", key=i)\n", "\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
13 changes: 13 additions & 0 deletions demo/state_change/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,18 @@ def transform_list(x):
click_count,
)

async def increment(x):
yield x + 1

n_text = gr.State(0)
add_btn = gr.Button("Iterator State Change")
add_btn.click(increment, n_text, n_text)

@gr.render(inputs=n_text)
def render_count(count):
for i in range(int(count)):
gr.Markdown(value = f"Success Box {i} added", key=i)


if __name__ == "__main__":
demo.launch()
6 changes: 5 additions & 1 deletion js/core/src/Blocks.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@
}
}
/* eslint-disable complexity */
function handle_status_update(message: StatusMessage): void {
const { fn_index, ...status } = message;
if (status.stage === "streaming" && status.time_limit) {
Expand Down Expand Up @@ -474,14 +475,16 @@
];
}
if (status.stage === "complete") {
if (status.stage === "complete" || status.stage === "generating") {
status.changed_state_ids?.forEach((id) => {
dependencies
.filter((dep) => dep.targets.some(([_id, _]) => _id === id))
.forEach((dep) => {
wait_then_trigger_api_call(dep.id, payload.trigger_id);
});
});
}
if (status.stage === "complete") {
dependencies.forEach(async (dep) => {
if (dep.trigger_after === fn_index) {
wait_then_trigger_api_call(dep.id, payload.trigger_id);
Expand Down Expand Up @@ -530,6 +533,7 @@
}
}
}
/* eslint-enable complexity */
function trigger_share(title: string | undefined, description: string): void {
if (space_id === null) {
Expand Down
11 changes: 11 additions & 0 deletions js/spa/test/state_change.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,14 @@ test("test datastructure-based state changes", async ({ page }) => {
await expect(page.getByLabel("Changes")).toHaveValue("2");
await expect(page.getByLabel("Clicks")).toHaveValue("5");
});

test("test generators properly trigger state changes", async ({ page }) => {
await page.getByRole("button", { name: "Iterator State Change" }).click();
await expect(page.getByTestId("markdown").first()).toHaveText(
"Success Box 0 added"
);
await page.getByRole("button", { name: "Iterator State Change" }).click();
await expect(page.getByTestId("markdown").nth(1)).toHaveText(
"Success Box 1 added"
);
});

0 comments on commit aa35b07

Please sign in to comment.