Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix dataset update #8581

Merged
merged 13 commits into from
Jun 19, 2024
7 changes: 7 additions & 0 deletions .changeset/soft-worms-remain.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"@gradio/dataset": patch
"gradio": patch
"website": patch
---

fix:fix dataset update
6 changes: 3 additions & 3 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1723,12 +1723,12 @@ async def postprocess_data(
) from err

if block.stateful:
if not utils.is_update(predictions[i]):
if not utils.is_prop_update(predictions[i]):
state[block._id] = predictions[i]
output.append(None)
else:
prediction_value = predictions[i]
if utils.is_update(
if utils.is_prop_update(
prediction_value
): # if update is passed directly (deprecated), remove Nones
prediction_value = utils.delete_none(
Expand All @@ -1738,7 +1738,7 @@ async def postprocess_data(
if isinstance(prediction_value, Block):
prediction_value = prediction_value.constructor_args.copy()
prediction_value["__type__"] = "update"
if utils.is_update(prediction_value):
if utils.is_prop_update(prediction_value):
kwargs = state[block._id].constructor_args.copy()
kwargs.update(prediction_value)
kwargs.pop("value", None)
Expand Down
34 changes: 23 additions & 11 deletions gradio/components/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import warnings
from typing import Any, Literal

from gradio_client.documentation import document
Expand All @@ -17,7 +18,8 @@
@document()
class Dataset(Component):
"""
Creates a gallery or table to display data samples. This component is designed for internal use to display examples.
Creates a gallery or table to display data samples. This component is primarily designed for internal use to display examples.
However, it can also be used directly to display a dataset and let users select examples.
"""

EVENTS = [Events.click, Events.select]
Expand All @@ -26,7 +28,7 @@ def __init__(
self,
*,
label: str | None = None,
components: list[Component] | list[str],
components: list[Component] | list[str] | None = None,
component_props: list[dict[str, Any]] | None = None,
samples: list[list[Any]] | None = None,
headers: list[str] | None = None,
Expand Down Expand Up @@ -70,7 +72,7 @@ def __init__(
self.container = container
self.scale = scale
self.min_width = min_width
self._components = [get_component_instance(c) for c in components]
self._components = [get_component_instance(c) for c in components or []]
if component_props is None:
self.component_props = [
component.recover_kwargs(
Expand Down Expand Up @@ -131,29 +133,39 @@ def get_config(self):

return config

def preprocess(self, payload: int) -> int | list | None:
def preprocess(self, payload: int | None) -> int | list | None:
"""
Parameters:
payload: the index of the selected example in the dataset
Returns:
Passes the selected sample either as a `list` of data corresponding to each input component (if `type` is "value") or as an `int` index (if `type` is "index")
"""
if payload is None:
return None
if self.type == "index":
return payload
elif self.type == "values":
return self.samples[payload]

def postprocess(self, samples: list[list]) -> dict:
def postprocess(self, sample: int | list | None) -> int | None:
"""
Parameters:
samples: Expects a `list[list]` corresponding to the dataset data, can be used to update the dataset.
sample: Expects an `int` index or `list` of sample data. Returns the index of the sample in the dataset or `None` if the sample is not found.
Returns:
Returns the updated dataset data as a `dict` with the key "samples".
Returns the index of the sample in the dataset.
"""
return {
"samples": samples,
"__type__": "update",
}
if sample is None or isinstance(sample, int):
return sample
if isinstance(sample, list):
try:
index = self.samples.index(sample)
except ValueError:
index = None
warnings.warn(
"The `Dataset` component does not support updating the dataset data by providing "
"a set of list values. Instead, you should return a new Dataset(samples=...) object."
)
return index

def example_payload(self) -> Any:
return 0
Expand Down
2 changes: 1 addition & 1 deletion gradio/flagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def flag(
) / client_utils.strip_invalid_filename_characters(
getattr(component, "label", None) or f"component {idx}"
)
if utils.is_update(sample):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a refactor for clarity

if utils.is_prop_update(sample):
csv_data.append(str(sample))
else:
data = (
Expand Down
2 changes: 1 addition & 1 deletion gradio/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ def load_from_cache(self, example_id: int) -> list[Any]:
component, components.File
):
value_to_use = value_as_dict
if not utils.is_update(value_as_dict):
if not utils.is_prop_update(value_as_dict):
raise TypeError("value wasn't an update") # caught below
output.append(value_as_dict)
except (ValueError, TypeError, SyntaxError):
Expand Down
2 changes: 1 addition & 1 deletion gradio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ def validate_url(possible_url: str) -> bool:
return False


def is_update(val):
def is_prop_update(val):
return isinstance(val, dict) and "update" in val.get("__type__", "")


Expand Down
35 changes: 35 additions & 0 deletions js/_website/src/lib/templates/gradio/03_components/dataset.svx
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,40 @@ def predict(···) -> list[list]
<DemosSection demos={obj.demos} />
{/if}

### Examples

**Updating a Dataset**

In this example, we display a text dataset using `gr.Dataset` and then update it when the user clicks a button:

```py
import gradio as gr

philosophy_quotes = [
["I think therefore I am."],
["The unexamined life is not worth living."]
]

startup_quotes = [
["Ideas are easy. Implementation is hard"],
["Make mistakes faster."]
]

def show_startup_quotes():
return gr.Dataset(samples=startup_quotes)

with gr.Blocks() as demo:
textbox = gr.Textbox()
dataset = gr.Dataset(components=[textbox], samples=philosophy_quotes)
button = gr.Button()

button.click(show_startup_quotes, None, dataset)

demo.launch()
```



{#if obj.fns && obj.fns.length > 0}
<!--- Event Listeners -->
### Event Listeners
Expand All @@ -97,3 +131,4 @@ def predict(···) -> list[list]
### Guides
<GuidesSection guides={obj.guides}/>
{/if}

5 changes: 3 additions & 2 deletions js/dataset/Index.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
>;
export let label = "Examples";
export let headers: string[];
export let samples: any[][];
export let samples: any[][] | null = null;
export let elem_id = "";
export let elem_classes: string[] = [];
export let visible = true;
Expand All @@ -34,7 +34,7 @@
: `${root}/file=`;
let page = 0;
$: gallery = components.length < 2;
let paginate = samples.length > samples_per_page;
let paginate = samples ? samples.length > samples_per_page : false;

let selected_samples: any[][];
let page_count: number;
Expand All @@ -51,6 +51,7 @@
}

$: {
samples = samples || [];
paginate = samples.length > samples_per_page;
if (paginate) {
visible_pages = [];
Expand Down
27 changes: 27 additions & 0 deletions test/test_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,33 @@ def infer(a, b):
):
await demo.postprocess_data(demo.fns[0], predictions=(1, 2), state=None)

@pytest.mark.asyncio
async def test_dataset_is_updated(self):
def update(value):
return value, gr.Dataset(samples=[["New A"], ["New B"]])

with gr.Blocks() as demo:
with gr.Row():
textbox = gr.Textbox()
dataset = gr.Dataset(
components=["text"], samples=[["Original"]], label="Saved Prompts"
)
dataset.click(update, inputs=[dataset], outputs=[textbox, dataset])
app, _, _ = demo.launch(prevent_thread_lock=True)

client = TestClient(app)

session_1 = client.post(
"/api/predict/",
json={"data": [0], "session_hash": "1", "fn_index": 0},
)
assert "Original" in session_1.json()["data"][0]
session_2 = client.post(
"/api/predict/",
json={"data": [0], "session_hash": "1", "fn_index": 0},
)
assert "New" in session_2.json()["data"][0]


class TestStateHolder:
@pytest.mark.asyncio
Expand Down