Skip to content

Commit

Permalink
Refactor plots to drop altair and use vega.js directly (gradio-ap…
Browse files Browse the repository at this point in the history
…p#8807)

* changes

* add changeset

* changes

* changes

* changes

* add changeset

* changes

* add changeset

* changes

* add changeset

* add changeset

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* add changeset

* changes

* changes

* Update gradio/components/native_plot.py

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* Update gradio/components/native_plot.py

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* Update gradio/blocks.py

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* changes

* changes

* changes

* Update gradio/components/native_plot.py

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* Update gradio/components/native_plot.py

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* changes

* changes

* changes

---------

Co-authored-by: Ali Abid <aliabid94@gmail.com>
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
  • Loading branch information
4 people authored Jul 22, 2024
1 parent 914b193 commit a238af4
Show file tree
Hide file tree
Showing 34 changed files with 1,199 additions and 683 deletions.
8 changes: 8 additions & 0 deletions .changeset/tangy-beds-guess.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
"@gradio/app": minor
"@gradio/datetime": minor
"@gradio/nativeplot": minor
"gradio": minor
---

feat:Refactor plots to drop `altair` and use `vega.js` directly
2 changes: 1 addition & 1 deletion demo/blocks_xray/run.ipynb
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: blocks_xray"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import time\n", "\n", "disease_values = [0.25, 0.5, 0.75]\n", "\n", "def xray_model(diseases, img):\n", " return [{disease: disease_values[idx] for idx,disease in enumerate(diseases)}]\n", "\n", "\n", "def ct_model(diseases, img):\n", " return [{disease: 0.1 for disease in diseases}]\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\n", " \"\"\"\n", "# Detect Disease From Scan\n", "With this model you can lorem ipsum\n", "- ipsum 1\n", "- ipsum 2\n", "\"\"\"\n", " )\n", " gr.DuplicateButton()\n", " disease = gr.CheckboxGroup(\n", " info=\"Select the diseases you want to scan for.\",\n", " choices=[\"Covid\", \"Malaria\", \"Lung Cancer\"], label=\"Disease to Scan For\"\n", " )\n", " slider = gr.Slider(0, 100)\n", "\n", " with gr.Tab(\"X-ray\") as x_tab:\n", " with gr.Row():\n", " xray_scan = gr.Image()\n", " xray_results = gr.JSON()\n", " xray_run = gr.Button(\"Run\")\n", " xray_run.click(\n", " xray_model,\n", " inputs=[disease, xray_scan],\n", " outputs=xray_results,\n", " api_name=\"xray_model\"\n", " )\n", "\n", " with gr.Tab(\"CT Scan\"):\n", " with gr.Row():\n", " ct_scan = gr.Image()\n", " ct_results = gr.JSON()\n", " ct_run = gr.Button(\"Run\")\n", " ct_run.click(\n", " ct_model,\n", " inputs=[disease, ct_scan],\n", " outputs=ct_results,\n", " api_name=\"ct_model\"\n", " )\n", "\n", " upload_btn = gr.Button(\"Upload Results\", variant=\"primary\")\n", " upload_btn.click(\n", " lambda ct, xr: None,\n", " inputs=[ct_results, xray_results],\n", " outputs=[],\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: blocks_xray"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import time\n", "\n", "disease_values = [0.25, 0.5, 0.75]\n", "\n", "def xray_model(diseases, img):\n", " return [{disease: disease_values[idx] for idx,disease in enumerate(diseases)}]\n", "\n", "\n", "def ct_model(diseases, img):\n", " return [{disease: 0.1 for disease in diseases}]\n", "\n", "with gr.Blocks(fill_width=True) as demo:\n", " gr.Markdown(\n", " \"\"\"\n", "# Detect Disease From Scan\n", "With this model you can lorem ipsum\n", "- ipsum 1\n", "- ipsum 2\n", "\"\"\"\n", " )\n", " gr.DuplicateButton()\n", " disease = gr.CheckboxGroup(\n", " info=\"Select the diseases you want to scan for.\",\n", " choices=[\"Covid\", \"Malaria\", \"Lung Cancer\"], label=\"Disease to Scan For\"\n", " )\n", " slider = gr.Slider(0, 100)\n", "\n", " with gr.Tab(\"X-ray\") as x_tab:\n", " with gr.Row():\n", " xray_scan = gr.Image()\n", " xray_results = gr.JSON()\n", " xray_run = gr.Button(\"Run\")\n", " xray_run.click(\n", " xray_model,\n", " inputs=[disease, xray_scan],\n", " outputs=xray_results,\n", " api_name=\"xray_model\"\n", " )\n", "\n", " with gr.Tab(\"CT Scan\"):\n", " with gr.Row():\n", " ct_scan = gr.Image()\n", " ct_results = gr.JSON()\n", " ct_run = gr.Button(\"Run\")\n", " ct_run.click(\n", " ct_model,\n", " inputs=[disease, ct_scan],\n", " outputs=ct_results,\n", " api_name=\"ct_model\"\n", " )\n", "\n", " upload_btn = gr.Button(\"Upload Results\", variant=\"primary\")\n", " upload_btn.click(\n", " lambda ct, xr: None,\n", " inputs=[ct_results, xray_results],\n", " outputs=[],\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
2 changes: 1 addition & 1 deletion demo/blocks_xray/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def xray_model(diseases, img):
def ct_model(diseases, img):
return [{disease: 0.1 for disease in diseases}]

with gr.Blocks() as demo:
with gr.Blocks(fill_width=True) as demo:
gr.Markdown(
"""
# Detect Disease From Scan
Expand Down
166 changes: 66 additions & 100 deletions demo/native_plots/bar_plot_demo.py
Original file line number Diff line number Diff line change
@@ -1,111 +1,77 @@
import gradio as gr
import pandas as pd
import numpy as np
from data import temp_sensor_data, food_rating_data

from vega_datasets import data
with gr.Blocks() as bar_plots:
with gr.Row():
start = gr.DateTime("2021-01-01 00:00:00", label="Start")
end = gr.DateTime("2021-01-05 00:00:00", label="End")
apply_btn = gr.Button("Apply", scale=0)
with gr.Row():
group_by = gr.Radio(["None", "30m", "1h", "4h", "1d"], value="None", label="Group by")
aggregate = gr.Radio(["sum", "mean", "median", "min", "max"], value="sum", label="Aggregation")

barley = data.barley()
simple = pd.DataFrame({
'a': ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I'],
'b': [28, 55, 43, 91, 81, 53, 19, 87, 52]
})
temp_by_time = gr.BarPlot(
temp_sensor_data,
x="time",
y="temperature",
)
temp_by_time_location = gr.BarPlot(
temp_sensor_data,
x="time",
y="temperature",
color="location",
)

def bar_plot_fn(display):
if display == "simple":
return gr.BarPlot(
simple,
x="a",
y="b",
color=None,
group=None,
title="Simple Bar Plot with made up data",
tooltip=['a', 'b'],
y_lim=[20, 100],
x_title=None,
y_title=None,
vertical=True,
)
elif display == "stacked":
return gr.BarPlot(
barley,
x="variety",
y="yield",
color="site",
group=None,
title="Barley Yield Data",
tooltip=['variety', 'site'],
y_lim=None,
x_title=None,
y_title=None,
vertical=True,
)
elif display == "grouped":
return gr.BarPlot(
barley.astype({"year": str}),
x="year",
y="yield",
color="year",
group="site",
title="Barley Yield by Year and Site",
tooltip=["yield", "site", "year"],
y_lim=None,
x_title=None,
y_title=None,
vertical=True,
time_graphs = [temp_by_time, temp_by_time_location]
group_by.change(
lambda group: [gr.BarPlot(x_bin=None if group == "None" else group)] * len(time_graphs),
group_by,
time_graphs
)
aggregate.change(
lambda aggregate: [gr.BarPlot(y_aggregate=aggregate)] * len(time_graphs),
aggregate,
time_graphs
)


def rescale(select: gr.SelectData):
return select.index
rescale_evt = gr.on([plot.select for plot in time_graphs], rescale, None, [start, end])

for trigger in [apply_btn.click, rescale_evt.then]:
trigger(
lambda start, end: [gr.BarPlot(x_lim=[start, end])] * len(time_graphs), [start, end], time_graphs
)
elif display == "simple-horizontal":
return gr.BarPlot(
simple,
x="a",
y="b",
color=None,
group=None,
title="Simple Bar Plot with made up data",
tooltip=['a', 'b'],
y_lim=[20, 100],
x_title="Variable A",
y_title="Variable B",
vertical=False,

with gr.Row():
price_by_cuisine = gr.BarPlot(
food_rating_data,
x="cuisine",
y="price",
)
elif display == "stacked-horizontal":
return gr.BarPlot(
barley,
x="variety",
y="yield",
color="site",
group=None,
title="Barley Yield Data",
tooltip=['variety', 'site'],
y_lim=None,
x_title=None,
y_title=None,
vertical=False,
with gr.Column(scale=0):
gr.Button("Sort $ > $$$").click(lambda: gr.BarPlot(sort="y"), None, price_by_cuisine)
gr.Button("Sort $$$ > $").click(lambda: gr.BarPlot(sort="-y"), None, price_by_cuisine)
gr.Button("Sort A > Z").click(lambda: gr.BarPlot(sort=["Chinese", "Italian", "Mexican"]), None, price_by_cuisine)

with gr.Row():
price_by_rating = gr.BarPlot(
food_rating_data,
x="rating",
y="price",
x_bin=1,
)
elif display == "grouped-horizontal":
return gr.BarPlot(
barley.astype({"year": str}),
x="year",
y="yield",
color="year",
group="site",
title="Barley Yield by Year and Site",
group_title="",
tooltip=["yield", "site", "year"],
y_lim=None,
x_title=None,
y_title=None,
vertical=False
price_by_rating_color = gr.BarPlot(
food_rating_data,
x="rating",
y="price",
color="cuisine",
x_bin=1,
color_map={"Italian": "red", "Mexican": "green", "Chinese": "blue"},
)


with gr.Blocks() as bar_plot:
display = gr.Dropdown(
choices=["simple", "stacked", "grouped", "simple-horizontal", "stacked-horizontal", "grouped-horizontal"],
value="simple",
label="Type of Bar Plot"
)
plot = gr.BarPlot(show_label=False)
display.change(bar_plot_fn, inputs=display, outputs=plot)
bar_plot.load(fn=bar_plot_fn, inputs=display, outputs=plot)

if __name__ == "__main__":
bar_plot.launch()
bar_plots.launch()
20 changes: 20 additions & 0 deletions demo/native_plots/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pandas as pd
from random import randint, choice, random

temp_sensor_data = pd.DataFrame(
{
"time": pd.date_range("2021-01-01", end="2021-01-05", periods=200),
"temperature": [randint(50 + 10 * (i % 2), 65 + 15 * (i % 2)) for i in range(200)],
"humidity": [randint(50 + 10 * (i % 2), 65 + 15 * (i % 2)) for i in range(200)],
"location": ["indoor", "outdoor"] * 100,
}
)

food_rating_data = pd.DataFrame(
{
"cuisine": [["Italian", "Mexican", "Chinese"][i % 3] for i in range(100)],
"rating": [random() * 4 + 0.5 * (i % 3) for i in range(100)],
"price": [randint(10, 50) + 4 * (i % 3) for i in range(100)],
"wait": [random() for i in range(100)],
}
)
131 changes: 59 additions & 72 deletions demo/native_plots/line_plot_demo.py
Original file line number Diff line number Diff line change
@@ -1,82 +1,69 @@
import gradio as gr
from vega_datasets import data
import numpy as np
from data import temp_sensor_data, food_rating_data

stocks = data.stocks()
gapminder = data.gapminder()
gapminder = gapminder.loc[
gapminder.country.isin(["Argentina", "Australia", "Afghanistan"])
]
climate = data.climate()
seattle_weather = data.seattle_weather()
with gr.Blocks() as line_plots:
with gr.Row():
start = gr.DateTime("2021-01-01 00:00:00", label="Start")
end = gr.DateTime("2021-01-05 00:00:00", label="End")
apply_btn = gr.Button("Apply", scale=0)
with gr.Row():
group_by = gr.Radio(["None", "30m", "1h", "4h", "1d"], value="None", label="Group by")
aggregate = gr.Radio(["sum", "mean", "median", "min", "max"], value="sum", label="Aggregation")

temp_by_time = gr.LinePlot(
temp_sensor_data,
x="time",
y="temperature",
)
temp_by_time_location = gr.LinePlot(
temp_sensor_data,
x="time",
y="temperature",
color="location",
)

time_graphs = [temp_by_time, temp_by_time_location]
group_by.change(
lambda group: [gr.LinePlot(x_bin=None if group == "None" else group)] * len(time_graphs),
group_by,
time_graphs
)
aggregate.change(
lambda aggregate: [gr.LinePlot(y_aggregate=aggregate)] * len(time_graphs),
aggregate,
time_graphs
)

def line_plot_fn(dataset):
if dataset == "stocks":
return gr.LinePlot(
stocks,
x="date",
y="price",
color="symbol",
x_lim=None,
y_lim=None,
stroke_dash=None,
tooltip=['date', 'price', 'symbol'],
overlay_point=False,
title="Stock Prices",
stroke_dash_legend_title=None,
)
elif dataset == "climate":
return gr.LinePlot(
climate,
x="DATE",
y="HLY-TEMP-NORMAL",
color=None,
x_lim=None,
y_lim=[250, 500],
stroke_dash=None,
tooltip=['DATE', 'HLY-TEMP-NORMAL'],
overlay_point=False,
title="Climate",
stroke_dash_legend_title=None,
)
elif dataset == "seattle_weather":
return gr.LinePlot(
seattle_weather,
x="date",
y="temp_min",
color=None,
x_lim=None,
y_lim=None,
stroke_dash=None,
tooltip=["weather", "date"],
overlay_point=True,
title="Seattle Weather",
stroke_dash_legend_title=None,
)
elif dataset == "gapminder":
return gr.LinePlot(
gapminder,
x="year",
y="life_expect",
color="country",
x_lim=[1950, 2010],
y_lim=None,
stroke_dash="cluster",
tooltip=['country', 'life_expect'],
overlay_point=False,
title="Life expectancy for countries",
)

def rescale(select: gr.SelectData):
return select.index
rescale_evt = gr.on([plot.select for plot in time_graphs], rescale, None, [start, end])

for trigger in [apply_btn.click, rescale_evt.then]:
trigger(
lambda start, end: [gr.LinePlot(x_lim=[start, end])] * len(time_graphs), [start, end], time_graphs
)

with gr.Blocks() as line_plot:
dataset = gr.Dropdown(
choices=["stocks", "climate", "seattle_weather", "gapminder"],
value="stocks",
price_by_cuisine = gr.LinePlot(
food_rating_data,
x="cuisine",
y="price",
)
plot = gr.LinePlot()
dataset.change(line_plot_fn, inputs=dataset, outputs=plot)
line_plot.load(fn=line_plot_fn, inputs=dataset, outputs=plot)
with gr.Row():
price_by_rating = gr.LinePlot(
food_rating_data,
x="rating",
y="price",
)
price_by_rating_color = gr.LinePlot(
food_rating_data,
x="rating",
y="price",
color="cuisine",
color_map={"Italian": "red", "Mexican": "green", "Chinese": "blue"},
)


if __name__ == "__main__":
line_plot.launch()
line_plots.launch()
2 changes: 1 addition & 1 deletion demo/native_plots/run.ipynb
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: native_plots"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio vega_datasets"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/bar_plot_demo.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/line_plot_demo.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/scatter_plot_demo.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "from scatter_plot_demo import scatter_plot\n", "from line_plot_demo import line_plot\n", "from bar_plot_demo import bar_plot\n", "\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Tabs():\n", " with gr.TabItem(\"Scatter Plot\"):\n", " scatter_plot.render()\n", " with gr.TabItem(\"Line Plot\"):\n", " line_plot.render()\n", " with gr.TabItem(\"Bar Plot\"):\n", " bar_plot.render()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: native_plots"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio vega_datasets"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/bar_plot_demo.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/data.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/line_plot_demo.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/scatter_plot_demo.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "from scatter_plot_demo import scatter_plots\n", "from line_plot_demo import line_plots\n", "from bar_plot_demo import bar_plots\n", "\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Tabs():\n", " with gr.TabItem(\"Line Plot\"):\n", " line_plots.render()\n", " with gr.TabItem(\"Scatter Plot\"):\n", " scatter_plots.render()\n", " with gr.TabItem(\"Bar Plot\"):\n", " bar_plots.render()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
14 changes: 7 additions & 7 deletions demo/native_plots/run.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import gradio as gr

from scatter_plot_demo import scatter_plot
from line_plot_demo import line_plot
from bar_plot_demo import bar_plot
from scatter_plot_demo import scatter_plots
from line_plot_demo import line_plots
from bar_plot_demo import bar_plots


with gr.Blocks() as demo:
with gr.Tabs():
with gr.TabItem("Scatter Plot"):
scatter_plot.render()
with gr.TabItem("Line Plot"):
line_plot.render()
line_plots.render()
with gr.TabItem("Scatter Plot"):
scatter_plots.render()
with gr.TabItem("Bar Plot"):
bar_plot.render()
bar_plots.render()

if __name__ == "__main__":
demo.launch()
Loading

0 comments on commit a238af4

Please sign in to comment.