Skip to content

Commit

Permalink
Merge pull request #4470 from plotly/pass-b64
Browse files Browse the repository at this point in the history
Use plotly.js `base64` API to store and pass typed arrays declared by numpy, pandas, etc.
  • Loading branch information
marthacryan authored Oct 21, 2024
2 parents 7c24d87 + f481af7 commit 8c75004
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 36 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
All notable changes to this project will be documented in this file.
This project adheres to [Semantic Versioning](http://semver.org/).

### Updated

- Updated plotly.py to use base64 encoding of arrays in plotly JSON to improve performance.

## [5.24.1] - 2024-09-12

### Updated
Expand Down
106 changes: 105 additions & 1 deletion packages/python/plotly/_plotly_utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,115 @@
import base64
import decimal
import json as _json
import sys
import re
from functools import reduce

from _plotly_utils.optional_imports import get_module
from _plotly_utils.basevalidators import ImageUriValidator
from _plotly_utils.basevalidators import (
ImageUriValidator,
copy_to_readonly_numpy_array,
is_homogeneous_array,
)


int8min = -128
int8max = 127
int16min = -32768
int16max = 32767
int32min = -2147483648
int32max = 2147483647

uint8max = 255
uint16max = 65535
uint32max = 4294967295

plotlyjsShortTypes = {
"int8": "i1",
"uint8": "u1",
"int16": "i2",
"uint16": "u2",
"int32": "i4",
"uint32": "u4",
"float32": "f4",
"float64": "f8",
}


def to_typed_array_spec(v):
"""
Convert numpy array to plotly.js typed array spec
If not possible return the original value
"""
v = copy_to_readonly_numpy_array(v)

np = get_module("numpy", should_load=False)
if not np or not isinstance(v, np.ndarray):
return v

dtype = str(v.dtype)

# convert default Big Ints until we could support them in plotly.js
if dtype == "int64":
max = v.max()
min = v.min()
if max <= int8max and min >= int8min:
v = v.astype("int8")
elif max <= int16max and min >= int16min:
v = v.astype("int16")
elif max <= int32max and min >= int32min:
v = v.astype("int32")
else:
return v

elif dtype == "uint64":
max = v.max()
min = v.min()
if max <= uint8max and min >= 0:
v = v.astype("uint8")
elif max <= uint16max and min >= 0:
v = v.astype("uint16")
elif max <= uint32max and min >= 0:
v = v.astype("uint32")
else:
return v

dtype = str(v.dtype)

if dtype in plotlyjsShortTypes:
arrObj = {
"dtype": plotlyjsShortTypes[dtype],
"bdata": base64.b64encode(v).decode("ascii"),
}

if v.ndim > 1:
arrObj["shape"] = str(v.shape)[1:-1]

return arrObj

return v


def is_skipped_key(key):
"""
Return whether the key is skipped for conversion to the typed array spec
"""
skipped_keys = ["geojson", "layer", "layers", "range"]
return any(skipped_key == key for skipped_key in skipped_keys)


def convert_to_base64(obj):
if isinstance(obj, dict):
for key, value in obj.items():
if is_skipped_key(key):
continue
elif is_homogeneous_array(value):
obj[key] = to_typed_array_spec(value)
else:
convert_to_base64(value)
elif isinstance(obj, list) or isinstance(obj, tuple):
for value in obj:
convert_to_base64(value)


def cumsum(x):
Expand Down
4 changes: 4 additions & 0 deletions packages/python/plotly/plotly/basedatatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
display_string_positions,
chomp_empty_strings,
find_closest_string,
convert_to_base64,
)
from _plotly_utils.exceptions import PlotlyKeyError
from .optional_imports import get_module
Expand Down Expand Up @@ -3310,6 +3311,9 @@ def to_dict(self):
if frames:
res["frames"] = frames

# Add base64 conversion before sending to the front-end
convert_to_base64(res)

return res

