Skip to content

Commit

Permalink
Plot Component (#805)
Browse files Browse the repository at this point in the history
* plotly + matplotlib component

* update plot demos and plotly component

* fix gray bg

* format

* pnpm lock file

* add bokeh

* update plot demo

* add bokeh support

* ignore plot file

* fixed demo

* fixed sorting

* update image-plot deprecation warning

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
  • Loading branch information
dawoodkhan82 and abidlabs authored Apr 14, 2022
1 parent b17afde commit 7552e1e
Show file tree
Hide file tree
Showing 9 changed files with 203 additions and 19 deletions.
3 changes: 2 additions & 1 deletion demo/outbreak_forecast/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
numpy
matplotlib
matplotlib
bokeh
55 changes: 42 additions & 13 deletions demo/outbreak_forecast/run.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,72 @@
from math import sqrt

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import plotly.express as px
import pandas as pd
import bokeh.plotting as bk
from bokeh.models import ColumnDataSource
from bokeh.embed import json_item

import gradio as gr


def outbreak(r, month, countries, social_distancing):
def outbreak(plot_type, r, month, countries, social_distancing):
months = ["January", "February", "March", "April", "May"]
m = months.index(month)
start_day = 30 * m
final_day = 30 * (m + 1)
x = np.arange(start_day, final_day + 1)
day_count = x.shape[0]
pop_count = {"USA": 350, "Canada": 40, "Mexico": 300, "UK": 120}
r = sqrt(r)
if social_distancing:
r = sqrt(r)
for i, country in enumerate(countries):
series = x ** (r) * (i + 1)
plt.plot(x, series)
plt.title("Outbreak in " + month)
plt.ylabel("Cases")
plt.xlabel("Days since Day 0")
plt.legend(countries)
return plt
df = pd.DataFrame({'day': x})
for country in countries:
df[country] = ( x ** (r) * (pop_count[country] + 1))


if plot_type == "Matplotlib":
fig = plt.figure()
plt.plot(df['day'], df[countries])
plt.title("Outbreak in " + month)
plt.ylabel("Cases")
plt.xlabel("Days since Day 0")
plt.legend(countries)
return fig
elif plot_type == "Plotly":
fig = px.line(df, x='day', y=countries)
fig.update_layout(title="Outbreak in " + month,
xaxis_title="Cases",
yaxis_title="Days Since Day 0")
return fig
else:
source = ColumnDataSource(df)
p = bk.figure(title="Outbreak in " + month, x_axis_label="Cases", y_axis_label="Days Since Day 0")
for country in countries:
p.line(x='day', y=country, line_width=2, source=source)
item_text = json_item(p, "plotDiv")
return item_text



iface = gr.Interface(
outbreak,
[
gr.inputs.Dropdown(
["Matplotlib", "Plotly", "Bokeh"], label="Plot Type"
),
gr.inputs.Slider(1, 4, default=3.2, label="R"),
gr.inputs.Dropdown(
["January", "February", "March", "April", "May"], label="Month"
),
gr.inputs.CheckboxGroup(["USA", "Canada", "Mexico", "UK"], label="Countries"),
gr.inputs.CheckboxGroup(["USA", "Canada", "Mexico", "UK"], label="Countries",
default=["USA", "Canada"]),
gr.inputs.Checkbox(label="Social Distancing?"),
],
"plot",
gr.outputs.Plot(type="auto"),
)

if __name__ == "__main__":
iface.launch()
70 changes: 68 additions & 2 deletions gradio/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
from types import ModuleType
from typing import TYPE_CHECKING, Dict, List, Optional

import matplotlib
import numpy as np
import pandas as pd
import PIL
from black import out
from ffmpy import FFmpeg

from gradio import processing_utils
Expand Down Expand Up @@ -210,12 +212,12 @@ def __init__(
"""
Parameters:
type (str): Type of value to be passed to component. "numpy" expects a numpy array with shape (width, height, 3), "pil" expects a PIL image object, "file" expects a file path to the saved image or a remote URL, "plot" expects a matplotlib.pyplot object, "auto" detects return type.
plot (bool): DEPRECATED. Whether to expect a plot to be returned by the function.
plot (bool): DEPRECATED (Use the new 'plot' component). Whether to expect a plot to be returned by the function.
label (str): component name in interface.
"""
if plot:
warnings.warn(
"The 'plot' parameter has been deprecated. Set parameter 'type' to 'plot' instead.",
"The 'plot' parameter has been deprecated. Use the new 'plot' component instead.",
DeprecationWarning,
)
self.type = "plot"
Expand Down Expand Up @@ -853,6 +855,70 @@ def get_shortcut_implementations(cls):
}


class Plot(OutputComponent):
"""
Used for plot output.
Output type: matplotlib plt or plotly figure
Demos: outbreak_forecast
"""

def __init__(self, type: str = None, label: Optional[str] = None):
"""
Parameters:
type (str): type of plot (matplotlib, plotly)
label (str): component name in interface.
"""
self.type = type
super().__init__(label)

def get_template_context(self):
return {**super().get_template_context()}

@classmethod
def get_shortcut_implementations(cls):
return {
"plot": {},
}

def postprocess(self, y):
"""
Parameters:
y (str): plot data
Returns:
(str): plot type
(str): plot base64 or json
"""
dtype = self.type
if self.type == "plotly":
out_y = y.to_json()
elif self.type == "matplotlib":
out_y = processing_utils.encode_plot_to_base64(y)
elif self.type == "bokeh":
out_y = json.dumps(y)
elif self.type == "auto":
if isinstance(y, (ModuleType, matplotlib.pyplot.Figure)):
dtype = "matplotlib"
out_y = processing_utils.encode_plot_to_base64(y)
elif isinstance(y, dict):
dtype = "bokeh"
out_y = json.dumps(y)
else:
dtype = "plotly"
out_y = y.to_json()
else:
raise ValueError(
"Unknown type. Please choose from: 'plotly', 'matplotlib', 'bokeh'."
)
return {"type": dtype, "plot": out_y}

def deserialize(self, x):
y = processing_utils.decode_base64_to_file(x).name
return y

def save_flagged(self, dir, label, data, encryption_key):
return self.save_flagged_file(dir, label, data, encryption_key)


class Image3D(OutputComponent):
"""
Used for 3d image model output.
Expand Down
3 changes: 2 additions & 1 deletion ui/.prettierignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
packages/app/public/**
pnpm-workspace.yaml
packages/app/dist/**
pnpm-lock.yaml
pnpm-lock.yaml
packages/app/src/components/output/Plot/Plot.svelte
5 changes: 3 additions & 2 deletions ui/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
"svelte": "^3.46.3",
"svelte-check": "^2.4.1",
"svelte-i18n": "^3.3.13",
"vitest": "^0.3.2",
"plotly.js-dist-min": "^2.10.1",
"babylonjs": "^4.2.1",
"babylonjs-loaders": "^4.2.1",
"vitest": "^0.3.2"
"babylonjs-loaders": "^4.2.1"
},
"devDependencies": {
"@types/three": "^0.138.0"
Expand Down
2 changes: 2 additions & 0 deletions ui/packages/app/src/components/directory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import OutputTextbox from "./output/Textbox/config.js";
import OutputVideo from "./output/Video/config.js";
import OutputTimeSeries from "./output/TimeSeries/config.js";
import OutputChatbot from "./output/Chatbot/config.js";
import OutputPlot from "./output/Plot/config.js";
import OutputImage3D from "./output/Image3D/config.js";

import StaticButton from "./static/Button/config.js";
Expand Down Expand Up @@ -62,6 +63,7 @@ export const output_component_map = {
timeseries: OutputTimeSeries,
video: OutputVideo,
chatbot: OutputChatbot,
plot: OutputPlot,
image3d: OutputImage3D
};

Expand Down
73 changes: 73 additions & 0 deletions ui/packages/app/src/components/output/Plot/Plot.svelte
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
<svelte:head>
<!-- Loading Bokeh from CDN -->
<script src="https://cdn.bokeh.org/bokeh/release/bokeh-2.4.2.min.js" on:load={handleBokehLoaded} ></script>
{#if bokehLoaded}
<script src="https://cdn.pydata.org/bokeh/release/bokeh-widgets-2.4.2.min.js" on:load={() => initializeBokeh(1)} ></script>
<script src="https://cdn.pydata.org/bokeh/release/bokeh-tables-2.4.2.min.js" on:load={() => initializeBokeh(2)}></script>
<script src="https://cdn.pydata.org/bokeh/release/bokeh-gl-2.4.2.min.js" on:load={() => initializeBokeh(3)}></script>
<script src="https://cdn.pydata.org/bokeh/release/bokeh-api-2.4.2.min.js" on:load={() => initializeBokeh(4)}></script>
<script src="https://cdn.pydata.org/bokeh/release/bokeh-api-2.4.2.min.js" on:load={() => initializeBokeh(5)} ></script>
{/if}
</svelte:head>

<script lang="ts">
export let value: string;
export let theme: string;
import { afterUpdate, onMount} from "svelte";
import Plotly from "plotly.js-dist-min";
// Bokeh
let bokehLoaded = false
const resolves = []
const bokehPromises = Array(6).fill(0).map((_, i) => createPromise(i))
const initializeBokeh = (index) => {
if (value["type"] == "bokeh") {
console.log(resolves)
resolves[index]()
}
}
function createPromise(index) {
return new Promise((resolve, reject) => {
resolves[index] = resolve
})
}
function handleBokehLoaded() {
initializeBokeh(0)
bokehLoaded = true
}
Promise.all(bokehPromises).then(() => {
let plotObj = JSON.parse(value["plot"]);
window.Bokeh.embed.embed_item(plotObj, "plotDiv");
})
// Plotly
afterUpdate(() => {
if (value["type"] == "plotly") {
let plotObj = JSON.parse(value["plot"]);
let plotDiv = document.getElementById("plotDiv");
Plotly.newPlot(plotDiv, plotObj["data"], plotObj["layout"]);
} else if (value["type"] == "bokeh") {
let plotObj = JSON.parse(value["plot"]);
window.Bokeh.embed.embed_item(plotObj, "plotDiv");
}
});
</script>

{#if value["type"] == "plotly" || value["type"] == "bokeh" }
<div id="plotDiv" />
{:else}
<div
class="output-image w-full h-80 flex justify-center items-center dark:bg-gray-600 relative"
{theme}
>
<!-- svelte-ignore a11y-missing-attribute -->
<img class="w-full h-full object-contain" src={value["plot"]} />
</div>
{/if}

<style lang="postcss">
</style>
5 changes: 5 additions & 0 deletions ui/packages/app/src/components/output/Plot/config.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import Component from "./Plot.svelte";

export default {
component: Component
};
6 changes: 6 additions & 0 deletions ui/pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 7552e1e

Please sign in to comment.