Skip to content

Commit be97175

Browse files
committed
too many changes to summarize
1 parent 6bcf54f commit be97175

20 files changed

+2061
-1778
lines changed

setup.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ def get_version(package_name):
3030
'umap-learn',
3131
'mygene',
3232
'palettable',
33-
'gseapy'
33+
'gseapy',
34+
'alive_progress',
35+
'python-igraph',
36+
'marsilea'
3437
],
3538
project_urls={
3639
'Documentation': 'https://pysinglecellnet.readthedocs.io/en/latest/',

src/pySingleCellNet/__init__.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,32 @@
11
"""PySingleCellNet"""
22

3+
from .config import SCN_CATEGORY_COLOR_DICT
4+
from .config import SCN_DIFFEXP_KEY
35
from . import plotting as pl
6+
from . import utils as ut
47
from .stats import *
5-
from .utils import *
68
from .tsp_rf import *
79
from .scn_train import *
8-
from .scn_assess import *
10+
from .scn_assess import create_classifier_report
911
from .postclass_analysis import *
1012
from .rank_class import *
11-
from .config import SCN_CATEGORY_COLOR_DICT
13+
1214

1315

1416
# Public API
1517
__all__ = [
1618
"__version__",
1719
"pl",
18-
"mito_rib",
19-
"limit_anndata_to_common_genes",
20-
"splitCommonAnnData",
20+
"ut",
21+
"train_classifier",
2122
"scn_train",
2223
"scn_classify",
2324
"graph_from_nodes_and_edges",
2425
"comp_ct_thresh",
2526
"class_by_threshold",
2627
"determine_relationships"
28+
"remove_xist_y_genes",
29+
"create_classifier_report"
2730
]
2831

2932

src/pySingleCellNet/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.1.2'
1+
__version__ = '0.1.3'

src/pySingleCellNet/config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,8 @@
1919

2020
# fonts
2121

22-
# ...
22+
# ...
23+
24+
# Arbitrary strings
25+
26+
SCN_DIFFEXP_KEY = "scnDiffExp"

src/pySingleCellNet/plotting/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
)
88

99
from .dot import (
10+
umi_counts_ranked,
1011
ontogeny_graph,
1112
dotplot_deg,
1213
dotplot_diff_gene,
@@ -15,6 +16,7 @@
1516
)
1617

