1111import pandas as pd
1212import scanpy as sc
1313import spatialdata as sd
14- import xarray as xr
1514from anndata import AnnData
1615from geopandas import GeoDataFrame
1716from matplotlib import colors
@@ -328,7 +327,6 @@ def _render_images(
328327 fig_params : FigParams ,
329328 scalebar_params : ScalebarParams ,
330329 legend_params : LegendParams ,
331- # extent: tuple[float, float, float, float] | None = None,
332330) -> None :
333331 elements = render_params .elements
334332
@@ -346,9 +344,6 @@ def _render_images(
346344 images = [sdata .images [e ] for e in elements ]
347345
348346 for img in images :
349- if (len (img .c ) > 3 or len (img .c ) == 2 ) and render_params .channel is None :
350- raise NotImplementedError ("Only 1 or 3 channels are supported at the moment." )
351-
352347 if render_params .channel is None :
353348 channels = img .coords ["c" ].values
354349 else :
@@ -358,11 +353,8 @@ def _render_images(
358353
359354 n_channels = len (channels )
360355
356+ # True if user gave n cmaps for n channels
361357 got_multiple_cmaps = isinstance (render_params .cmap_params , list )
362-
363- if not isinstance (render_params .cmap_params , list ):
364- render_params .cmap_params = [render_params .cmap_params ] * n_channels
365-
366358 if got_multiple_cmaps :
367359 logger .warning (
368360 "You're blending multiple cmaps. "
@@ -372,102 +364,113 @@ def _render_images(
372364 "Consider using 'palette' instead."
373365 )
374366
375- if render_params .palette is not None :
376- logger .warning ("Parameter 'palette' is ignored when a 'cmap' is provided." )
367+ # not using got_multiple_cmaps here because of ruff :(
368+ if isinstance (render_params .cmap_params , list ) and len (render_params .cmap_params ) != n_channels :
369+ raise ValueError ("If 'cmap' is provided, its length must match the number of channels." )
370+
371+ # 1) Image has only 1 channel
372+ if n_channels == 1 and not isinstance (render_params .cmap_params , list ):
373+ layer = img .sel (c = channels ).squeeze ()
374+
375+ if render_params .quantiles_for_norm != (None , None ):
376+ layer = _normalize (
377+ layer , pmin = render_params .quantiles_for_norm [0 ], pmax = render_params .quantiles_for_norm [1 ], clip = True
378+ )
379+
380+ if render_params .cmap_params .norm is not None : # type: ignore[attr-defined]
381+ layer = render_params .cmap_params .norm (layer ) # type: ignore[attr-defined]
377382
378- for idx , channel in enumerate (channels ):
379- layer = img .sel (c = channel )
383+ if render_params .palette is None :
384+ cmap = render_params .cmap_params .cmap # type: ignore[attr-defined]
385+ else :
386+ cmap = _get_linear_colormap ([render_params .palette ], "k" )[0 ]
387+
388+ ax .imshow (
389+ layer , # get rid of the channel dimension
390+ cmap = cmap ,
391+ alpha = render_params .alpha ,
392+ )
393+
394+ # 2) Image has any number of channels but 1
395+ else :
396+ layers = {}
397+ for i , c in enumerate (channels ):
398+ layers [c ] = img .sel (c = c ).copy (deep = True ).squeeze ()
380399
381400 if render_params .quantiles_for_norm != (None , None ):
382- layer = _normalize (
383- layer ,
401+ layers [ c ] = _normalize (
402+ layers [ c ] ,
384403 pmin = render_params .quantiles_for_norm [0 ],
385404 pmax = render_params .quantiles_for_norm [1 ],
386405 clip = True ,
387406 )
388407
389- if render_params .cmap_params [idx ].norm is not None :
390- layer = render_params .cmap_params [idx ].norm (layer )
408+ if not isinstance (render_params .cmap_params , list ):
409+ if render_params .cmap_params .norm is not None :
410+ layers [c ] = render_params .cmap_params .norm (layers [c ])
411+ else :
412+ if render_params .cmap_params [i ].norm is not None :
413+ layers [c ] = render_params .cmap_params [i ].norm (layers [c ])
391414
392- ax .imshow (
393- layer ,
394- cmap = render_params .cmap_params [idx ].cmap ,
395- alpha = (1 / n_channels ),
396- )
397- break
415+ # 2A) Image has 3 channels, no palette/cmap info -> use RGB
416+ if n_channels == 3 and render_params .palette is None and not got_multiple_cmaps :
417+ ax .imshow (np .stack ([layers [c ] for c in channels ], axis = - 1 ), alpha = render_params .alpha )
398418
399- if n_channels == 1 :
400- layer = img .sel (c = channels )
419+ # 2B) Image has n channels, no palette/cmap info -> sample n categorical colors
420+ elif render_params .palette is None and not got_multiple_cmaps :
421+ # overwrite if n_channels == 2 for intuitive result
422+ if n_channels == 2 :
423+ seed_colors = ["#ff0000ff" , "#00ff00ff" ]
424+ else :
425+ seed_colors = _get_colors_for_categorical_obs (list (range (n_channels )))
401426
402- if render_params .quantiles_for_norm != (None , None ):
403- layer = _normalize (
404- layer , pmin = render_params .quantiles_for_norm [0 ], pmax = render_params .quantiles_for_norm [1 ], clip = True
405- )
427+ channel_cmaps = [_get_linear_colormap ([c ], "k" )[0 ] for c in seed_colors ]
406428
407- if render_params . cmap_params [ 0 ]. norm is not None :
408- layer = render_params . cmap_params [ 0 ]. norm ( layer )
429+ # Apply cmaps to each channel and add up
430+ colored = np . stack ([ channel_cmaps [ i ]( layers [ c ]) for i , c in enumerate ( channels )], 0 ). sum ( 0 )
409431
410- if render_params .palette is None :
411- ax .imshow (
412- layer .squeeze (), # get rid of the channel dimension
413- cmap = render_params .cmap_params [0 ].cmap ,
414- )
432+ # Remove alpha channel so we can overwrite it from render_params.alpha
433+ colored = colored [:, :, :3 ]
415434
416- else :
417435 ax .imshow (
418- layer . squeeze (), # get rid of the channel dimension
419- cmap = _get_linear_colormap ([ render_params .palette ], "k" )[ 0 ] ,
436+ colored ,
437+ alpha = render_params .alpha ,
420438 )
421439
422- break
440+ # 2C) Image has n channels and palette info
441+ elif render_params .palette is not None and not got_multiple_cmaps :
442+ if len (render_params .palette ) != n_channels :
443+ raise ValueError ("If 'palette' is provided, its length must match the number of channels." )
423444
424- if render_params .palette is not None and n_channels != len (render_params .palette ):
425- raise ValueError ("If 'palette' is provided, its length must match the number of channels." )
445+ channel_cmaps = [_get_linear_colormap ([c ], "k" )[0 ] for c in render_params .palette ]
426446
427- if n_channels > 1 :
428- layer = img . sel ( c = channels ). copy ( deep = True )
447+ # Apply cmaps to each channel and add up
448+ colored = np . stack ([ channel_cmaps [ i ]( layers [ c ]) for i , c in enumerate ( channels )], 0 ). sum ( 0 )
429449
430- channel_colors : list [str ] | Any
431- if render_params .palette is None :
432- channel_colors = _get_colors_for_categorical_obs (
433- layer .coords ["c" ].values .tolist (), palette = render_params .cmap_params [0 ].cmap
450+ # Remove alpha channel so we can overwrite it from render_params.alpha
451+ colored = colored [:, :, :3 ]
452+
453+ ax .imshow (
454+ colored ,
455+ alpha = render_params .alpha ,
434456 )
435- else :
436- channel_colors = render_params .palette
437457
438- channel_cmaps = _get_linear_colormap ([str (c ) for c in channel_colors [:n_channels ]], "k" )
458+ elif render_params .palette is None and got_multiple_cmaps :
459+ channel_cmaps = [cp .cmap for cp in render_params .cmap_params ] # type: ignore[union-attr]
439460
440- layer_vals = []
441- if render_params .quantiles_for_norm != (None , None ):
442- for i in range (n_channels ):
443- layer_vals .append (
444- _normalize (
445- layer .values [i ],
446- pmin = render_params .quantiles_for_norm [0 ],
447- pmax = render_params .quantiles_for_norm [1 ],
448- clip = True ,
449- )
450- )
461+ # Apply cmaps to each channel, add up and normalize to [0, 1]
462+ colored = np .stack ([channel_cmaps [i ](layers [c ]) for i , c in enumerate (channels )], 0 ).sum (0 ) / n_channels
451463
452- colored = np .stack ([channel_cmaps [i ](layer_vals [i ]) for i in range (n_channels )], 0 ).sum (0 )
464+ # Remove alpha channel so we can overwrite it from render_params.alpha
465+ colored = colored [:, :, :3 ]
453466
454- layer = xr .DataArray (
455- data = colored ,
456- coords = [
457- layer .coords ["y" ],
458- layer .coords ["x" ],
459- ["R" , "G" , "B" , "A" ],
460- ],
461- dims = ["y" , "x" , "c" ],
462- )
463- layer = layer .transpose ("y" , "x" , "c" ) # for plotting
467+ ax .imshow (
468+ colored ,
469+ alpha = render_params .alpha ,
470+ )
464471
465- ax .imshow (
466- layer .data ,
467- cmap = channel_cmaps [0 ],
468- alpha = render_params .alpha ,
469- norm = render_params .cmap_params [0 ].norm ,
470- )
472+ elif render_params .palette is not None and got_multiple_cmaps :
473+ raise ValueError ("If 'palette' is provided, 'cmap' must be None." )
471474
472475
473476@dataclass
0 commit comments