Skip to content

Commit 60c792b

Browse files
authored
Merge branch 'main' into feature/outline_refactoring
2 parents 7eb6baa + c5c55b3 commit 60c792b

File tree

13 files changed

+224
-69
lines changed

13 files changed

+224
-69
lines changed

.github/release.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ changelog:
44
- release-ignore
55
authors:
66
- pre-commit-ci
7+
- pre-commit-ci[bot]
78
categories:
89
- title: Added
910
labels:

.mypy.ini

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
[mypy]
22
python_version = 3.10
3-
plugins = numpy.typing.mypy_plugin
43

54
ignore_errors = False
65
warn_redundant_casts = True

.pre-commit-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,17 @@ ci:
99
skip: []
1010
repos:
1111
- repo: https://github.com/rbubley/mirrors-prettier
12-
rev: v3.5.3
12+
rev: v3.6.2
1313
hooks:
1414
- id: prettier
1515
- repo: https://github.com/astral-sh/ruff-pre-commit
16-
rev: v0.11.9
16+
rev: v0.13.1
1717
hooks:
1818
- id: ruff
1919
args: [--fix, --exit-non-zero-on-fix]
2020
- id: ruff-format
2121
- repo: https://github.com/pre-commit/mirrors-mypy
22-
rev: v1.15.0
22+
rev: v1.18.2
2323
hooks:
2424
- id: mypy
2525
additional_dependencies: [numpy, types-requests]

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
info = metadata("spatialdata-plot")
2222
project_name = info["Name"]
2323
author = info["Author"]
24-
copyright = f"{datetime.now():%Y}, {author}."
24+
copyright = f"{datetime.now():%Y}, {author}"
2525
version = info["Version"]
2626

2727
# repository_url = f"https://github.com/scverse/{project_name}"

src/spatialdata_plot/pl/basic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,7 @@ def show(
906906
cs_contents.query(f"cs == '{cs}'").iloc[0, :].values.tolist()
907907
)
908908
ax = fig_params.ax if fig_params.axs is None else fig_params.axs[i]
909+
assert isinstance(ax, Axes)
909910

910911
wants_images = False
911912
wants_labels = False

