From aa35b0788e613fdd45446d267513e6f94fa208ea Mon Sep 17 00:00:00 2001 From: Freddy Boulton Date: Thu, 19 Sep 2024 11:15:30 -0400 Subject: [PATCH] Trigger state change event on iterators (#9299) * Fix render async * add changeset * Fix regression * tests * Add code --------- Co-authored-by: gradio-pr-bot Co-authored-by: Abubakar Abid --- .changeset/deep-memes-cheat.md | 7 +++++++ client/js/src/helpers/api_info.ts | 5 ++++- demo/state_change/run.ipynb | 2 +- demo/state_change/run.py | 13 +++++++++++++ js/core/src/Blocks.svelte | 6 +++++- js/spa/test/state_change.spec.ts | 11 +++++++++++ 6 files changed, 41 insertions(+), 3 deletions(-) create mode 100644 .changeset/deep-memes-cheat.md diff --git a/.changeset/deep-memes-cheat.md b/.changeset/deep-memes-cheat.md new file mode 100644 index 0000000000000..429d9b4dd69ad --- /dev/null +++ b/.changeset/deep-memes-cheat.md @@ -0,0 +1,7 @@ +--- +"@gradio/client": patch +"@gradio/core": patch +"gradio": patch +--- + +fix:Trigger state change event on iterators diff --git a/client/js/src/helpers/api_info.ts b/client/js/src/helpers/api_info.ts index b500c539fc018..b3e5a655a2dc5 100644 --- a/client/js/src/helpers/api_info.ts +++ b/client/js/src/helpers/api_info.ts @@ -315,7 +315,10 @@ export function handle_message( stage: data.success ? "generating" : "error", code: data.code, progress_data: data.progress_data, - eta: data.average_duration + eta: data.average_duration, + changed_state_ids: data.success + ? data.output.changed_state_ids + : undefined }, data: data.success ? data.output : null }; diff --git a/demo/state_change/run.ipynb b/demo/state_change/run.ipynb index 8b4d611579adb..1fe9958e16699 100644 --- a/demo/state_change/run.ipynb +++ b/demo/state_change/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: state_change"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "with gr.Blocks() as demo:\n", "\n", " with gr.Row():\n", " state_a = gr.State(0)\n", " btn_a = gr.Button(\"Increment A\")\n", " value_a = gr.Number(label=\"Number A\")\n", " btn_a.click(lambda x: x + 1, state_a, state_a)\n", " state_a.change(lambda x: x, state_a, value_a)\n", " with gr.Row():\n", " state_b = gr.State(0)\n", " btn_b = gr.Button(\"Increment B\")\n", " value_b = gr.Number(label=\"Number B\")\n", " btn_b.click(lambda x: x + 1, state_b, state_b)\n", "\n", " @gr.on(inputs=state_b, outputs=value_b)\n", " def identity(x):\n", " return x\n", "\n", " @gr.render(inputs=[state_a, state_b])\n", " def render(a, b):\n", " for x in range(a):\n", " with gr.Row():\n", " for y in range(b):\n", " gr.Button(f\"Button {x}, {y}\")\n", "\n", " list_state = gr.State([])\n", " dict_state = gr.State(dict())\n", " nested_list_state = gr.State([])\n", " set_state = gr.State(set())\n", "\n", " def transform_list(x):\n", " return {n: n for n in x}, [x[:] for _ in range(len(x))], set(x)\n", "\n", " list_state.change(\n", " transform_list,\n", " inputs=list_state,\n", " outputs=[dict_state, nested_list_state, set_state],\n", " )\n", "\n", " all_textbox = gr.Textbox(label=\"Output\")\n", " click_count = gr.Number(label=\"Clicks\")\n", " change_count = gr.Number(label=\"Changes\")\n", " gr.on(\n", " inputs=[change_count, dict_state, nested_list_state, set_state],\n", " triggers=[dict_state.change, nested_list_state.change, set_state.change],\n", " fn=lambda x, *args: (x + 1, \"\\n\".join(str(arg) for arg in args)),\n", " outputs=[change_count, all_textbox],\n", " )\n", "\n", " count_to_3_btn = gr.Button(\"Count to 3\")\n", " count_to_3_btn.click(lambda: [1, 2, 3], outputs=list_state)\n", " zero_all_btn = gr.Button(\"Zero All\")\n", " zero_all_btn.click(lambda x: [0] * len(x), inputs=list_state, outputs=list_state)\n", "\n", " gr.on(\n", " [count_to_3_btn.click, zero_all_btn.click],\n", " lambda x: x + 1,\n", " click_count,\n", " click_count,\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: state_change"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "with gr.Blocks() as demo:\n", "\n", " with gr.Row():\n", " state_a = gr.State(0)\n", " btn_a = gr.Button(\"Increment A\")\n", " value_a = gr.Number(label=\"Number A\")\n", " btn_a.click(lambda x: x + 1, state_a, state_a)\n", " state_a.change(lambda x: x, state_a, value_a)\n", " with gr.Row():\n", " state_b = gr.State(0)\n", " btn_b = gr.Button(\"Increment B\")\n", " value_b = gr.Number(label=\"Number B\")\n", " btn_b.click(lambda x: x + 1, state_b, state_b)\n", "\n", " @gr.on(inputs=state_b, outputs=value_b)\n", " def identity(x):\n", " return x\n", "\n", " @gr.render(inputs=[state_a, state_b])\n", " def render(a, b):\n", " for x in range(a):\n", " with gr.Row():\n", " for y in range(b):\n", " gr.Button(f\"Button {x}, {y}\")\n", "\n", " list_state = gr.State([])\n", " dict_state = gr.State(dict())\n", " nested_list_state = gr.State([])\n", " set_state = gr.State(set())\n", "\n", " def transform_list(x):\n", " return {n: n for n in x}, [x[:] for _ in range(len(x))], set(x)\n", "\n", " list_state.change(\n", " transform_list,\n", " inputs=list_state,\n", " outputs=[dict_state, nested_list_state, set_state],\n", " )\n", "\n", " all_textbox = gr.Textbox(label=\"Output\")\n", " click_count = gr.Number(label=\"Clicks\")\n", " change_count = gr.Number(label=\"Changes\")\n", " gr.on(\n", " inputs=[change_count, dict_state, nested_list_state, set_state],\n", " triggers=[dict_state.change, nested_list_state.change, set_state.change],\n", " fn=lambda x, *args: (x + 1, \"\\n\".join(str(arg) for arg in args)),\n", " outputs=[change_count, all_textbox],\n", " )\n", "\n", " count_to_3_btn = gr.Button(\"Count to 3\")\n", " count_to_3_btn.click(lambda: [1, 2, 3], outputs=list_state)\n", " zero_all_btn = gr.Button(\"Zero All\")\n", " zero_all_btn.click(lambda x: [0] * len(x), inputs=list_state, outputs=list_state)\n", "\n", " gr.on(\n", " [count_to_3_btn.click, zero_all_btn.click],\n", " lambda x: x + 1,\n", " click_count,\n", " click_count,\n", " )\n", "\n", " async def increment(x):\n", " yield x + 1\n", "\n", " n_text = gr.State(0)\n", " add_btn = gr.Button(\"Iterator State Change\")\n", " add_btn.click(increment, n_text, n_text)\n", "\n", " @gr.render(inputs=n_text)\n", " def render_count(count):\n", " for i in range(int(count)):\n", " gr.Markdown(value = f\"Success Box {i} added\", key=i)\n", "\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/state_change/run.py b/demo/state_change/run.py index 2846a38ce51c6..565146f170c5f 100644 --- a/demo/state_change/run.py +++ b/demo/state_change/run.py @@ -61,5 +61,18 @@ def transform_list(x): click_count, ) + async def increment(x): + yield x + 1 + + n_text = gr.State(0) + add_btn = gr.Button("Iterator State Change") + add_btn.click(increment, n_text, n_text) + + @gr.render(inputs=n_text) + def render_count(count): + for i in range(int(count)): + gr.Markdown(value = f"Success Box {i} added", key=i) + + if __name__ == "__main__": demo.launch() diff --git a/js/core/src/Blocks.svelte b/js/core/src/Blocks.svelte index 5ffa367100d1a..125282ba6c8b9 100644 --- a/js/core/src/Blocks.svelte +++ b/js/core/src/Blocks.svelte @@ -428,6 +428,7 @@ } } + /* eslint-disable complexity */ function handle_status_update(message: StatusMessage): void { const { fn_index, ...status } = message; if (status.stage === "streaming" && status.time_limit) { @@ -474,7 +475,7 @@ ]; } - if (status.stage === "complete") { + if (status.stage === "complete" || status.stage === "generating") { status.changed_state_ids?.forEach((id) => { dependencies .filter((dep) => dep.targets.some(([_id, _]) => _id === id)) @@ -482,6 +483,8 @@ wait_then_trigger_api_call(dep.id, payload.trigger_id); }); }); + } + if (status.stage === "complete") { dependencies.forEach(async (dep) => { if (dep.trigger_after === fn_index) { wait_then_trigger_api_call(dep.id, payload.trigger_id); @@ -530,6 +533,7 @@ } } } + /* eslint-enable complexity */ function trigger_share(title: string | undefined, description: string): void { if (space_id === null) { diff --git a/js/spa/test/state_change.spec.ts b/js/spa/test/state_change.spec.ts index 04db71eabc7f4..d7ff7c20a1556 100644 --- a/js/spa/test/state_change.spec.ts +++ b/js/spa/test/state_change.spec.ts @@ -56,3 +56,14 @@ test("test datastructure-based state changes", async ({ page }) => { await expect(page.getByLabel("Changes")).toHaveValue("2"); await expect(page.getByLabel("Clicks")).toHaveValue("5"); }); + +test("test generators properly trigger state changes", async ({ page }) => { + await page.getByRole("button", { name: "Iterator State Change" }).click(); + await expect(page.getByTestId("markdown").first()).toHaveText( + "Success Box 0 added" + ); + await page.getByRole("button", { name: "Iterator State Change" }).click(); + await expect(page.getByTestId("markdown").nth(1)).toHaveText( + "Success Box 1 added" + ); +});