Skip to content

Commit 1c2f571

Browse files
authored
Merge pull request #92 from predict-idlab/figure_dict_input
🔥 add support for figure dict input + propagate _grid_str
2 parents f96a2df + 862bc0a commit 1c2f571

File tree

5 files changed

+174
-10
lines changed

5 files changed

+174
-10
lines changed

plotly_resampler/figure_resampler/figure_resampler.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,18 @@ def __init__(
5555
if isinstance(figure, BaseFigure): # go.FigureWidget or AbstractFigureAggregator
5656
# A base figure object, we first copy the layout and grid ref
5757
f.layout = figure.layout
58+
f._grid_str = figure._grid_str
5859
f._grid_ref = figure._grid_ref
5960
f.add_traces(figure.data)
61+
elif isinstance(figure, dict) and (
62+
"data" in figure or "layout" in figure # or "frames" in figure # TODO
63+
):
64+
# A dict with data, layout or frames
65+
f.layout = figure.get("layout")
66+
f._grid_str = figure.get("_grid_str")
67+
f._grid_ref = figure.get("_grid_ref")
68+
f.add_traces(figure.get("data"))
69+
# f.add_frames(figure.get("frames")) TODO
6070
elif isinstance(figure, (dict, list)):
6171
# A single trace dict or a list of traces
6272
f.add_traces(figure)

plotly_resampler/figure_resampler/figure_resampler_interface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def __init__(
113113
# call __init__ with the correct layout and set the `_grid_ref` of the
114114
# to-be-converted figure
115115
f_ = self._figure_class(layout=figure.layout)
116+
f_._grid_str = figure._grid_str
116117
f_._grid_ref = figure._grid_ref
117118
super().__init__(f_)
118119

plotly_resampler/figure_resampler/figurewidget_resampler.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,17 @@ def __init__(
5959
if isinstance(figure, BaseFigure): # go.Figure or go.FigureWidget or AbstractFigureAggregator
6060
# A base figure object, we first copy the layout and grid ref
6161
f.layout = figure.layout
62+
f._grid_str = figure._grid_str
6263
f._grid_ref = figure._grid_ref
6364
f.add_traces(figure.data)
65+
elif isinstance(figure, dict) and (
66+
"data" in figure or "layout" in figure # or "frames" in figure # TODO
67+
):
68+
f.layout = figure.get("layout")
69+
f._grid_str = figure.get("_grid_str")
70+
f._grid_ref = figure.get("_grid_ref")
71+
f.add_traces(figure.get("data"))
72+
# f.add_frames(figure.get("frames")) TODO
6473
elif isinstance(figure, (dict, list)):
6574
# A single trace dict or a list of traces
6675
f.add_traces(figure)

tests/test_figure_resampler.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,7 @@ def test_fr_add_empty_trace():
648648
assert len(fig.hf_data[0]["y"]) == 0
649649

650650

651-
def test_fr_from_dict():
651+
def test_fr_from_trace_dict():
652652
y = np.array([1] * 10_000)
653653
base_fig = {
654654
"type": "scatter",
@@ -668,6 +668,24 @@ def test_fr_from_dict():
668668
assert fr_fig.data[0].uid in fr_fig._hf_data
669669

670670

671+
def test_fr_from_figure_dict():
672+
y = np.array([1] * 10_000)
673+
base_fig = go.Figure()
674+
base_fig.add_trace(go.Scatter(y=y))
675+
676+
fr_fig = FigureResampler(base_fig.to_dict(), default_n_shown_samples=1000)
677+
assert len(fr_fig.hf_data) == 1
678+
assert (fr_fig.hf_data[0]["y"] == y).all()
679+
assert len(fr_fig.data) == 1
680+
assert len(fr_fig.data[0]["x"]) == 1_000
681+
assert (fr_fig.data[0]["x"][0] >= 0) & (fr_fig.data[0]["x"][-1] < 10_000)
682+
assert (fr_fig.data[0]["y"] == [1] * 1_000).all()
683+
684+
# assert that all the uuids of data and hf_data match
685+
# this is a proxy for assuring that the dynamic aggregation should work
686+
assert fr_fig.data[0].uid in fr_fig._hf_data
687+
688+
671689
def test_fr_empty_list():
672690
# and empty list -> so no concrete traces were added
673691
fr_fig = FigureResampler([], default_n_shown_samples=1000)
@@ -927,3 +945,57 @@ def test_fr_object_binary_data():
927945
assert fig.hf_data[0]["y"].dtype == "int64"
928946
assert fig.data[0]["y"].dtype == "int64"
929947
assert np.all(fig.data[0]["y"] == binary_series)
948+
949+
950+
def test_fr_copy_grid():
951+
f = make_subplots(rows=2, cols=1)
952+
f.add_scatter(y=np.arange(2_000), row=1, col=1)
953+
f.add_scatter(y=np.arange(2_000), row=2, col=1)
954+
955+
## go.Figure
956+
assert isinstance(f, go.Figure)
957+
assert f._grid_ref is not None
958+
fr = FigureResampler(f)
959+
assert fr._grid_ref is not None
960+
assert fr._grid_ref == f._grid_ref
961+
962+
## go.FigureWidget
963+
fw = go.FigureWidget(f)
964+
assert fw._grid_ref is not None
965+
assert isinstance(fw, go.FigureWidget)
966+
fr = FigureResampler(fw)
967+
assert fr._grid_ref is not None
968+
assert fr._grid_ref == fw._grid_ref
969+
970+
## FigureResampler
971+
fr_ = FigureResampler(f)
972+
assert fr_._grid_ref is not None
973+
assert isinstance(fr_, FigureResampler)
974+
fr = FigureResampler(fr_)
975+
assert fr._grid_ref is not None
976+
assert fr._grid_ref == fr_._grid_ref
977+
978+
## FigureWidgetResampler
979+
from plotly_resampler import FigureWidgetResampler
980+
fwr = FigureWidgetResampler(f)
981+
assert fwr._grid_ref is not None
982+
assert isinstance(fwr, FigureWidgetResampler)
983+
fr = FigureResampler(fwr)
984+
assert fr._grid_ref is not None
985+
assert fr._grid_ref == fwr._grid_ref
986+
987+
## dict (with no _grid_ref)
988+
f_dict = f.to_dict()
989+
assert isinstance(f_dict, dict)
990+
assert f_dict.get("_grid_ref") is None
991+
fr = FigureResampler(f_dict)
992+
assert fr._grid_ref is f_dict.get("_grid_ref") # both are None
993+
994+
## dict (with _grid_ref)
995+
f_dict = f.to_dict()
996+
f_dict["_grid_ref"] = f._grid_ref
997+
assert isinstance(f_dict, dict)
998+
assert f_dict.get("_grid_ref") is not None
999+
fr = FigureResampler(f_dict)
1000+
assert fr._grid_ref is not None
1001+
assert fr._grid_ref == f_dict.get("_grid_ref")

tests/test_figurewidget_resampler.py

Lines changed: 81 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1534,24 +1534,42 @@ def test_fwr_time_based_data_s():
15341534
assert (text == -hovertext).sum() == 1000
15351535

15361536

1537-
def test_fwr_from_dict():
1537+
def test_fwr_from_trace_dict():
15381538
y = np.array([1] * 10_000)
15391539
base_fig = {
15401540
"type": "scatter",
15411541
"y": y,
15421542
}
15431543

1544-
fr_fig = FigureWidgetResampler(base_fig, default_n_shown_samples=1000)
1545-
assert len(fr_fig.hf_data) == 1
1546-
assert (fr_fig.hf_data[0]["y"] == y).all()
1547-
assert len(fr_fig.data) == 1
1548-
assert len(fr_fig.data[0]["x"]) == 1_000
1549-
assert (fr_fig.data[0]["x"][0] >= 0) & (fr_fig.data[0]["x"][-1] < 10_000)
1550-
assert (fr_fig.data[0]["y"] == [1] * 1_000).all()
1544+
fwr_fig = FigureWidgetResampler(base_fig, default_n_shown_samples=1000)
1545+
assert len(fwr_fig.hf_data) == 1
1546+
assert (fwr_fig.hf_data[0]["y"] == y).all()
1547+
assert len(fwr_fig.data) == 1
1548+
assert len(fwr_fig.data[0]["x"]) == 1_000
1549+
assert (fwr_fig.data[0]["x"][0] >= 0) & (fwr_fig.data[0]["x"][-1] < 10_000)
1550+
assert (fwr_fig.data[0]["y"] == [1] * 1_000).all()
15511551

15521552
# assert that all the uuids of data and hf_data match
15531553
# this is a proxy for assuring that the dynamic aggregation should work
1554-
assert fr_fig.data[0].uid in fr_fig._hf_data
1554+
assert fwr_fig.data[0].uid in fwr_fig._hf_data
1555+
1556+
1557+
def test_fwr_from_figure_dict():
1558+
y = np.array([1] * 10_000)
1559+
base_fig = go.Figure()
1560+
base_fig.add_trace(go.Scatter(y=y))
1561+
1562+
fwr_fig = FigureWidgetResampler(base_fig.to_dict(), default_n_shown_samples=1000)
1563+
assert len(fwr_fig.hf_data) == 1
1564+
assert (fwr_fig.hf_data[0]["y"] == y).all()
1565+
assert len(fwr_fig.data) == 1
1566+
assert len(fwr_fig.data[0]["x"]) == 1_000
1567+
assert (fwr_fig.data[0]["x"][0] >= 0) & (fwr_fig.data[0]["x"][-1] < 10_000)
1568+
assert (fwr_fig.data[0]["y"] == [1] * 1_000).all()
1569+
1570+
# assert that all the uuids of data and hf_data match
1571+
# this is a proxy for assuring that the dynamic aggregation should work
1572+
assert fwr_fig.data[0].uid in fwr_fig._hf_data
15551573

15561574

15571575
def test_fwr_empty_list():
@@ -1796,3 +1814,57 @@ def test_fwr_object_binary_data():
17961814
assert fig.hf_data[0]["y"].dtype == "int64"
17971815
assert fig.data[0]["y"].dtype == "int64"
17981816
assert np.all(fig.data[0]["y"] == binary_series)
1817+
1818+
1819+
def test_fwr_copy_grid():
1820+
f = make_subplots(rows=2, cols=1)
1821+
f.add_scatter(y=np.arange(2_000), row=1, col=1)
1822+
f.add_scatter(y=np.arange(2_000), row=2, col=1)
1823+
1824+
## go.Figure
1825+
assert isinstance(f, go.Figure)
1826+
assert f._grid_ref is not None
1827+
fwr = FigureWidgetResampler(f)
1828+
assert fwr._grid_ref is not None
1829+
assert fwr._grid_ref == f._grid_ref
1830+
1831+
## go.FigureWidget
1832+
fw = go.FigureWidget(f)
1833+
assert fw._grid_ref is not None
1834+
assert isinstance(fw, go.FigureWidget)
1835+
fwr = FigureWidgetResampler(fw)
1836+
assert fwr._grid_ref is not None
1837+
assert fwr._grid_ref == fw._grid_ref
1838+
1839+
## FigureWidgetResampler
1840+
fwr_ = FigureWidgetResampler(f)
1841+
assert fwr_._grid_ref is not None
1842+
assert isinstance(fwr_, FigureWidgetResampler)
1843+
fwr = FigureWidgetResampler(fwr_)
1844+
assert fwr._grid_ref is not None
1845+
assert fwr._grid_ref == fwr_._grid_ref
1846+
1847+
## FigureResampler
1848+
from plotly_resampler import FigureResampler
1849+
fr = FigureResampler(f)
1850+
assert fr._grid_ref is not None
1851+
assert isinstance(fr, FigureResampler)
1852+
fwr = FigureWidgetResampler(fr)
1853+
assert fwr._grid_ref is not None
1854+
assert fwr._grid_ref == fr._grid_ref
1855+
1856+
## dict (with no _grid_ref)
1857+
f_dict = f.to_dict()
1858+
assert isinstance(f_dict, dict)
1859+
assert f_dict.get("_grid_ref") is None
1860+
fwr = FigureWidgetResampler(f_dict)
1861+
assert fwr._grid_ref is f_dict.get("_grid_ref") # both are None
1862+
1863+
## dict (with _grid_ref)
1864+
f_dict = f.to_dict()
1865+
f_dict["_grid_ref"] = f._grid_ref
1866+
assert isinstance(f_dict, dict)
1867+
assert f_dict.get("_grid_ref") is not None
1868+
fwr = FigureWidgetResampler(f_dict)
1869+
assert fwr._grid_ref is not None
1870+
assert fwr._grid_ref == f_dict.get("_grid_ref")

0 commit comments

Comments
 (0)