Skip to content

fix image rendering (clipping warning) #471

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 59 additions & 14 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def _render_shapes(
cax = None
if aggregate_with_reduction is not None:
vmin = aggregate_with_reduction[0].values if norm.vmin is None else norm.vmin
vmax = aggregate_with_reduction[1].values if norm.vmin is None else norm.vmax
vmax = aggregate_with_reduction[1].values if norm.vmax is None else norm.vmax
if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax:
# value (vmin=vmax) is placed in the middle of the colorbar so that we can distinguish it from over and
# under values in case clip=True or clip=False with cmap(under)=cmap(0) & cmap(over)=cmap(1)
Expand Down Expand Up @@ -846,20 +846,22 @@ def _render_images(
# 2) Image has any number of channels but 1
else:
layers = {}
for ch_index, c in enumerate(channels):
layers[c] = img.sel(c=c).copy(deep=True).squeeze()

if not isinstance(render_params.cmap_params, list):
if render_params.cmap_params.norm is not None:
layers[c] = render_params.cmap_params.norm(layers[c])
for ch_idx, ch in enumerate(channels):
layers[ch] = img.sel(c=ch).copy(deep=True).squeeze()
if isinstance(render_params.cmap_params, list):
ch_norm = render_params.cmap_params[ch_idx].norm
ch_cmap_is_default = render_params.cmap_params[ch_idx].cmap_is_default
else:
if render_params.cmap_params[ch_index].norm is not None:
layers[c] = render_params.cmap_params[ch_index].norm(layers[c])
ch_norm = render_params.cmap_params.norm
ch_cmap_is_default = render_params.cmap_params.cmap_is_default

if not ch_cmap_is_default and ch_norm is not None:
layers[ch_idx] = ch_norm(layers[ch_idx])

# 2A) Image has 3 channels, no palette info, and no/only one cmap was given
if palette is None and n_channels == 3 and not isinstance(render_params.cmap_params, list):
if render_params.cmap_params.cmap_is_default: # -> use RGB
stacked = np.stack([layers[c] for c in channels], axis=-1)
stacked = np.stack([layers[ch] for ch in layers], axis=-1)
else: # -> use given cmap for each channel
channel_cmaps = [render_params.cmap_params.cmap] * n_channels
stacked = (
Expand Down Expand Up @@ -892,12 +894,54 @@ def _render_images(
# overwrite if n_channels == 2 for intuitive result
if n_channels == 2:
seed_colors = ["#ff0000ff", "#00ff00ff"]
else:
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
colored = np.stack(
[channel_cmaps[ch_ind](layers[ch]) for ch_ind, ch in enumerate(channels)],
0,
).sum(0)
colored = colored[:, :, :3]
elif n_channels == 3:
seed_colors = _get_colors_for_categorical_obs(list(range(n_channels)))
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
colored = np.stack(
[channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)],
0,
).sum(0)
colored = colored[:, :, :3]
else:
if isinstance(render_params.cmap_params, list):
cmap_is_default = render_params.cmap_params[0].cmap_is_default
else:
cmap_is_default = render_params.cmap_params.cmap_is_default

channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
colored = np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0)
colored = colored[:, :, :3]
if cmap_is_default:
seed_colors = _get_colors_for_categorical_obs(list(range(n_channels)))
else:
# Sample n_channels colors evenly from the colormap
if isinstance(render_params.cmap_params, list):
seed_colors = [
render_params.cmap_params[i].cmap(i / (n_channels - 1)) for i in range(n_channels)
]
else:
seed_colors = [render_params.cmap_params.cmap(i / (n_channels - 1)) for i in range(n_channels)]
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]

# Stack (n_channels, height, width) → (height*width, n_channels)
H, W = next(iter(layers.values())).shape
comp_rgb = np.zeros((H, W, 3), dtype=float)

# For each channel: map to RGBA, apply constant alpha, then add
for ch_idx, ch in enumerate(channels):
layer_arr = layers[ch]
rgba = channel_cmaps[ch_idx](layer_arr)
rgba[..., 3] = render_params.alpha
comp_rgb += rgba[..., :3] * rgba[..., 3][..., None]

colored = np.clip(comp_rgb, 0, 1)
logger.info(
f"Your image has {n_channels} channels. Sampling categorical colors and using "
f"multichannel strategy 'stack' to render."
) # TODO: update when pca is added as strategy

_ax_show_and_transform(
colored,
Expand Down Expand Up @@ -943,6 +987,7 @@ def _render_images(
zorder=render_params.zorder,
)

# 2D) Image has n channels, no palette but cmap info
elif palette is not None and got_multiple_cmaps:
raise ValueError("If 'palette' is provided, 'cmap' must be None.")

Expand Down
48 changes: 36 additions & 12 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2006,7 +2006,7 @@ def _validate_col_for_column_table(
table_name = next(iter(tables))
if len(tables) > 1:
warnings.warn(
f"Multiple tables contain color column, using {table_name}",
f"Multiple tables contain column '{col_for_color}', using table '{table_name}'.",
UserWarning,
stacklevel=2,
)
Expand Down Expand Up @@ -2042,25 +2042,49 @@ def _validate_image_render_params(
element_params[el] = {}
spatial_element = param_dict["sdata"][el]

# robustly get channel names from image or multiscale image
spatial_element_ch = (
spatial_element.c if isinstance(spatial_element, DataArray) else spatial_element["scale0"].c
spatial_element.c.values if isinstance(spatial_element, DataArray) else spatial_element["scale0"].c.values
)
if (channel := param_dict["channel"]) is not None and (
(isinstance(channel[0], int) and max([abs(ch) for ch in channel]) <= len(spatial_element_ch))
or all(ch in spatial_element_ch for ch in channel)
):
channel = param_dict["channel"]
if channel is not None:
# Normalize channel to always be a list of str or a list of int
if isinstance(channel, str):
channel = [channel]

if isinstance(channel, int):
channel = [channel]

# If channel is a list, ensure all elements are the same type
if not (isinstance(channel, list) and channel and all(isinstance(c, type(channel[0])) for c in channel)):
raise TypeError("Each item in 'channel' list must be of the same type, either string or integer.")

invalid = [c for c in channel if c not in spatial_element_ch]
if invalid:
raise ValueError(
f"Invalid channel(s): {', '.join(str(c) for c in invalid)}. Valid choices are: {spatial_element_ch}"
)
element_params[el]["channel"] = channel
else:
element_params[el]["channel"] = None

element_params[el]["alpha"] = param_dict["alpha"]

if isinstance(palette := param_dict["palette"], list):
palette = param_dict["palette"]
assert isinstance(palette, list | type(None)) # if present, was converted to list, just to make sure

if isinstance(palette, list):
# case A: single palette for all channels
if len(palette) == 1:
palette_length = len(channel) if channel is not None else len(spatial_element_ch)
palette = palette * palette_length
if (channel is not None and len(palette) != len(channel)) and len(palette) != len(spatial_element_ch):
palette = None
# case B: one palette per channel (either given or derived from channel length)
channels_to_use = spatial_element_ch if element_params[el]["channel"] is None else channel
if channels_to_use is not None and len(palette) != len(channels_to_use):
raise ValueError(
f"Palette length ({len(palette)}) does not match channel length "
f"({', '.join(str(c) for c in channels_to_use)})."
)
element_params[el]["palette"] = palette
element_params[el]["na_color"] = param_dict["na_color"]

Expand All @@ -2086,7 +2110,7 @@ def _validate_image_render_params(
def _get_wanted_render_elements(
sdata: SpatialData,
sdata_wanted_elements: list[str],
params: (ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams),
params: ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams,
cs: str,
element_type: Literal["images", "labels", "points", "shapes"],
) -> tuple[list[str], list[str], bool]:
Expand Down Expand Up @@ -2243,7 +2267,7 @@ def _create_image_from_datashader_result(


def _datashader_aggregate_with_function(
reduction: (Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None),
reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None,
cvs: Canvas,
spatial_element: GeoDataFrame | dask.dataframe.core.DataFrame,
col_for_color: str | None,
Expand Down Expand Up @@ -2307,7 +2331,7 @@ def _datashader_aggregate_with_function(


def _datshader_get_how_kw_for_spread(
reduction: (Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None),
reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None,
) -> str:
# Get the best input for the how argument of ds.tf.spread(), needed for numerical values
reduction = reduction or "sum"
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading