Skip to content

Transformations are applied before rendering with datashader #378

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,17 @@ and this project adheres to [Semantic Versioning][].
[keep a changelog]: https://keepachangelog.com/en/1.0.0/
[semantic versioning]: https://semver.org/spec/v2.0.0.html

## [0.2.9] - tbd

### Fixed

- Transformations of Points and Shapes are now applied before rendering with datashader (#378)

## [0.2.8] - 2024-11-26

### Changed
- Support for `xarray.DataTree` (which moved from `datatree.DataTree`) #380

- Support for `xarray.DataTree` (which moved from `datatree.DataTree`) (#380)

## [0.2.7] - 2024-10-24

Expand Down Expand Up @@ -45,10 +52,6 @@ and this project adheres to [Semantic Versioning][].

## [0.2.5] - 2024-08-23

### Added

-

### Changed

- Replaced `outline` parameter in `render_labels` with alpha-based logic (#323)
Expand Down
10 changes: 6 additions & 4 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def render_shapes(
palette: list[str] | str | None = None,
na_color: ColorLike | None = "default",
outline_width: float | int = 1.5,
outline_color: str | list[float] = "#000000ff",
outline_color: str | list[float] = "#000000",
outline_alpha: float | int = 0.0,
cmap: Colormap | str | None = None,
norm: Normalize | None = None,
Expand Down Expand Up @@ -208,9 +208,11 @@ def render_shapes(
won't be shown.
outline_width : float | int, default 1.5
Width of the border.
outline_color : str | list[float], default "#000000ff"
Color of the border. Can either be a named color ("red"), a hex representation ("#000000ff") or a list of
floats that represent RGB/RGBA values (1.0, 0.0, 0.0, 1.0).
outline_color : str | list[float], default "#000000"
Color of the border. Can either be a named color ("red"), a hex representation ("#000000") or a list of
floats that represent RGB/RGBA values (1.0, 0.0, 0.0, 1.0). If the hex representation includes alpha, e.g.
"#000000ff", the last two positions are ignored, since the alpha of the outlines is solely controlled by
`outline_alpha`.
outline_alpha : float | int, default 0.0
Alpha value for the outline of shapes. Invisible by default.
cmap : Colormap | str | None, optional
Expand Down
92 changes: 60 additions & 32 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
from matplotlib.colors import ListedColormap, Normalize
from scanpy._settings import settings as sc_settings
from spatialdata import get_extent
from spatialdata.models import PointsModel, get_table_keys
from spatialdata.transformations import (
set_transformation,
)
from spatialdata.models import PointsModel, ShapesModel, get_table_keys
from spatialdata.transformations import get_transformation, set_transformation
from spatialdata.transformations.transformations import Identity
from xarray import DataTree

from spatialdata_plot._logging import logger
Expand All @@ -44,6 +43,7 @@
_get_colors_for_categorical_obs,
_get_extent_and_range_for_datashader_canvas,
_get_linear_colormap,
_get_transformation_matrix_for_datashader,
_is_coercable_to_float,
_map_color_seg,
_maybe_set_colors,
Expand Down Expand Up @@ -148,7 +148,7 @@ def _render_shapes(
colorbar = False if col_for_color is None else legend_params.colorbar

# Apply the transformation to the PatchCollection's paths
trans, _ = _prepare_transformation(sdata_filt.shapes[element], coordinate_system)
trans, trans_data = _prepare_transformation(sdata_filt.shapes[element], coordinate_system)

shapes = gpd.GeoDataFrame(shapes, geometry="geometry")

Expand All @@ -168,14 +168,6 @@ def _render_shapes(
)

if method == "datashader":
trans += ax.transData

plot_width, plot_height, x_ext, y_ext, factor = _get_extent_and_range_for_datashader_canvas(
sdata_filt.shapes[element], coordinate_system, ax, fig_params
)

cvs = ds.Canvas(plot_width=plot_width, plot_height=plot_height, x_range=x_ext, y_range=y_ext)

_geometry = shapes["geometry"]
is_point = _geometry.type == "Point"

Expand All @@ -184,36 +176,48 @@ def _render_shapes(
scale = shapes[is_point]["radius"] * render_params.scale
sdata_filt.shapes[element].loc[is_point, "geometry"] = _geometry[is_point].buffer(scale.to_numpy())

# apply transformations to the individual points
element_trans = get_transformation(sdata_filt.shapes[element])
tm = _get_transformation_matrix_for_datashader(element_trans)
transformed_element = sdata_filt.shapes[element].transform(
lambda x: (np.hstack([x, np.ones((x.shape[0], 1))]) @ tm)[:, :2]
)
transformed_element = ShapesModel.parse(
gpd.GeoDataFrame(data=sdata_filt.shapes[element].drop("geometry", axis=1), geometry=transformed_element)
)

plot_width, plot_height, x_ext, y_ext, factor = _get_extent_and_range_for_datashader_canvas(
transformed_element, coordinate_system, ax, fig_params
)

cvs = ds.Canvas(plot_width=plot_width, plot_height=plot_height, x_range=x_ext, y_range=y_ext)

# in case we are coloring by a column in table
if col_for_color is not None and col_for_color not in sdata_filt.shapes[element].columns:
sdata_filt.shapes[element][col_for_color] = (
color_vector if color_source_vector is None else color_source_vector
)
if col_for_color is not None and col_for_color not in transformed_element.columns:
transformed_element[col_for_color] = color_vector if color_source_vector is None else color_source_vector
# Render shapes with datashader
color_by_categorical = col_for_color is not None and color_source_vector is not None
aggregate_with_reduction = None
if col_for_color is not None and (render_params.groups is None or len(render_params.groups) > 1):
if color_by_categorical:
agg = cvs.polygons(
sdata_filt.shapes[element], geometry="geometry", agg=ds.by(col_for_color, ds.count())
)
agg = cvs.polygons(transformed_element, geometry="geometry", agg=ds.by(col_for_color, ds.count()))
else:
reduction_name = render_params.ds_reduction if render_params.ds_reduction is not None else "mean"
logger.info(
f'Using the datashader reduction "{reduction_name}". "max" will give an output very close '
"to the matplotlib result."
)
agg = _datashader_aggregate_with_function(
render_params.ds_reduction, cvs, sdata_filt.shapes[element], col_for_color, "shapes"
render_params.ds_reduction, cvs, transformed_element, col_for_color, "shapes"
)
# save min and max values for drawing the colorbar
aggregate_with_reduction = (agg.min(), agg.max())
else:
agg = cvs.polygons(sdata_filt.shapes[element], geometry="geometry", agg=ds.count())
agg = cvs.polygons(transformed_element, geometry="geometry", agg=ds.count())
# render outlines if needed
if (render_outlines := render_params.outline_alpha) > 0:
agg_outlines = cvs.line(
sdata_filt.shapes[element],
transformed_element,
geometry="geometry",
line_width=render_params.outline_params.linewidth,
)
Expand Down Expand Up @@ -287,13 +291,23 @@ def _render_shapes(

rgba_image, trans_data = _create_image_from_datashader_result(ds_result, factor, ax)
_cax = _ax_show_and_transform(
rgba_image, trans_data, ax, zorder=render_params.zorder, alpha=render_params.fill_alpha
rgba_image,
trans_data,
ax,
zorder=render_params.zorder,
alpha=render_params.fill_alpha,
extent=x_ext + y_ext,
)
# render outline image if needed
if render_outlines:
rgba_image, trans_data = _create_image_from_datashader_result(ds_outlines, factor, ax)
_ax_show_and_transform(
rgba_image, trans_data, ax, zorder=render_params.zorder, alpha=render_params.outline_alpha
rgba_image,
trans_data,
ax,
zorder=render_params.zorder,
alpha=render_params.outline_alpha,
extent=x_ext + y_ext,
)

cax = None
Expand Down Expand Up @@ -330,7 +344,7 @@ def _render_shapes(

if not values_are_categorical:
# If the user passed a Normalize object with vmin/vmax we'll use those,
# # if not we'll use the min/max of the color_vector
# if not we'll use the min/max of the color_vector
_cax.set_clim(
vmin=render_params.cmap_params.norm.vmin or min(color_vector),
vmax=render_params.cmap_params.norm.vmax or max(color_vector),
Expand Down Expand Up @@ -468,7 +482,7 @@ def _render_points(
if color_source_vector is None and render_params.transfunc is not None:
color_vector = render_params.transfunc(color_vector)

_, trans_data = _prepare_transformation(sdata.points[element], coordinate_system, ax)
trans, trans_data = _prepare_transformation(sdata.points[element], coordinate_system, ax)

norm = copy(render_params.cmap_params.norm)

Expand All @@ -491,8 +505,15 @@ def _render_points(
# use dpi/100 as a factor for cases where dpi!=100
px = int(np.round(np.sqrt(render_params.size) * (fig_params.fig.dpi / 100)))

# apply transformations
transformed_element = PointsModel.parse(
trans.transform(sdata_filt.points[element][["x", "y"]]),
annotation=sdata_filt.points[element][sdata_filt.points[element].columns.drop(["x", "y"])],
transformations={coordinate_system: Identity()},
)

plot_width, plot_height, x_ext, y_ext, factor = _get_extent_and_range_for_datashader_canvas(
sdata_filt.points[element], coordinate_system, ax, fig_params
transformed_element, coordinate_system, ax, fig_params
)

# use datashader for the visualization of points
Expand All @@ -502,20 +523,20 @@ def _render_points(
aggregate_with_reduction = None
if col_for_color is not None and (render_params.groups is None or len(render_params.groups) > 1):
if color_by_categorical:
agg = cvs.points(sdata_filt.points[element], "x", "y", agg=ds.by(col_for_color, ds.count()))
agg = cvs.points(transformed_element, "x", "y", agg=ds.by(col_for_color, ds.count()))
else:
reduction_name = render_params.ds_reduction if render_params.ds_reduction is not None else "sum"
logger.info(
f'Using the datashader reduction "{reduction_name}". "max" will give an output very close '
"to the matplotlib result."
)
agg = _datashader_aggregate_with_function(
render_params.ds_reduction, cvs, sdata_filt.points[element], col_for_color, "points"
render_params.ds_reduction, cvs, transformed_element, col_for_color, "points"
)
# save min and max values for drawing the colorbar
aggregate_with_reduction = (agg.min(), agg.max())
else:
agg = cvs.points(sdata_filt.points[element], "x", "y", agg=ds.count())
agg = cvs.points(transformed_element, "x", "y", agg=ds.count())

if norm.vmin is not None or norm.vmax is not None:
norm.vmin = np.min(agg) if norm.vmin is None else norm.vmin
Expand Down Expand Up @@ -573,7 +594,14 @@ def _render_points(
)

rgba_image, trans_data = _create_image_from_datashader_result(ds_result, factor, ax)
_ax_show_and_transform(rgba_image, trans_data, ax, zorder=render_params.zorder, alpha=render_params.alpha)
_ax_show_and_transform(
rgba_image,
trans_data,
ax,
zorder=render_params.zorder,
alpha=render_params.alpha,
extent=x_ext + y_ext,
)

cax = None
if aggregate_with_reduction is not None:
Expand Down
57 changes: 55 additions & 2 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import matplotlib.transforms as mtransforms
import numpy as np
import numpy.ma as ma
import numpy.typing as npt
import pandas as pd
import shapely
import spatialdata as sd
Expand Down Expand Up @@ -58,8 +59,11 @@
from spatialdata._core.query.relational_query import _locate_value, _ValueOrigin
from spatialdata._types import ArrayLike
from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, SpatialElement, get_model

# from spatialdata.transformations.transformations import Scale
from spatialdata.transformations import Affine, Identity, MapAxis, Scale, Translation
from spatialdata.transformations import Sequence as SDSequence
from spatialdata.transformations.operations import get_transformation
from spatialdata.transformations.transformations import Scale
from xarray import DataArray, DataTree

from spatialdata_plot._logging import logger
Expand Down Expand Up @@ -1977,19 +1981,37 @@ def _ax_show_and_transform(
alpha: float | None = None,
cmap: ListedColormap | LinearSegmentedColormap | None = None,
zorder: int = 0,
extent: list[float] | None = None,
) -> matplotlib.image.AxesImage:
# default extent in mpl:
image_extent = [-0.5, array.shape[1] - 0.5, array.shape[0] - 0.5, -0.5]
if extent is not None:
# make sure extent is [x_min, x_max, y_min, y_max]
if extent[3] < extent[2]:
extent[2], extent[3] = extent[3], extent[2]
if extent[0] < 0:
x_factor = array.shape[1] / (extent[1] - extent[0])
image_extent[0] = image_extent[0] + (extent[0] * x_factor)
image_extent[1] = image_extent[1] + (extent[0] * x_factor)
if extent[2] < 0:
y_factor = array.shape[0] / (extent[3] - extent[2])
image_extent[2] = image_extent[2] + (extent[2] * y_factor)
image_extent[3] = image_extent[3] + (extent[2] * y_factor)

if not cmap and alpha is not None:
im = ax.imshow(
array,
alpha=alpha,
zorder=zorder,
extent=tuple(image_extent),
)
im.set_transform(trans_data)
else:
im = ax.imshow(
array,
cmap=cmap,
zorder=zorder,
extent=tuple(image_extent),
)
im.set_transform(trans_data)
return im
Expand Down Expand Up @@ -2055,7 +2077,7 @@ def _get_extent_and_range_for_datashader_canvas(

def _create_image_from_datashader_result(
ds_result: ds.transfer_functions.Image, factor: float, ax: Axes
) -> tuple[MaskedArray[tuple[int, ...], Any], matplotlib.transforms.CompositeGenericTransform]:
) -> tuple[MaskedArray[tuple[int, ...], Any], matplotlib.transforms.Transform]:
# create SpatialImage from datashader output to get it back to original size
rgba_image_data = ds_result.to_numpy().base
rgba_image_data = np.transpose(rgba_image_data, (2, 0, 1))
Expand Down Expand Up @@ -2187,3 +2209,34 @@ def _prepare_transformation(
trans_data = trans + ax.transData if ax is not None else None

return trans, trans_data


def _get_datashader_trans_matrix_of_single_element(
trans: Identity | Scale | Affine | MapAxis | Translation,
) -> npt.NDArray[Any]:
flip_matrix = np.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]])
tm: npt.NDArray[Any] = trans.to_affine_matrix(("x", "y"), ("x", "y"))

if isinstance(trans, Identity):
return np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
if isinstance(trans, (Scale | Affine)):
# idea: "flip the y-axis", apply transformation, flip back
flip_and_transform: npt.NDArray[Any] = flip_matrix @ tm @ flip_matrix
return flip_and_transform
if isinstance(trans, MapAxis):
# no flipping needed
return tm
# for a Translation, we need the transposed transformation matrix
return tm.T


def _get_transformation_matrix_for_datashader(
trans: Scale | Identity | Affine | MapAxis | Translation | SDSequence,
) -> npt.NDArray[Any]:
"""Get the affine matrix needed to transform shapes for rendering with datashader."""
if isinstance(trans, SDSequence):
tm = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
for x in trans.transformations:
tm = tm @ _get_datashader_trans_matrix_of_single_element(x)
return tm
return _get_datashader_trans_matrix_of_single_element(trans)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading