|
6 | 6 | import mygene
|
7 | 7 | import anndata as ad
|
8 | 8 | import pySingleCellNet as pySCN
|
| 9 | +from scipy.sparse import issparse |
| 10 | + |
9 | 11 |
|
10 | 12 | def convert_ensembl_to_symbol(adata, species = 'mouse', batch_size=1000):
|
11 | 13 | mg = mygene.MyGeneInfo()
|
@@ -90,6 +92,25 @@ def read_gmt(file_path: str) -> dict:
|
90 | 92 |
|
91 | 93 | return gene_sets
|
92 | 94 |
|
| 95 | +def filter_gene_list(genelist, min_genes, max_genes): |
| 96 | + """ |
| 97 | + Filter the gene lists in the provided dictionary based on their lengths. |
| 98 | +
|
| 99 | + Parameters: |
| 100 | + - genelist : dict |
| 101 | + Dictionary with keys as identifiers and values as lists of genes. |
| 102 | + - min_genes : int |
| 103 | + Minimum number of genes a list should have. |
| 104 | + - max_genes : int |
| 105 | + Maximum number of genes a list should have. |
| 106 | +
|
| 107 | + Returns: |
| 108 | + - dict |
| 109 | + Filtered dictionary with lists that have a length between min_genes and max_genes (inclusive of min_genes and max_genes). |
| 110 | + """ |
| 111 | + filtered_dict = {key: value for key, value in genelist.items() if min_genes <= len(value) <= max_genes} |
| 112 | + return filtered_dict |
| 113 | + |
93 | 114 |
|
94 | 115 | def pull_out_genes(
|
95 | 116 | diff_genes_dict: dict,
|
@@ -172,8 +193,6 @@ def read_broken_geo_mtx(path: str, prefix: str) -> AnnData:
|
172 | 193 | adata.var_names = adata.var['gene']
|
173 | 194 | return adata
|
174 | 195 |
|
175 |
| - |
176 |
| - |
177 | 196 | def mito_rib(adQ: AnnData, species: str = "MM", clean: bool = True) -> AnnData:
|
178 | 197 | """
|
179 | 198 | Calculate mitochondrial and ribosomal QC metrics and add them to the `.var` attribute of the AnnData object.
|
@@ -236,7 +255,8 @@ def norm_hvg_scale_pca(
|
236 | 255 | min_disp: float = 0.25,
|
237 | 256 | scale_max: float = 10,
|
238 | 257 | n_comps: int = 100,
|
239 |
| - gene_scale: bool = False |
| 258 | + gene_scale: bool = False, |
| 259 | + use_hvg: bool = True |
240 | 260 | ) -> AnnData:
|
241 | 261 | """
|
242 | 262 | Normalize, detect highly variable genes, optionally scale, and perform PCA on an AnnData object.
|
@@ -287,7 +307,7 @@ def norm_hvg_scale_pca(
|
287 | 307 | sc.pp.scale(adata, max_value=scale_max)
|
288 | 308 |
|
289 | 309 | # Perform PCA on the data
|
290 |
| - sc.tl.pca(adata, n_comps=n_comps) |
| 310 | + sc.tl.pca(adata, n_comps=n_comps, use_highly_variable=use_hvg) |
291 | 311 |
|
292 | 312 | return adata
|
293 | 313 |
|
@@ -545,6 +565,85 @@ def sample_cells(
|
545 | 565 |
|
546 | 566 | return sampled_adata
|
547 | 567 |
|
| 568 | +from scipy.sparse import issparse |
| 569 | + |
| 570 | +def compute_mean_expression_per_cluster( |
| 571 | + adata, |
| 572 | + cluster_key |
| 573 | +): |
| 574 | + """ |
| 575 | + Compute mean gene expression for each gene in each cluster, create a new anndata object, and store it in adata.uns. |
| 576 | +
|
| 577 | + Parameters: |
| 578 | + - adata : anndata.AnnData |
| 579 | + The input AnnData object with labeled cell clusters. |
| 580 | + - cluster_key : str |
| 581 | + The key in adata.obs where the cluster labels are stored. |
| 582 | +
|
| 583 | + Returns: |
| 584 | + - anndata.AnnData |
| 585 | + The modified AnnData object with the mean expression anndata stored in uns['mean_expression']. |
| 586 | + """ |
| 587 | + if cluster_key not in adata.obs.columns: |
| 588 | + raise ValueError(f"{cluster_key} not found in adata.obs") |
| 589 | + |
| 590 | + # Extract unique cluster labels |
| 591 | + clusters = adata.obs[cluster_key].unique().tolist() |
| 592 | + |
| 593 | + # Compute mean expression for each cluster |
| 594 | + mean_expressions = [] |
| 595 | + for cluster in clusters: |
| 596 | + cluster_cells = adata[adata.obs[cluster_key] == cluster, :] |
| 597 | + mean_expression = np.mean(cluster_cells.X, axis=0).A1 if issparse(cluster_cells.X) else np.mean(cluster_cells.X, axis=0) |
| 598 | + mean_expressions.append(mean_expression) |
| 599 | + |
| 600 | + # Convert to matrix |
| 601 | + mean_expression_matrix = np.vstack(mean_expressions) |
| 602 | + |
| 603 | + # Create a new anndata object |
| 604 | + mean_expression_adata = sc.AnnData(X=mean_expression_matrix, |
| 605 | + var=pd.DataFrame(index=adata.var_names), |
| 606 | + obs=pd.DataFrame(index=clusters)) |
| 607 | + |
| 608 | + # Store this new anndata object in adata.uns |
| 609 | + adata.uns['mean_expression'] = mean_expression_adata |
| 610 | + #return adata |
| 611 | + |
| 612 | + |
| 613 | +def find_elbow( |
| 614 | + adata |
| 615 | +): |
| 616 | + """ |
| 617 | + Find the "elbow" index in the variance explained by principal components. |
| 618 | +
|
| 619 | + Parameters: |
| 620 | + - variance_explained : list or array |
| 621 | + Variance explained by each principal component, typically in decreasing order. |
| 622 | +
|
| 623 | + Returns: |
| 624 | + - int |
| 625 | + The index corresponding to the "elbow" in the variance explained plot. |
| 626 | + """ |
| 627 | + variance_explained = adata.uns['pca']['variance_ratio'] |
| 628 | + # Coordinates of all points |
| 629 | + n_points = len(variance_explained) |
| 630 | + all_coords = np.vstack((range(n_points), variance_explained)).T |
| 631 | + # Line vector from first to last point |
| 632 | + line_vec = all_coords[-1] - all_coords[0] |
| 633 | + line_vec_norm = line_vec / np.sqrt(np.sum(line_vec**2)) |
| 634 | + # Vector being orthogonal to the line |
| 635 | + vec_from_first = all_coords - all_coords[0] |
| 636 | + scalar_prod = np.sum(vec_from_first * np.tile(line_vec_norm, (n_points, 1)), axis=1) |
| 637 | + vec_from_first_parallel = np.outer(scalar_prod, line_vec_norm) |
| 638 | + vec_to_line = vec_from_first - vec_from_first_parallel |
| 639 | + # Distance to the line |
| 640 | + dist_to_line = np.sqrt(np.sum(vec_to_line ** 2, axis=1)) |
| 641 | + # Index of the point with max distance to the line |
| 642 | + elbow_idx = np.argmax(dist_to_line) |
| 643 | + return elbow_idx |
| 644 | + |
| 645 | + |
| 646 | + |
548 | 647 |
|
549 | 648 | def ctMerge(sampTab, annCol, ctVect, newName):
|
550 | 649 | oldRows=np.isin(sampTab[annCol], ctVect)
|
@@ -652,7 +751,6 @@ def downSampleW(vector,total=1e5, dThresh=0):
|
652 | 751 | res[res<dThresh]=0
|
653 | 752 | return res
|
654 | 753 |
|
655 |
| - |
656 | 754 | def weighted_down(expDat, total, dThresh=0):
|
657 | 755 | rSums=expDat.sum(axis=1)
|
658 | 756 | dVector=np.divide(total, rSums)
|
|
0 commit comments