19
19
from anndata import AnnData
20
20
from scipy .sparse import csr_matrix
21
21
from sklearn .metrics import f1_score
22
+
22
23
# from ..utils import *
23
24
from pySingleCellNet .config import SCN_CATEGORY_COLOR_DICT
24
25
25
-
26
- import numpy as np
27
- import matplotlib .pyplot as plt
26
+ from scipy .spatial .distance import pdist , squareform
27
+ from scipy .cluster .hierarchy import linkage , leaves_list
28
28
from anndata import AnnData
29
29
30
+
30
31
def stackedbar_composition (
31
32
adata : AnnData ,
32
33
groupby : str ,
33
34
obs_column = 'SCN_class' ,
34
35
labels = None ,
35
36
bar_width : float = 0.75 ,
36
37
color_dict = None ,
37
- ax = None
38
+ ax = None ,
39
+ order_by_similarity : bool = False ,
40
+ similarity_metric : str = 'correlation'
38
41
):
39
42
"""
40
43
Plots a stacked bar chart of cell type proportions for a single AnnData object grouped by a specified column.
41
-
44
+
42
45
Args:
43
46
adata (anndata.AnnData): An AnnData object.
44
47
groupby (str): The column in `.obs` to group by.
@@ -50,39 +53,36 @@ def stackedbar_composition(
50
53
color_dict (Dict[str, str], optional): A dictionary mapping categories to specific colors. If not provided,
51
54
default colors will be used.
52
55
ax (matplotlib.axes.Axes, optional): The axis to plot on. If not provided, a new figure and axis will be created.
53
-
56
+ order_by_similarity (bool, optional): Whether to order the bars by similarity in composition. Defaults to False.
57
+ similarity_metric (str, optional): The metric to use for similarity ordering. Defaults to 'correlation'.
58
+
54
59
Raises:
55
60
ValueError: If the length of `labels` does not match the number of unique groups.
56
-
61
+
57
62
Examples:
58
63
>>> stackedbar_composition(adata, groupby='sample', obs_column='your_column_name')
59
64
>>> fig, ax = plt.subplots()
60
65
>>> stackedbar_composition(adata, groupby='sample', obs_column='your_column_name', ax=ax)
61
66
"""
62
-
63
67
# Ensure the groupby column exists in .obs
64
68
if groupby not in adata .obs .columns :
65
69
raise ValueError (f"The groupby column '{ groupby } ' does not exist in the .obs attribute." )
66
70
67
-
68
71
# Check if groupby column is categorical or not
69
72
if pd .api .types .is_categorical_dtype (adata .obs [groupby ]):
70
73
unique_groups = adata .obs [groupby ].cat .categories .to_list ()
71
74
else :
72
75
unique_groups = adata .obs [groupby ].unique ().tolist ()
73
-
76
+
74
77
# Extract unique groups and ensure labels are provided or create default ones
75
- unique_groups = adata .obs [groupby ].cat .categories .to_list ()
76
-
77
-
78
78
if labels is None :
79
79
labels = unique_groups
80
80
elif len (labels ) != len (unique_groups ):
81
81
raise ValueError ("Length of 'labels' must match the number of unique groups." )
82
82
83
83
if color_dict is None :
84
84
color_dict = adata .uns ['SCN_class_colors' ]
85
-
85
+
86
86
# Extracting category proportions per group
87
87
category_counts = []
88
88
categories = set ()
@@ -101,12 +101,21 @@ def stackedbar_composition(
101
101
j = categories .index (category )
102
102
proportions [j , i ] = counts [category ]
103
103
104
+ # Ordering groups by similarity if requested
105
+ if order_by_similarity :
106
+ dist_matrix = pdist (proportions .T , metric = similarity_metric )
107
+ linkage_matrix = linkage (dist_matrix , method = 'average' )
108
+ order = leaves_list (linkage_matrix )
109
+ proportions = proportions [:, order ]
110
+ unique_groups = [unique_groups [i ] for i in order ]
111
+ labels = [labels [i ] for i in order ]
112
+
104
113
# Plotting
105
114
if ax is None :
106
115
fig , ax = plt .subplots ()
107
116
else :
108
117
fig = ax .figure
109
-
118
+
110
119
bottom = np .zeros (len (unique_groups ))
111
120
for i , category in enumerate (categories ):
112
121
color = color_dict [category ] if color_dict and category in color_dict else None
@@ -135,7 +144,9 @@ def stackedbar_composition(
135
144
return ax
136
145
137
146
138
- def stackedbar_composition2 (
147
+
148
+
149
+ def stackedbar_composition_old (
139
150
adata : AnnData ,
140
151
groupby : str ,
141
152
obs_column = 'SCN_class' ,
@@ -332,25 +343,28 @@ def stackedbar_categories(
332
343
adata : AnnData ,
333
344
scn_classes_to_display = None ,
334
345
bar_height = 0.8 ,
335
- color_dict = None
346
+ color_dict = None ,
347
+ class_col_name = 'SCN_class_argmax' ,
348
+ category_col_name = 'SCN_class_type'
336
349
):
337
350
# Copy the obs DataFrame to avoid modifying the original data
338
351
df = adata .obs .copy ()
339
352
340
353
# Ensure the columns 'SCN_class' and 'SCN_class_type' exist
341
- if 'SCN_class' not in df .columns or 'SCN_class_type' not in df .columns :
342
- raise KeyError ("Columns 'SCN_class' and 'SCN_class_type' must be present in adata.obs" )
354
+ # if 'SCN_class' not in df.columns or 'SCN_class_type' not in df.columns:
355
+ if class_col_name not in df .columns or category_col_name not in df .columns :
356
+ raise KeyError (f"Columns '{ class_col_name } ' and '{ category_col_name } ' must be present in adata.obs" )
343
357
344
358
# Ensure SCN_class categories are consistent
345
- df ['SCN_class' ] = df ['SCN_class' ].astype ('category' )
346
- df ['SCN_class_type' ] = df ['SCN_class_type' ].astype ('category' )
359
+ df [class_col_name ] = df [class_col_name ].astype ('category' )
360
+ df [category_col_name ] = df [category_col_name ].astype ('category' )
347
361
348
- df ['SCN_class' ] = df ['SCN_class' ].cat .set_categories (df ['SCN_class' ].cat .categories )
349
- df ['SCN_class_type' ] = df ['SCN_class_type' ].cat .set_categories (df ['SCN_class_type' ].cat .categories )
362
+ df [class_col_name ] = df [class_col_name ].cat .set_categories (df [class_col_name ].cat .categories )
363
+ df [category_col_name ] = df [category_col_name ].cat .set_categories (df [category_col_name ].cat .categories )
350
364
351
365
# Group by 'SCN_class' and get value counts for 'SCN_class_type'
352
366
try :
353
- counts = df .groupby ('SCN_class' )[ 'SCN_class_type' ].value_counts ().unstack ().fillna (0 )
367
+ counts = df .groupby (class_col_name )[ category_col_name ].value_counts ().unstack ().fillna (0 )
354
368
except Exception as e :
355
369
print ("Error during groupby and value_counts operations:" , e )
356
370
return
@@ -362,7 +376,7 @@ def stackedbar_categories(
362
376
total_counts = counts .sum (axis = 1 )
363
377
total_percent = (total_counts / total_counts .sum () * 100 ).round (1 ) # Converts to percentage and round
364
378
365
- all_classes = df ['SCN_class' ].unique ()
379
+ all_classes = df [class_col_name ].unique ()
366
380
if scn_classes_to_display is not None :
367
381
if not all (cls in all_classes for cls in scn_classes_to_display ):
368
382
raise ValueError ("Some values in 'scn_classes_to_display' do not match available 'SCN_class' values in the provided DataFrames." )
@@ -415,8 +429,6 @@ def stackedbar_categories(
415
429
416
430
417
431
418
-
419
-
420
432
def stackedbar_categories_list_old (
421
433
ads ,
422
434
titles = None ,
@@ -505,8 +517,6 @@ def stackedbar_categories_list_old(
505
517
return fig
506
518
507
519
508
-
509
-
510
520
def stackedbar_categories_list (
511
521
ads ,
512
522
titles = None ,
@@ -593,8 +603,6 @@ def stackedbar_categories_list(
593
603
594
604
595
605
596
-
597
-
598
606
def bar_classifier_f1 (adata : AnnData , ground_truth : str = "celltype" , class_prediction : str = "SCN_class" , bar_height = 0.8 ):
599
607
"""
600
608
Plots a bar graph of F1 scores per class based on ground truth and predicted classifications.
0 commit comments