@@ -82,7 +82,7 @@ def _render_shapes(
8282
8383 for e in elements :
8484 shapes = sdata .shapes [e ]
85- n_shapes = sum ([ len (s ) for s in shapes ] )
85+ n_shapes = sum (len (s ) for s in shapes )
8686
8787 if sdata .table is None :
8888 table = AnnData (None , obs = pd .DataFrame (index = pd .Index (np .arange (n_shapes ), dtype = str )))
@@ -94,11 +94,11 @@ def _render_shapes(
9494 sdata = sdata_filt ,
9595 element = sdata_filt .shapes [e ],
9696 element_name = e ,
97- value_to_plot = render_params .color ,
97+ value_to_plot = render_params .col_for_color ,
9898 layer = render_params .layer ,
9999 groups = render_params .groups ,
100100 palette = render_params .palette ,
101- na_color = render_params .cmap_params .na_color ,
101+ na_color = render_params .color or render_params . cmap_params .na_color ,
102102 alpha = render_params .fill_alpha ,
103103 cmap_params = render_params .cmap_params ,
104104 )
@@ -162,14 +162,18 @@ def _render_shapes(
162162 len (set (color_vector )) == 1 and list (set (color_vector ))[0 ] == to_hex (render_params .cmap_params .na_color )
163163 ):
164164 # necessary in case different shapes elements are annotated with one table
165- if color_source_vector is not None :
165+ if color_source_vector is not None and render_params . col_for_color is not None :
166166 color_source_vector = color_source_vector .remove_unused_categories ()
167+
168+ # False if user specified color-like with 'color' parameter
169+ colorbar = False if render_params .col_for_color is None else legend_params .colorbar
170+
167171 _ = _decorate_axs (
168172 ax = ax ,
169173 cax = cax ,
170174 fig_params = fig_params ,
171175 adata = table ,
172- value_to_plot = render_params .color ,
176+ value_to_plot = render_params .col_for_color ,
173177 color_source_vector = color_source_vector ,
174178 palette = palette ,
175179 alpha = render_params .fill_alpha ,
@@ -179,7 +183,7 @@ def _render_shapes(
179183 legend_loc = legend_params .legend_loc ,
180184 legend_fontoutline = legend_params .legend_fontoutline ,
181185 na_in_legend = legend_params .na_in_legend ,
182- colorbar = legend_params . colorbar ,
186+ colorbar = colorbar ,
183187 scalebar_dx = scalebar_params .scalebar_dx ,
184188 scalebar_units = scalebar_params .scalebar_units ,
185189 )
@@ -194,12 +198,6 @@ def _render_points(
194198 scalebar_params : ScalebarParams ,
195199 legend_params : LegendParams ,
196200) -> None :
197- if render_params .groups is not None :
198- if isinstance (render_params .groups , str ):
199- render_params .groups = [render_params .groups ]
200- if not all (isinstance (g , str ) for g in render_params .groups ):
201- raise TypeError ("All groups must be strings." )
202-
203201 elements = render_params .elements
204202
205203 sdata_filt = sdata .filter_by_coordinate_system (
@@ -214,43 +212,56 @@ def _render_points(
214212
215213 for e in elements :
216214 points = sdata .points [e ]
215+ col_for_color = render_params .col_for_color
216+
217217 coords = ["x" , "y" ]
218- if render_params .color is not None :
219- color = [render_params .color ] if isinstance (render_params .color , str ) else render_params .color
220- coords .extend (color )
218+ if col_for_color is not None :
219+ if col_for_color not in points .columns :
220+ # no error in case there are multiple elements, but onyl some have color key
221+ msg = f"Color key '{ col_for_color } ' for element '{ e } ' not been found, using default colors."
222+ logger .warning (msg )
223+ else :
224+ coords += [col_for_color ]
221225
222226 points = points [coords ].compute ()
223- if render_params .groups is not None :
224- points = points [points [color ].isin (render_params .groups ).values ]
225- points [color [0 ]] = points [color [0 ]].cat .set_categories (render_params .groups )
226- points = dask .dataframe .from_pandas (points , npartitions = 1 )
227- sdata_filt .points [e ] = PointsModel .parse (points , coordinates = {"x" : "x" , "y" : "y" })
228-
229- point_df = points [coords ].compute ()
227+ if render_params .groups is not None and col_for_color is not None :
228+ points = points [points [col_for_color ].isin (render_params .groups )]
230229
231230 # we construct an anndata to hack the plotting functions
232231 adata = AnnData (
233- X = point_df [["x" , "y" ]].values , obs = point_df [coords ].reset_index (), dtype = point_df [["x" , "y" ]].values .dtype
232+ X = points [["x" , "y" ]].values , obs = points [coords ].reset_index (), dtype = points [["x" , "y" ]].values .dtype
234233 )
235- if render_params .color is not None :
236- cols = sc .get .obs_df (adata , render_params .color )
234+
235+ # Convert back to dask dataframe to modify sdata
236+ points = dask .dataframe .from_pandas (points , npartitions = 1 )
237+ sdata_filt .points [e ] = PointsModel .parse (points , coordinates = {"x" : "x" , "y" : "y" })
238+
239+ if render_params .col_for_color is not None :
240+ cols = sc .get .obs_df (adata , render_params .col_for_color )
237241 # maybe set color based on type
238242 if is_categorical_dtype (cols ):
239243 _maybe_set_colors (
240244 source = adata ,
241245 target = adata ,
242- key = render_params .color ,
246+ key = render_params .col_for_color ,
243247 palette = render_params .palette ,
244248 )
245249
250+ # when user specified a single color, we overwrite na with it
251+ default_color = (
252+ render_params .color
253+ if render_params .col_for_color is None and render_params .color is not None
254+ else render_params .cmap_params .na_color
255+ )
256+
246257 color_source_vector , color_vector , _ = _set_color_source_vec (
247258 sdata = sdata_filt ,
248259 element = points ,
249260 element_name = e ,
250- value_to_plot = render_params .color ,
261+ value_to_plot = render_params .col_for_color ,
251262 groups = render_params .groups ,
252263 palette = render_params .palette ,
253- na_color = render_params . cmap_params . na_color ,
264+ na_color = default_color ,
254265 alpha = render_params .alpha ,
255266 cmap_params = render_params .cmap_params ,
256267 )
@@ -278,9 +289,7 @@ def _render_points(
278289 )
279290 cax = ax .add_collection (_cax )
280291
281- if not (
282- len (set (color_vector )) == 1 and list (set (color_vector ))[0 ] == to_hex (render_params .cmap_params .na_color )
283- ):
292+ if len (set (color_vector )) != 1 or list (set (color_vector ))[0 ] != to_hex (render_params .cmap_params .na_color ):
284293 if color_source_vector is None :
285294 palette = ListedColormap (dict .fromkeys (color_vector ))
286295 else :
@@ -291,7 +300,7 @@ def _render_points(
291300 cax = cax ,
292301 fig_params = fig_params ,
293302 adata = adata ,
294- value_to_plot = render_params .color ,
303+ value_to_plot = render_params .col_for_color ,
295304 color_source_vector = color_source_vector ,
296305 palette = palette ,
297306 alpha = render_params .alpha ,
@@ -629,8 +638,8 @@ def _render_labels(
629638 _cax = ax .imshow (
630639 labels_infill ,
631640 rasterized = True ,
632- cmap = render_params . cmap_params . cmap if not categorical else None ,
633- norm = render_params . cmap_params . norm if not categorical else None ,
641+ cmap = None if categorical else render_params . cmap_params . cmap ,
642+ norm = None if categorical else render_params . cmap_params . norm ,
634643 alpha = render_params .fill_alpha ,
635644 origin = "lower" ,
636645 )
@@ -652,14 +661,11 @@ def _render_labels(
652661 _cax = ax .imshow (
653662 labels_contour ,
654663 rasterized = True ,
655- cmap = render_params . cmap_params . cmap if not categorical else None ,
656- norm = render_params . cmap_params . norm if not categorical else None ,
664+ cmap = None if categorical else render_params . cmap_params . cmap ,
665+ norm = None if categorical else render_params . cmap_params . norm ,
657666 alpha = render_params .outline_alpha ,
658667 origin = "lower" ,
659668 )
660- _cax .set_transform (trans_data )
661- cax = ax .add_image (_cax )
662-
663669 else :
664670 # Default: no alpha, contour = infill
665671 label = _map_color_seg (
@@ -676,13 +682,13 @@ def _render_labels(
676682 _cax = ax .imshow (
677683 label ,
678684 rasterized = True ,
679- cmap = render_params . cmap_params . cmap if not categorical else None ,
680- norm = render_params . cmap_params . norm if not categorical else None ,
685+ cmap = None if categorical else render_params . cmap_params . cmap ,
686+ norm = None if categorical else render_params . cmap_params . norm ,
681687 alpha = render_params .fill_alpha ,
682688 origin = "lower" ,
683689 )
684- _cax .set_transform (trans_data )
685- cax = ax .add_image (_cax )
690+ _cax .set_transform (trans_data )
691+ cax = ax .add_image (_cax )
686692
687693 _ = _decorate_axs (
688694 ax = ax ,
0 commit comments