@@ -137,7 +137,7 @@ def _render_shapes(
137137 if isinstance (groups , list ) and color_source_vector is not None :
138138 mask = color_source_vector .isin (groups )
139139 shapes = shapes [mask ]
140- shapes = shapes .reset_index ()
140+ shapes = shapes .reset_index (drop = True )
141141 color_source_vector = color_source_vector [mask ]
142142 color_vector = color_vector [mask ]
143143
@@ -363,8 +363,10 @@ def _render_shapes(
363363 cax = None
364364 if aggregate_with_reduction is not None :
365365 vmin = aggregate_with_reduction [0 ].values if norm .vmin is None else norm .vmin
366- vmax = aggregate_with_reduction [1 ].values if norm .vmin is None else norm .vmax
366+ vmax = aggregate_with_reduction [1 ].values if norm .vmax is None else norm .vmax
367367 if (norm .vmin is not None or norm .vmax is not None ) and norm .vmin == norm .vmax :
368+ assert norm .vmin is not None
369+ assert norm .vmax is not None
368370 # value (vmin=vmax) is placed in the middle of the colorbar so that we can distinguish it from over and
369371 # under values in case clip=True or clip=False with cmap(under)=cmap(0) & cmap(over)=cmap(1)
370372 vmin = norm .vmin - 0.5
@@ -766,6 +768,8 @@ def _render_points(
766768 vmin = aggregate_with_reduction [0 ].values if norm .vmin is None else norm .vmin
767769 vmax = aggregate_with_reduction [1 ].values if norm .vmax is None else norm .vmax
768770 if (norm .vmin is not None or norm .vmax is not None ) and norm .vmin == norm .vmax :
771+ assert norm .vmin is not None
772+ assert norm .vmax is not None
769773 # value (vmin=vmax) is placed in the middle of the colorbar so that we can distinguish it from over and
770774 # under values in case clip=True or clip=False with cmap(under)=cmap(0) & cmap(over)=cmap(1)
771775 vmin = norm .vmin - 0.5
@@ -922,20 +926,22 @@ def _render_images(
922926 # 2) Image has any number of channels but 1
923927 else :
924928 layers = {}
925- for ch_index , c in enumerate (channels ):
926- layers [c ] = img .sel (c = c ).copy (deep = True ).squeeze ()
927-
928- if not isinstance (render_params .cmap_params , list ):
929- if render_params .cmap_params .norm is not None :
930- layers [c ] = render_params .cmap_params .norm (layers [c ])
929+ for ch_idx , ch in enumerate (channels ):
930+ layers [ch ] = img .sel (c = ch ).copy (deep = True ).squeeze ()
931+ if isinstance (render_params .cmap_params , list ):
932+ ch_norm = render_params .cmap_params [ch_idx ].norm
933+ ch_cmap_is_default = render_params .cmap_params [ch_idx ].cmap_is_default
931934 else :
932- if render_params .cmap_params [ch_index ].norm is not None :
933- layers [c ] = render_params .cmap_params [ch_index ].norm (layers [c ])
935+ ch_norm = render_params .cmap_params .norm
936+ ch_cmap_is_default = render_params .cmap_params .cmap_is_default
937+
938+ if not ch_cmap_is_default and ch_norm is not None :
939+ layers [ch_idx ] = ch_norm (layers [ch_idx ])
934940
935941 # 2A) Image has 3 channels, no palette info, and no/only one cmap was given
936942 if palette is None and n_channels == 3 and not isinstance (render_params .cmap_params , list ):
937943 if render_params .cmap_params .cmap_is_default : # -> use RGB
938- stacked = np .stack ([layers [c ] for c in channels ], axis = - 1 )
944+ stacked = np .stack ([layers [ch ] for ch in layers ], axis = - 1 )
939945 else : # -> use given cmap for each channel
940946 channel_cmaps = [render_params .cmap_params .cmap ] * n_channels
941947 stacked = (
@@ -968,12 +974,54 @@ def _render_images(
968974 # overwrite if n_channels == 2 for intuitive result
969975 if n_channels == 2 :
970976 seed_colors = ["#ff0000ff" , "#00ff00ff" ]
971- else :
977+ channel_cmaps = [_get_linear_colormap ([c ], "k" )[0 ] for c in seed_colors ]
978+ colored = np .stack (
979+ [channel_cmaps [ch_ind ](layers [ch ]) for ch_ind , ch in enumerate (channels )],
980+ 0 ,
981+ ).sum (0 )
982+ colored = colored [:, :, :3 ]
983+ elif n_channels == 3 :
972984 seed_colors = _get_colors_for_categorical_obs (list (range (n_channels )))
985+ channel_cmaps = [_get_linear_colormap ([c ], "k" )[0 ] for c in seed_colors ]
986+ colored = np .stack (
987+ [channel_cmaps [ind ](layers [ch ]) for ind , ch in enumerate (channels )],
988+ 0 ,
989+ ).sum (0 )
990+ colored = colored [:, :, :3 ]
991+ else :
992+ if isinstance (render_params .cmap_params , list ):
993+ cmap_is_default = render_params .cmap_params [0 ].cmap_is_default
994+ else :
995+ cmap_is_default = render_params .cmap_params .cmap_is_default
973996
974- channel_cmaps = [_get_linear_colormap ([c ], "k" )[0 ] for c in seed_colors ]
975- colored = np .stack ([channel_cmaps [ind ](layers [ch ]) for ind , ch in enumerate (channels )], 0 ).sum (0 )
976- colored = colored [:, :, :3 ]
997+ if cmap_is_default :
998+ seed_colors = _get_colors_for_categorical_obs (list (range (n_channels )))
999+ else :
1000+ # Sample n_channels colors evenly from the colormap
1001+ if isinstance (render_params .cmap_params , list ):
1002+ seed_colors = [
1003+ render_params .cmap_params [i ].cmap (i / (n_channels - 1 )) for i in range (n_channels )
1004+ ]
1005+ else :
1006+ seed_colors = [render_params .cmap_params .cmap (i / (n_channels - 1 )) for i in range (n_channels )]
1007+ channel_cmaps = [_get_linear_colormap ([c ], "k" )[0 ] for c in seed_colors ]
1008+
1009+ # Stack (n_channels, height, width) → (height*width, n_channels)
1010+ H , W = next (iter (layers .values ())).shape
1011+ comp_rgb = np .zeros ((H , W , 3 ), dtype = float )
1012+
1013+ # For each channel: map to RGBA, apply constant alpha, then add
1014+ for ch_idx , ch in enumerate (channels ):
1015+ layer_arr = layers [ch ]
1016+ rgba = channel_cmaps [ch_idx ](layer_arr )
1017+ rgba [..., 3 ] = render_params .alpha
1018+ comp_rgb += rgba [..., :3 ] * rgba [..., 3 ][..., None ]
1019+
1020+ colored = np .clip (comp_rgb , 0 , 1 )
1021+ logger .info (
1022+ f"Your image has { n_channels } channels. Sampling categorical colors and using "
1023+ f"multichannel strategy 'stack' to render."
1024+ ) # TODO: update when pca is added as strategy
9771025
9781026 _ax_show_and_transform (
9791027 colored ,
@@ -1019,6 +1067,7 @@ def _render_images(
10191067 zorder = render_params .zorder ,
10201068 )
10211069
1070+ # 2D) Image has n channels, no palette but cmap info
10221071 elif palette is not None and got_multiple_cmaps :
10231072 raise ValueError ("If 'palette' is provided, 'cmap' must be None." )
10241073
0 commit comments