src/spatialdata_plot/pl/render.py

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def _render_shapes(
137137
if isinstance(groups, list) and color_source_vector is not None:
138138
mask = color_source_vector.isin(groups)
139139
shapes = shapes[mask]
140-
shapes = shapes.reset_index()
140+
shapes = shapes.reset_index(drop=True)
141141
color_source_vector = color_source_vector[mask]
142142
color_vector = color_vector[mask]
143143

@@ -363,8 +363,10 @@ def _render_shapes(
363363
cax = None
364364
if aggregate_with_reduction is not None:
365365
vmin = aggregate_with_reduction[0].values if norm.vmin is None else norm.vmin
366-
vmax = aggregate_with_reduction[1].values if norm.vmin is None else norm.vmax
366+
vmax = aggregate_with_reduction[1].values if norm.vmax is None else norm.vmax
367367
if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax:
368+
assert norm.vmin is not None
369+
assert norm.vmax is not None
368370
# value (vmin=vmax) is placed in the middle of the colorbar so that we can distinguish it from over and
369371
# under values in case clip=True or clip=False with cmap(under)=cmap(0) & cmap(over)=cmap(1)
370372
vmin = norm.vmin - 0.5
@@ -766,6 +768,8 @@ def _render_points(
766768
vmin = aggregate_with_reduction[0].values if norm.vmin is None else norm.vmin
767769
vmax = aggregate_with_reduction[1].values if norm.vmax is None else norm.vmax
768770
if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax:
771+
assert norm.vmin is not None
772+
assert norm.vmax is not None
769773
# value (vmin=vmax) is placed in the middle of the colorbar so that we can distinguish it from over and
770774
# under values in case clip=True or clip=False with cmap(under)=cmap(0) & cmap(over)=cmap(1)
771775
vmin = norm.vmin - 0.5
@@ -922,20 +926,22 @@ def _render_images(
922926
# 2) Image has any number of channels but 1
923927
else:
924928
layers = {}
925-
for ch_index, c in enumerate(channels):
926-
layers[c] = img.sel(c=c).copy(deep=True).squeeze()
927-
928-
if not isinstance(render_params.cmap_params, list):
929-
if render_params.cmap_params.norm is not None:
930-
layers[c] = render_params.cmap_params.norm(layers[c])
929+
for ch_idx, ch in enumerate(channels):
930+
layers[ch] = img.sel(c=ch).copy(deep=True).squeeze()
931+
if isinstance(render_params.cmap_params, list):
932+
ch_norm = render_params.cmap_params[ch_idx].norm
933+
ch_cmap_is_default = render_params.cmap_params[ch_idx].cmap_is_default
931934
else:
932-
if render_params.cmap_params[ch_index].norm is not None:
933-
layers[c] = render_params.cmap_params[ch_index].norm(layers[c])
935+
ch_norm = render_params.cmap_params.norm
936+
ch_cmap_is_default = render_params.cmap_params.cmap_is_default
937+
938+
if not ch_cmap_is_default and ch_norm is not None:
939+
layers[ch_idx] = ch_norm(layers[ch_idx])
934940

935941
# 2A) Image has 3 channels, no palette info, and no/only one cmap was given
936942
if palette is None and n_channels == 3 and not isinstance(render_params.cmap_params, list):
937943
if render_params.cmap_params.cmap_is_default: # -> use RGB
938-
stacked = np.stack([layers[c] for c in channels], axis=-1)
944+
stacked = np.stack([layers[ch] for ch in layers], axis=-1)
939945
else: # -> use given cmap for each channel
940946
channel_cmaps = [render_params.cmap_params.cmap] * n_channels
941947
stacked = (
@@ -968,12 +974,54 @@ def _render_images(
968974
# overwrite if n_channels == 2 for intuitive result
969975
if n_channels == 2:
970976
seed_colors = ["#ff0000ff", "#00ff00ff"]
971-
else:
977+
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
978+
colored = np.stack(
979+
[channel_cmaps[ch_ind](layers[ch]) for ch_ind, ch in enumerate(channels)],
980+
0,
981+
).sum(0)
982+
colored = colored[:, :, :3]
983+
elif n_channels == 3:
972984
seed_colors = _get_colors_for_categorical_obs(list(range(n_channels)))
985+
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
986+
colored = np.stack(
987+
[channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)],
988+
0,
989+
).sum(0)
990+
colored = colored[:, :, :3]
991+
else:
992+
if isinstance(render_params.cmap_params, list):
993+
cmap_is_default = render_params.cmap_params[0].cmap_is_default
994+
else:
995+
cmap_is_default = render_params.cmap_params.cmap_is_default
973996

974-
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
975-
colored = np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0)
976-
colored = colored[:, :, :3]
997+
if cmap_is_default:
998+
seed_colors = _get_colors_for_categorical_obs(list(range(n_channels)))
999+
else:
1000+
# Sample n_channels colors evenly from the colormap
1001+
if isinstance(render_params.cmap_params, list):
1002+
seed_colors = [
1003+
render_params.cmap_params[i].cmap(i / (n_channels - 1)) for i in range(n_channels)
1004+
]
1005+
else:
1006+
seed_colors = [render_params.cmap_params.cmap(i / (n_channels - 1)) for i in range(n_channels)]
1007+
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
1008+
1009+
# Stack (n_channels, height, width) → (height*width, n_channels)
1010+
H, W = next(iter(layers.values())).shape
1011+
comp_rgb = np.zeros((H, W, 3), dtype=float)
1012+
1013+
# For each channel: map to RGBA, apply constant alpha, then add
1014+
for ch_idx, ch in enumerate(channels):
1015+
layer_arr = layers[ch]
1016+
rgba = channel_cmaps[ch_idx](layer_arr)
1017+
rgba[..., 3] = render_params.alpha
1018+
comp_rgb += rgba[..., :3] * rgba[..., 3][..., None]
1019+
1020+
colored = np.clip(comp_rgb, 0, 1)
1021+
logger.info(
1022+
f"Your image has {n_channels} channels. Sampling categorical colors and using "
1023+
f"multichannel strategy 'stack' to render."
1024+
) # TODO: update when pca is added as strategy
9771025

9781026
_ax_show_and_transform(
9791027
colored,
@@ -1019,6 +1067,7 @@ def _render_images(
10191067
zorder=render_params.zorder,
10201068
)
10211069

1070+
# 2D) Image has n channels, no palette but cmap info
10221071
elif palette is not None and got_multiple_cmaps:
10231072
raise ValueError("If 'palette' is provided, 'cmap' must be None.")
10241073

src/spatialdata_plot/pl/utils.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,8 @@ def _prepare_cmap_norm(
504504

505505
cmap = copy(cmap)
506506

507+
assert isinstance(cmap, Colormap), f"Invalid type of `cmap`: {type(cmap)}, expected `Colormap`."
508+
507509
if norm is None:
508510
norm = Normalize(vmin=None, vmax=None, clip=False)
509511

@@ -824,8 +826,9 @@ def _set_color_source_vec(
824826

825827
color_source_vector = pd.Categorical(color_source_vector) # convert, e.g., `pd.Series`
826828

829+
# TODO check why table_name is not passed here.
827830
color_mapping = _get_categorical_color_mapping(
828-
adata=sdata.table,
831+
adata=sdata["table"],
829832
cluster_key=value_to_plot,
830833
color_source_vector=color_source_vector,
831834
cmap_params=cmap_params,
@@ -1940,7 +1943,8 @@ def _validate_label_render_params(
19401943

19411944
element_params[el]["table_name"] = None
19421945
element_params[el]["color"] = None
1943-
if (color := param_dict["color"]) is not None:
1946+
color = param_dict["color"]
1947+
if color is not None:
19441948
color, table_name = _validate_col_for_column_table(sdata, el, color, param_dict["table_name"], labels=True)
19451949
element_params[el]["table_name"] = table_name
19461950
element_params[el]["color"] = color
@@ -2102,6 +2106,11 @@ def _validate_col_for_column_table(
21022106
if table_name not in tables or (
21032107
col_for_color not in sdata[table_name].obs.columns and col_for_color not in sdata[table_name].var_names
21042108
):
2109+
warnings.warn(
2110+
f"Table '{table_name}' does not annotate element '{element_name}'.",
2111+
UserWarning,
2112+
stacklevel=2,
2113+
)
21052114
table_name = None
21062115
col_for_color = None
21072116
else:
@@ -2115,7 +2124,7 @@ def _validate_col_for_column_table(
21152124
table_name = next(iter(tables))
21162125
if len(tables) > 1:
21172126
warnings.warn(
2118-
f"Multiple tables contain color column, using {table_name}",
2127+
f"Multiple tables contain column '{col_for_color}', using table '{table_name}'.",
21192128
UserWarning,
21202129
stacklevel=2,
21212130
)
@@ -2151,25 +2160,49 @@ def _validate_image_render_params(
21512160
element_params[el] = {}
21522161
spatial_element = param_dict["sdata"][el]
21532162

2163+
# robustly get channel names from image or multiscale image
21542164
spatial_element_ch = (
2155-
spatial_element.c if isinstance(spatial_element, DataArray) else spatial_element["scale0"].c
2165+
spatial_element.c.values if isinstance(spatial_element, DataArray) else spatial_element["scale0"].c.values
21562166
)
2157-
if (channel := param_dict["channel"]) is not None and (
2158-
(isinstance(channel[0], int) and max([abs(ch) for ch in channel]) <= len(spatial_element_ch))
2159-
or all(ch in spatial_element_ch for ch in channel)
2160-
):
2167+
channel = param_dict["channel"]
2168+
if channel is not None:
2169+
# Normalize channel to always be a list of str or a list of int
2170+
if isinstance(channel, str):
2171+
channel = [channel]
2172+
2173+
if isinstance(channel, int):
2174+
channel = [channel]
2175+
2176+
# If channel is a list, ensure all elements are the same type
2177+
if not (isinstance(channel, list) and channel and all(isinstance(c, type(channel[0])) for c in channel)):
2178+
raise TypeError("Each item in 'channel' list must be of the same type, either string or integer.")
2179+
2180+
invalid = [c for c in channel if c not in spatial_element_ch]
2181+
if invalid:
2182+
raise ValueError(
2183+
f"Invalid channel(s): {', '.join(str(c) for c in invalid)}. Valid choices are: {spatial_element_ch}"
2184+
)
21612185
element_params[el]["channel"] = channel
21622186
else:
21632187
element_params[el]["channel"] = None
21642188

21652189
element_params[el]["alpha"] = param_dict["alpha"]
21662190

2167-
if isinstance(palette := param_dict["palette"], list):
2191+
palette = param_dict["palette"]
2192+
assert isinstance(palette, list | type(None)) # if present, was converted to list, just to make sure
2193+
2194+
if isinstance(palette, list):
2195+
# case A: single palette for all channels
21682196
if len(palette) == 1:
21692197
palette_length = len(channel) if channel is not None else len(spatial_element_ch)
21702198
palette = palette * palette_length
2171-
if (channel is not None and len(palette) != len(channel)) and len(palette) != len(spatial_element_ch):
2172-
palette = None
2199+
# case B: one palette per channel (either given or derived from channel length)
2200+
channels_to_use = spatial_element_ch if element_params[el]["channel"] is None else channel
2201+
if channels_to_use is not None and len(palette) != len(channels_to_use):
2202+
raise ValueError(
2203+
f"Palette length ({len(palette)}) does not match channel length "
2204+
f"({', '.join(str(c) for c in channels_to_use)})."
2205+
)
21732206
element_params[el]["palette"] = palette
21742207
element_params[el]["na_color"] = param_dict["na_color"]
21752208

@@ -2473,7 +2506,9 @@ def _get_datashader_trans_matrix_of_single_element(
24732506
# no flipping needed
24742507
return tm
24752508
# for a Translation, we need the transposed transformation matrix
2476-
return tm.T
2509+
tm_T = tm.T
2510+
assert isinstance(tm_T, np.ndarray)
2511+
return tm_T
24772512

24782513

24792514
def _get_transformation_matrix_for_datashader(
-11.2 KB
Loading

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def test_sdata_multiple_images_diverging_dims():
154154
def sdata_blobs_shapes_annotated() -> SpatialData:
155155
"""Get blobs sdata with continuous annotation of polygons."""
156156
blob = blobs()
157-
blob["table"].obs["region"] = "blobs_polygons"
157+
blob["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * blob["table"].n_obs)
158158
blob["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
159159
blob.shapes["blobs_polygons"]["value"] = [1, 2, 3, 4, 5]
160160
return blob

0 commit comments

Comments
 (0)