Skip to content

Use Image trace and add low-res slices #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 28 additions & 28 deletions dash_3d_viewer/slicer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from plotly.graph_objects import Figure
from plotly.graph_objects import Figure, Image
from dash import Dash
from dash.dependencies import Input, Output, State
from dash_core_components import Graph, Slider, Store
Expand Down Expand Up @@ -29,44 +29,39 @@ def __init__(self, app, volume, axis=0, id=None):
self._id = id

# Get the slice size (width, height), and max index
arr_shape = list(volume.shape)
arr_shape.pop(self._axis)
slice_size = list(reversed(arr_shape))
# arr_shape = list(volume.shape)
# arr_shape.pop(self._axis)
# slice_size = list(reversed(arr_shape))
self._max_index = self._volume.shape[self._axis] - 1

# Prep low-res slices
thumbnails = [
img_array_to_uri(self._slice(i), (32, 32))
for i in range(self._max_index + 1)
]

# Create a placeholder trace
# todo: can add "%{z[0]}", but that would be the scaled value ...
trace = Image(source="", hovertemplate="(%{x}, %{y})<extra></extra>")
# Create the figure object
fig = Figure()
fig = Figure(data=[trace])
fig.update_layout(
template=None,
margin=dict(l=0, r=0, b=0, t=0, pad=4),
)
fig.update_xaxes(
# range=(0, slice_size[0]),
showgrid=False,
range=(0, slice_size[0]),
showticklabels=False,
zeroline=False,
)
fig.update_yaxes(
# range=(slice_size[1], 0), # todo: allow flipping x or y
showgrid=False,
scaleanchor="x",
range=(slice_size[1], 0), # todo: allow flipping x or y
showticklabels=False,
zeroline=False,
)
# Add an empty layout image that we can populate from JS.
fig.add_layout_image(
dict(
source="",
xref="x",
yref="y",
x=0,
y=0,
sizex=slice_size[0],
sizey=slice_size[1],
sizing="contain",
layer="below",
)
)
# Wrap the figure in a graph
# todo: or should the user provide this?
self.graph = Graph(
Expand All @@ -88,6 +83,7 @@ def __init__(self, app, volume, axis=0, id=None):
Store(id=self._subid("slice-index"), data=volume.shape[self._axis] // 2),
Store(id=self._subid("_requested-slice-index"), data=0),
Store(id=self._subid("_slice-data"), data=""),
Store(id=self._subid("_slice-data-lowres"), data=thumbnails),
]

self._create_server_callbacks(app)
Expand All @@ -101,7 +97,8 @@ def _slice(self, index):
"""Sample a slice from the volume."""
indices = [slice(None), slice(None), slice(None)]
indices[self._axis] = index
return self._volume[tuple(indices)]
im = self._volume[tuple(indices)]
return (im.astype(np.float32) * (255 / im.max())).astype(np.uint8)

def _create_server_callbacks(self, app):
"""Create the callbacks that run server-side."""
Expand All @@ -112,7 +109,6 @@ def _create_server_callbacks(self, app):
)
def upload_requested_slice(slice_index):
slice = self._slice(slice_index)
slice = (slice.astype(np.float32) * (255 / slice.max())).astype(np.uint8)
return [slice_index, img_array_to_uri(slice)]

def _create_client_callbacks(self, app):
Expand Down Expand Up @@ -158,7 +154,7 @@ def _create_client_callbacks(self, app):

app.clientside_callback(
"""
function handle_incoming_slice(index, index_and_data, ori_figure) {
function handle_incoming_slice(index, index_and_data, ori_figure, lowres) {
let new_index = index_and_data[0];
let new_data = index_and_data[1];
// Store data in cache
Expand All @@ -167,17 +163,18 @@ def _create_client_callbacks(self, app):
slice_cache[new_index] = new_data;
// Get the data we need *now*
let data = slice_cache[index];
//slice_cache[new_index] = undefined; // todo: disabled cache for now!
// Maybe we do not need an update
if (!data) {
return window.dash_clientside.no_update;
data = lowres[index];
}
if (data == ori_figure.layout.images[0].source) {
if (data == ori_figure.data[0].source) {
return window.dash_clientside.no_update;
}
// Otherwise, perform update
console.log("updating figure");
let figure = {...ori_figure};
figure.layout.images[0].source = data;
figure.data[0].source = data;
return figure;
}
""".replace(
Expand All @@ -188,5 +185,8 @@ def _create_client_callbacks(self, app):
Input(self._subid("slice-index"), "data"),
Input(self._subid("_slice-data"), "data"),
],
[State(self._subid("graph"), "figure")],
[
State(self._subid("graph"), "figure"),
State(self._subid("_slice-data-lowres"), "data"),
],
)
14 changes: 10 additions & 4 deletions dash_3d_viewer/utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
import io
import random
import base64

import PIL.Image
import skimage
from plotly.utils import ImageUriValidator


def gen_random_id(n=6):
return "".join(random.choice("abcdefghijklmnopqrtsuvwxyz") for i in range(n))


def img_array_to_uri(img_array):
def img_array_to_uri(img_array, new_size=None):
img_array = skimage.util.img_as_ubyte(img_array)
# todo: leverage this Plotly util once it becomes part of the public API (also drops the Pillow dependency)
# from plotly.express._imshow import _array_to_b64str
# return _array_to_b64str(img_array)
img_pil = PIL.Image.fromarray(img_array)
uri = ImageUriValidator.pil_image_to_uri(img_pil)
return uri
if new_size:
img_pil.thumbnail(new_size)
# The below was taken from plotly.utils.ImageUriValidator.pil_image_to_uri()
f = io.BytesIO()
img_pil.save(f, format="PNG")
base64_str = base64.b64encode(f.getvalue()).decode()
return "data:image/png;base64," + base64_str