def to_plotly_json(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import json
from unittest import TestCase
import numpy as np
from plotly.tests.test_optional.optional_utils import NumpyTestUtilsMixin
import plotly.graph_objs as go


class TestShouldNotUseBase64InUnsupportedKeys(NumpyTestUtilsMixin, TestCase):
def test_np_geojson(self):
normal_coordinates = [
[
[-87, 35],
[-87, 30],
[-85, 30],
[-85, 35],
]
]

numpy_coordinates = np.array(normal_coordinates)

data = [
{
"type": "choropleth",
"locations": ["AL"],
"featureidkey": "properties.id",
"z": np.array([10]),
"geojson": {
"type": "Feature",
"properties": {"id": "AL"},
"geometry": {"type": "Polygon", "coordinates": numpy_coordinates},
},
}
]

fig = go.Figure(data=data)

assert (
json.loads(fig.to_json())["data"][0]["geojson"]["geometry"]["coordinates"]
== normal_coordinates
)

def test_np_layers(self):
layout = {
"mapbox": {
"layers": [
{
"sourcetype": "geojson",
"type": "line",
"line": {"dash": np.array([2.5, 1])},
"source": {
"type": "FeatureCollection",
"features": [
{
"type": "Feature",
"geometry": {
"type": "LineString",
"coordinates": np.array(
[[0.25, 52], [0.75, 50]]
),
},
}
],
},
},
],
"center": {"lon": 0.5, "lat": 51},
},
}
data = [{"type": "scattermapbox"}]

fig = go.Figure(data=data, layout=layout)

assert (fig.layout["mapbox"]["layers"][0]["line"]["dash"] == (2.5, 1)).all()

assert json.loads(fig.to_json())["layout"]["mapbox"]["layers"][0]["source"][
"features"
][0]["geometry"]["coordinates"] == [[0.25, 52], [0.75, 50]]

def test_np_range(self):
layout = {"xaxis": {"range": np.array([0, 1])}}

fig = go.Figure(data=[{"type": "scatter"}], layout=layout)

assert json.loads(fig.to_json())["layout"]["xaxis"]["range"] == [0, 1]
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ def _compare_figures(go_trace, px_fig):
def test_pie_like_px():
# Pie
labels = ["Oxygen", "Hydrogen", "Carbon_Dioxide", "Nitrogen"]
values = [4500, 2500, 1053, 500]
values = np.array([4500, 2500, 1053, 500])

fig = px.pie(names=labels, values=values)
trace = go.Pie(labels=labels, values=values)
_compare_figures(trace, fig)

labels = ["Eve", "Cain", "Seth", "Enos", "Noam", "Abel", "Awan", "Enoch", "Azura"]
parents = ["", "Eve", "Eve", "Seth", "Seth", "Eve", "Eve", "Awan", "Eve"]
values = [10, 14, 12, 10, 2, 6, 6, 4, 4]
values = np.array([10, 14, 12, 10, 2, 6, 6, 4, 4])
# Sunburst
fig = px.sunburst(names=labels, parents=parents, values=values)
trace = go.Sunburst(labels=labels, parents=parents, values=values)
Expand All @@ -45,7 +45,7 @@ def test_pie_like_px():

# Funnel
x = ["A", "B", "C"]
y = [3, 2, 1]
y = np.array([3, 2, 1])
fig = px.funnel(y=y, x=x)
trace = go.Funnel(y=y, x=x)
_compare_figures(trace, fig)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,38 +372,6 @@ def test_invalid_encode_exception(self):
with self.assertRaises(TypeError):
_json.dumps({"a": {1}}, cls=utils.PlotlyJSONEncoder)

def test_fast_track_finite_arrays(self):
# if NaN or Infinity is found in the json dump
# of a figure, it is decoded and re-encoded to replace these values
# with null. This test checks that NaN and Infinity values are
# indeed converted to null, and that the encoding of figures
# without inf or nan is faster (because we can avoid decoding
# and reencoding).
z = np.random.randn(100, 100)
x = np.arange(100.0)
fig_1 = go.Figure(go.Heatmap(z=z, x=x))
t1 = time()
json_str_1 = _json.dumps(fig_1, cls=utils.PlotlyJSONEncoder)
t2 = time()
x[0] = np.nan
x[1] = np.inf
fig_2 = go.Figure(go.Heatmap(z=z, x=x))
t3 = time()
json_str_2 = _json.dumps(fig_2, cls=utils.PlotlyJSONEncoder)
t4 = time()
assert t2 - t1 < t4 - t3
assert "null" in json_str_2
assert "NaN" not in json_str_2
assert "Infinity" not in json_str_2
x = np.arange(100.0)
fig_3 = go.Figure(go.Heatmap(z=z, x=x))
fig_3.update_layout(title_text="Infinity")
t5 = time()
json_str_3 = _json.dumps(fig_3, cls=utils.PlotlyJSONEncoder)
t6 = time()
assert t2 - t1 < t6 - t5
assert "Infinity" in json_str_3


class TestNumpyIntegerBaseType(TestCase):
def test_numpy_integer_import(self):
Expand Down

0 comments on commit 8c75004

Please sign in to comment.