1111import pandas as pd
1212import scanpy as sc
1313import spatialdata as sd
14- import xarray as xr
1514from anndata import AnnData
1615from geopandas import GeoDataFrame
1716from matplotlib import colors
2928 OutlineParams ,
3029 ScalebarParams ,
3130 _decorate_axs ,
31+ _get_colors_for_categorical_obs ,
3232 _get_linear_colormap ,
3333 _map_color_seg ,
3434 _maybe_set_colors ,
@@ -327,7 +327,6 @@ def _render_images(
327327 fig_params : FigParams ,
328328 scalebar_params : ScalebarParams ,
329329 legend_params : LegendParams ,
330- # extent: tuple[float, float, float, float] | None = None,
331330) -> None :
332331 elements = render_params .elements
333332
@@ -345,8 +344,8 @@ def _render_images(
345344 images = [sdata .images [e ] for e in elements ]
346345
347346 for img in images :
348- if (len (img .c ) > 3 or len (img .c ) == 2 ) and render_params .channel is None :
349- raise NotImplementedError ("Only 1 or 3 channels are supported at the moment." )
347+ # if (len(img.c) > 3 or len(img.c) == 2) and render_params.channel is None:
348+ # raise NotImplementedError("Only 1 or 3 channels are supported at the moment.")
350349
351350 if render_params .channel is None :
352351 channels = img .coords ["c" ].values
@@ -357,11 +356,8 @@ def _render_images(
357356
358357 n_channels = len (channels )
359358
359+ # True if user gave n cmaps for n channels
360360 got_multiple_cmaps = isinstance (render_params .cmap_params , list )
361-
362- if not isinstance (render_params .cmap_params , list ):
363- render_params .cmap_params = [render_params .cmap_params ] * n_channels
364-
365361 if got_multiple_cmaps :
366362 logger .warning (
367363 "You're blending multiple cmaps. "
@@ -371,100 +367,113 @@ def _render_images(
371367 "Consider using 'palette' instead."
372368 )
373369
374- if render_params .palette is not None :
375- logger .warning ("Parameter 'palette' is ignored when a 'cmap' is provided." )
370+ # not using got_multiple_cmaps here because of ruff :(
371+ if isinstance (render_params .cmap_params , list ) and len (render_params .cmap_params ) != n_channels :
372+ raise ValueError ("If 'cmap' is provided, its length must match the number of channels." )
373+
374+ # 1) Image has only 1 channel
375+ if n_channels == 1 and not isinstance (render_params .cmap_params , list ):
376+ layer = img .sel (c = channels ).squeeze ()
377+
378+ if render_params .quantiles_for_norm != (None , None ):
379+ layer = _normalize (
380+ layer , pmin = render_params .quantiles_for_norm [0 ], pmax = render_params .quantiles_for_norm [1 ], clip = True
381+ )
382+
383+ if render_params .cmap_params .norm is not None : # type: ignore[attr-defined]
384+ layer = render_params .cmap_params .norm (layer ) # type: ignore[attr-defined]
385+
386+ if render_params .palette is None :
387+ cmap = render_params .cmap_params .cmap # type: ignore[attr-defined]
388+ else :
389+ cmap = _get_linear_colormap ([render_params .palette ], "k" )[0 ]
390+
391+ ax .imshow (
392+ layer , # get rid of the channel dimension
393+ cmap = cmap ,
394+ alpha = render_params .alpha ,
395+ )
376396
377- for idx , channel in enumerate (channels ):
378- layer = img .sel (c = channel )
397+ # 2) Image has any number of channels but 1
398+ else :
399+ layers = {}
400+ for i , c in enumerate (channels ):
401+ layers [c ] = img .sel (c = c ).copy (deep = True ).squeeze ()
379402
380403 if render_params .quantiles_for_norm != (None , None ):
381- layer = _normalize (
382- layer ,
404+ layers [ c ] = _normalize (
405+ layers [ c ] ,
383406 pmin = render_params .quantiles_for_norm [0 ],
384407 pmax = render_params .quantiles_for_norm [1 ],
385408 clip = True ,
386409 )
387410
388- if render_params .cmap_params [idx ].norm is not None :
389- layer = render_params .cmap_params [idx ].norm (layer )
411+ if not isinstance (render_params .cmap_params , list ):
412+ if render_params .cmap_params .norm is not None :
413+ layers [c ] = render_params .cmap_params .norm (layers [c ])
414+ else :
415+ if render_params .cmap_params [i ].norm is not None :
416+ layers [c ] = render_params .cmap_params [i ].norm (layers [c ])
390417
391- ax .imshow (
392- layer ,
393- cmap = render_params .cmap_params [idx ].cmap ,
394- alpha = (1 / n_channels ),
395- )
396- break
418+ # 2A) Image has 3 channels, no palette/cmap info -> use RGB
419+ if n_channels == 3 and render_params .palette is None and not got_multiple_cmaps :
420+ ax .imshow (np .stack ([layers [c ] for c in channels ], axis = - 1 ), alpha = render_params .alpha )
397421
398- if n_channels == 1 :
399- layer = img .sel (c = channels )
422+ # 2B) Image has n channels, no palette/cmap info -> sample n categorical colors
423+ elif render_params .palette is None and not got_multiple_cmaps :
424+ # overwrite if n_channels == 2 for intuitive result
425+ if n_channels == 2 :
426+ seed_colors = ["#ff0000ff" , "#00ff00ff" ]
427+ else :
428+ seed_colors = _get_colors_for_categorical_obs (list (range (n_channels )))
400429
401- if render_params .quantiles_for_norm != (None , None ):
402- layer = _normalize (
403- layer , pmin = render_params .quantiles_for_norm [0 ], pmax = render_params .quantiles_for_norm [1 ], clip = True
404- )
430+ channel_cmaps = [_get_linear_colormap ([c ], "k" )[0 ] for c in seed_colors ]
405431
406- if render_params . cmap_params [ 0 ]. norm is not None :
407- layer = render_params . cmap_params [ 0 ]. norm ( layer )
432+ # Apply cmaps to each channel and add up
433+ colored = np . stack ([ channel_cmaps [ i ]( layers [ c ]) for i , c in enumerate ( channels )], 0 ). sum ( 0 )
408434
409- if render_params .palette is None :
410- ax .imshow (
411- layer .squeeze (), # get rid of the channel dimension
412- cmap = render_params .cmap_params [0 ].cmap ,
413- )
435+ # Remove alpha channel so we can overwrite it from render_params.alpha
436+ colored = colored [:, :, :3 ]
414437
415- else :
416438 ax .imshow (
417- layer . squeeze (), # get rid of the channel dimension
418- cmap = _get_linear_colormap ([ render_params .palette ], "k" )[ 0 ] ,
439+ colored ,
440+ alpha = render_params .alpha ,
419441 )
420442
421- break
443+ # 2C) Image has n channels and palette info
444+ elif render_params .palette is not None and not got_multiple_cmaps :
445+ if len (render_params .palette ) != n_channels :
446+ raise ValueError ("If 'palette' is provided, its length must match the number of channels." )
422447
423- if render_params .palette is not None and n_channels != len (render_params .palette ):
424- raise ValueError ("If 'palette' is provided, its length must match the number of channels." )
448+ channel_cmaps = [_get_linear_colormap ([c ], "k" )[0 ] for c in render_params .palette ]
425449
426- if n_channels > 1 : # to capture n_channels = 3 and custom number cases
427- layer = img . sel ( c = channels ). copy ( deep = True )
450+ # Apply cmaps to each channel and add up
451+ colored = np . stack ([ channel_cmaps [ i ]( layers [ c ]) for i , c in enumerate ( channels )], 0 ). sum ( 0 )
428452
429- channel_colors : list [str ] | Any
430- if render_params .palette is None :
431- channel_colors = ["#ff0000ff" , "#00ff00ff" , "#0000ffff" ]
432- else :
433- channel_colors = render_params .palette
453+ # Remove alpha channel so we can overwrite it from render_params.alpha
454+ colored = colored [:, :, :3 ]
455+
456+ ax .imshow (
457+ colored ,
458+ alpha = render_params .alpha ,
459+ )
434460
435- channel_cmaps = _get_linear_colormap ([str (c ) for c in channel_colors [:n_channels ]], "k" )
461+ elif render_params .palette is None and got_multiple_cmaps :
462+ channel_cmaps = [cp .cmap for cp in render_params .cmap_params ] # type: ignore[union-attr]
436463
437- layer_vals = []
438- if render_params .quantiles_for_norm != (None , None ):
439- for i in range (n_channels ):
440- layer_vals .append (
441- _normalize (
442- layer .values [i ],
443- pmin = render_params .quantiles_for_norm [0 ],
444- pmax = render_params .quantiles_for_norm [1 ],
445- clip = True ,
446- )
447- )
464+ # Apply cmaps to each channel, add up and normalize to [0, 1]
465+ colored = np .stack ([channel_cmaps [i ](layers [c ]) for i , c in enumerate (channels )], 0 ).sum (0 ) / n_channels
448466
449- colored = np .stack ([channel_cmaps [i ](layer_vals [i ]) for i in range (n_channels )], 0 ).sum (0 )
467+ # Remove alpha channel so we can overwrite it from render_params.alpha
468+ colored = colored [:, :, :3 ]
450469
451- layer = xr .DataArray (
452- data = colored ,
453- coords = [
454- layer .coords ["y" ],
455- layer .coords ["x" ],
456- ["R" , "G" , "B" , "A" ],
457- ],
458- dims = ["y" , "x" , "c" ],
459- )
460- layer = layer .transpose ("y" , "x" , "c" ) # for plotting
470+ ax .imshow (
471+ colored ,
472+ alpha = render_params .alpha ,
473+ )
461474
462- ax .imshow (
463- layer .data ,
464- cmap = channel_cmaps [0 ],
465- alpha = render_params .alpha ,
466- norm = render_params .cmap_params [0 ].norm ,
467- )
475+ elif render_params .palette is not None and got_multiple_cmaps :
476+ raise ValueError ("If 'palette' is provided, 'cmap' must be None." )
468477
469478
470479@dataclass
0 commit comments