1718
from .heatmap import (
19+
heatmap_classifier_report,
1820
heatmap_scores,
1921
heatmap_gsea,
2022
heatmap_genes,
@@ -31,11 +33,13 @@
3133
"stackedbar_categories",
3234
"stackedbar_categories_list",
3335
"bar_classifier_f1",
36+
"umi_counts_ranked",
3437
"ontogeny_graph",
3538
"dotplot_deg",
3639
"dotplot_diff_gene",
3740
"dotplot_scn_scores",
3841
"umap_scores",
42+
"heatmap_classifier_report",
3943
"heatmap_scores",
4044
"heatmap_gsea",
4145
"heatmap_genes",

src/pySingleCellNet/plotting/bar.py

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,29 @@
1919
from anndata import AnnData
2020
from scipy.sparse import csr_matrix
2121
from sklearn.metrics import f1_score
22+
2223
# from ..utils import *
2324
from pySingleCellNet.config import SCN_CATEGORY_COLOR_DICT
2425

25-
26-
import numpy as np
27-
import matplotlib.pyplot as plt
26+
from scipy.spatial.distance import pdist, squareform
27+
from scipy.cluster.hierarchy import linkage, leaves_list
2828
from anndata import AnnData
2929

30+
3031
def stackedbar_composition(
3132
adata: AnnData,
3233
groupby: str,
3334
obs_column='SCN_class',
3435
labels=None,
3536
bar_width: float = 0.75,
3637
color_dict=None,
37-
ax=None
38+
ax=None,
39+
order_by_similarity: bool = False,
40+
similarity_metric: str = 'correlation'
3841
):
3942
"""
4043
Plots a stacked bar chart of cell type proportions for a single AnnData object grouped by a specified column.
41-
44+
4245
Args:
4346
adata (anndata.AnnData): An AnnData object.
4447
groupby (str): The column in `.obs` to group by.
@@ -50,39 +53,36 @@ def stackedbar_composition(
5053
color_dict (Dict[str, str], optional): A dictionary mapping categories to specific colors. If not provided,
5154
default colors will be used.
5255
ax (matplotlib.axes.Axes, optional): The axis to plot on. If not provided, a new figure and axis will be created.
53-
56+
order_by_similarity (bool, optional): Whether to order the bars by similarity in composition. Defaults to False.
57+
similarity_metric (str, optional): The metric to use for similarity ordering. Defaults to 'correlation'.
58+
5459
Raises:
5560
ValueError: If the length of `labels` does not match the number of unique groups.
56-
61+
5762
Examples:
5863
>>> stackedbar_composition(adata, groupby='sample', obs_column='your_column_name')
5964
>>> fig, ax = plt.subplots()
6065
>>> stackedbar_composition(adata, groupby='sample', obs_column='your_column_name', ax=ax)
6166
"""
62-
6367
# Ensure the groupby column exists in .obs
6468
if groupby not in adata.obs.columns:
6569
raise ValueError(f"The groupby column '{groupby}' does not exist in the .obs attribute.")
6670

67-
6871
# Check if groupby column is categorical or not
6972
if pd.api.types.is_categorical_dtype(adata.obs[groupby]):
7073
unique_groups = adata.obs[groupby].cat.categories.to_list()
7174
else:
7275
unique_groups = adata.obs[groupby].unique().tolist()
73-
76+
7477
# Extract unique groups and ensure labels are provided or create default ones
75-
unique_groups = adata.obs[groupby].cat.categories.to_list()
76-
77-
7878
if labels is None:
7979
labels = unique_groups
8080
elif len(labels) != len(unique_groups):
8181
raise ValueError("Length of 'labels' must match the number of unique groups.")
8282

8383
if color_dict is None:
8484
color_dict = adata.uns['SCN_class_colors']
85-
85+
8686
# Extracting category proportions per group
8787
category_counts = []
8888
categories = set()
@@ -101,12 +101,21 @@ def stackedbar_composition(
101101
j = categories.index(category)
102102
proportions[j, i] = counts[category]
103103

104+
# Ordering groups by similarity if requested
105+
if order_by_similarity:
106+
dist_matrix = pdist(proportions.T, metric=similarity_metric)
107+
linkage_matrix = linkage(dist_matrix, method='average')
108+
order = leaves_list(linkage_matrix)
109+
proportions = proportions[:, order]
110+
unique_groups = [unique_groups[i] for i in order]
111+
labels = [labels[i] for i in order]
112+
104113
# Plotting
105114
if ax is None:
106115
fig, ax = plt.subplots()
107116
else:
108117
fig = ax.figure
109-
118+
110119
bottom = np.zeros(len(unique_groups))
111120
for i, category in enumerate(categories):
112121
color = color_dict[category] if color_dict and category in color_dict else None
@@ -135,7 +144,9 @@ def stackedbar_composition(
135144
return ax
136145

137146

138-
def stackedbar_composition2(
147+
148+
149+
def stackedbar_composition_old(
139150
adata: AnnData,
140151
groupby: str,
141152
obs_column = 'SCN_class',
@@ -332,25 +343,28 @@ def stackedbar_categories(
332343
adata: AnnData,
333344
scn_classes_to_display = None,
334345
bar_height=0.8,
335-
color_dict = None
346+
color_dict = None,
347+
class_col_name = 'SCN_class_argmax',
348+
category_col_name = 'SCN_class_type'
336349
):
337350
# Copy the obs DataFrame to avoid modifying the original data
338351
df = adata.obs.copy()
339352

340353
# Ensure the columns 'SCN_class' and 'SCN_class_type' exist
341-
if 'SCN_class' not in df.columns or 'SCN_class_type' not in df.columns:
342-
raise KeyError("Columns 'SCN_class' and 'SCN_class_type' must be present in adata.obs")
354+
# if 'SCN_class' not in df.columns or 'SCN_class_type' not in df.columns:
355+
if class_col_name not in df.columns or category_col_name not in df.columns:
356+
raise KeyError(f"Columns '{class_col_name}' and '{category_col_name}' must be present in adata.obs")
343357

344358
# Ensure SCN_class categories are consistent
345-
df['SCN_class'] = df['SCN_class'].astype('category')
346-
df['SCN_class_type'] = df['SCN_class_type'].astype('category')
359+
df[class_col_name] = df[class_col_name].astype('category')
360+
df[category_col_name] = df[category_col_name].astype('category')
347361

348-
df['SCN_class'] = df['SCN_class'].cat.set_categories(df['SCN_class'].cat.categories)
349-
df['SCN_class_type'] = df['SCN_class_type'].cat.set_categories(df['SCN_class_type'].cat.categories)
362+
df[class_col_name] = df[class_col_name].cat.set_categories(df[class_col_name].cat.categories)
363+
df[category_col_name] = df[category_col_name].cat.set_categories(df[category_col_name].cat.categories)
350364

351365
# Group by 'SCN_class' and get value counts for 'SCN_class_type'
352366
try:
353-
counts = df.groupby('SCN_class')['SCN_class_type'].value_counts().unstack().fillna(0)
367+
counts = df.groupby(class_col_name)[category_col_name].value_counts().unstack().fillna(0)
354368
except Exception as e:
355369
print("Error during groupby and value_counts operations:", e)
356370
return
@@ -362,7 +376,7 @@ def stackedbar_categories(
362376
total_counts = counts.sum(axis=1)
363377
total_percent = (total_counts / total_counts.sum() * 100).round(1) # Converts to percentage and round
364378

365-
all_classes = df['SCN_class'].unique()
379+
all_classes = df[class_col_name].unique()
366380
if scn_classes_to_display is not None:
367381
if not all(cls in all_classes for cls in scn_classes_to_display):
368382
raise ValueError("Some values in 'scn_classes_to_display' do not match available 'SCN_class' values in the provided DataFrames.")
@@ -415,8 +429,6 @@ def stackedbar_categories(
415429

416430

417431

418-
419-
420432
def stackedbar_categories_list_old(
421433
ads,
422434
titles=None,
@@ -505,8 +517,6 @@ def stackedbar_categories_list_old(
505517
return fig
506518

507519

508-
509-
510520
def stackedbar_categories_list(
511521
ads,
512522
titles=None,
@@ -593,8 +603,6 @@ def stackedbar_categories_list(
593603

594604

595605

596-
597-
598606
def bar_classifier_f1(adata: AnnData, ground_truth: str = "celltype", class_prediction: str = "SCN_class", bar_height=0.8):
599607
"""
600608
Plots a bar graph of F1 scores per class based on ground truth and predicted classifications.

src/pySingleCellNet/plotting/dot.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,51 @@
2323
from sklearn.metrics import f1_score
2424
from ..utils import *
2525

26+
def umi_counts_ranked(adata, total_counts_column="total_counts"):
27+
"""
28+
Identifies and plors the knee point of the UMI count distribution in an AnnData object.
29+
30+
Parameters:
31+
adata (AnnData): The input AnnData object.
32+
total_counts_column (str): Column in `adata.obs` containing total UMI counts. Default is "total_counts".
33+
show (bool): If True, displays a log-log plot with the knee point. Default is True.
34+
35+
Returns:
36+
float: The UMI count value at the knee point.
37+
"""
38+
# Extract total UMI counts
39+
umi_counts = adata.obs[total_counts_column]
40+
41+
# Sort UMI counts in descending order
42+
sorted_umi_counts = np.sort(umi_counts)[::-1]
43+
44+
# Compute cumulative UMI counts (normalized to a fraction)
45+
cumulative_counts = np.cumsum(sorted_umi_counts)
46+
cumulative_fraction = cumulative_counts / cumulative_counts[-1]
47+
48+
# Compute derivatives to identify the knee point
49+
first_derivative = np.gradient(cumulative_fraction)
50+
second_derivative = np.gradient(first_derivative)
51+
52+
# Find the index of the maximum curvature (knee point)
53+
knee_idx = np.argmax(second_derivative)
54+
knee_point_value = sorted_umi_counts[knee_idx]
55+
56+
# Generate log-log plot
57+
cell_ranks = np.arange(1, len(sorted_umi_counts) + 1)
58+
plt.figure(figsize=(10, 6))
59+
plt.plot(cell_ranks, sorted_umi_counts, marker='o', markersize=2, linestyle='-', linewidth=0.5, label="UMI Counts")
60+
plt.axvline(cell_ranks[knee_idx], color="red", linestyle="--", label=f"Knee Point: {knee_point_value}")
61+
plt.title('UMI Counts Per Cell (Log-Log Scale)', fontsize=14)
62+
plt.xlabel('Cell Rank (Descending)', fontsize=12)
63+
plt.ylabel('Total UMI Counts', fontsize=12)
64+
plt.xscale('log')
65+
plt.yscale('log')
66+
plt.grid(True, linestyle='--', linewidth=0.5)
67+
plt.legend()
68+
plt.tight_layout()
69+
plt.show()
70+
2671

2772
def ontogeny_graph(gra, color_dict):
2873
ig.config['plotting.backend'] = 'matplotlib'
@@ -34,7 +79,8 @@ def ontogeny_graph(gra, color_dict):
3479
v_style["margin"] = (50)
3580

3681
for vertex in gra.vs:
37-
vertex["color"] = convert_color(color_dict.get(vertex["name"], np.array([0.5, 0.5, 0.5])))
82+
# vertex["color"] = convert_color(color_dict.get(vertex["name"], np.array([0.5, 0.5, 0.5])))
83+
vertex["color"] = tuple(color_dict.get(vertex["name"], np.array([0.5, 0.5, 0.5])))
3884

3985
# Normalize node sizes for better visualization
4086
max_size = 50 # Maximum size for visualization

0 commit comments

Comments
 (0)