1
+ import networkx as nx
2
+ import numpy as np
3
+ from stlearn .utils import _read_graph
4
+
5
+ def shortest_path_spatial_PAGA (adata ,use_label ,key = "dpt_pseudotime" ,):
6
+ # Read original PAGA graph
7
+ G = nx .from_numpy_array (adata .uns ["paga" ]["connectivities" ].toarray ())
8
+ edge_weights = nx .get_edge_attributes (G , "weight" )
9
+ G .remove_edges_from ((e for e , w in edge_weights .items () if w < 0 ))
10
+ H = G .to_directed ()
11
+
12
+ # Get min_node and max_node
13
+ min_node ,max_node = find_min_max_node (adata ,key ,use_label )
14
+
15
+ # Calculate pseudotime for each node
16
+ node_pseudotime = {}
17
+
18
+ for node in H .nodes :
19
+ node_pseudotime [node ] = adata .obs .query (use_label + " == '" + str (node ) + "'" )[
20
+ key
21
+ ].max ()
22
+
23
+ # Force original PAGA to directed PAGA based on pseudotime
24
+ edge_to_remove = []
25
+ for edge in H .edges :
26
+ if node_pseudotime [edge [0 ]] - node_pseudotime [edge [1 ]] > 0 :
27
+ edge_to_remove .append (edge )
28
+ H .remove_edges_from (edge_to_remove )
29
+
30
+ # Extract all available paths
31
+ all_paths = {}
32
+ j = 0
33
+ for source in H .nodes :
34
+ for target in H .nodes :
35
+ paths = nx .all_simple_paths (H , source = source , target = target )
36
+ for i , path in enumerate (paths ):
37
+ j += 1
38
+ all_paths [j ] = path
39
+
40
+ # Filter the target paths from min_node to max_node
41
+ target_paths = []
42
+ for path in list (all_paths .values ()):
43
+ if path [0 ] == min_node and path [- 1 ] == max_node :
44
+ target_paths .append (path )
45
+
46
+ # Get the global graph
47
+ G = _read_graph (adata , "global_graph" )
48
+
49
+ centroid_dict = adata .uns ["centroid_dict" ]
50
+ centroid_dict = {int (key ): centroid_dict [key ] for key in centroid_dict }
51
+
52
+ # Generate total length of every path. Store by dictionary
53
+ dist_dict = {}
54
+ for path in target_paths :
55
+ path_name = "," .join (list (map (str ,path )))
56
+ result = []
57
+ query_node = get_node (path , adata .uns ["split_node" ])
58
+ for edge in G .edges ():
59
+ if (edge [0 ] in query_node ) and (edge [1 ] in query_node ):
60
+ result .append (edge )
61
+ if len (result ) >= len (path ):
62
+ dist_dict [path_name ] = calculate_total_dist (result ,centroid_dict )
63
+
64
+ # Find the shortest path
65
+ shortest_path = min (dist_dict , key = lambda x : dist_dict [x ])
66
+ return shortest_path .split (',' )
67
+
68
+ # get name of cluster by subcluster
69
+ def get_cluster (search , dictionary ):
70
+ for cl , sub in dictionary .items ():
71
+ if search in sub :
72
+ return cl
73
+
74
+ def get_node (node_list , split_node ):
75
+ result = np .array ([])
76
+ for node in node_list :
77
+ result = np .append (result , np .array (split_node [int (node )]).astype (int ))
78
+ return result .astype (int )
79
+
80
+ def find_min_max_node (adata ,key = "dpt_pseudotime" ,use_label = "leiden" ):
81
+ min_cluster = int (adata .obs [adata .obs [key ]== 0 ][use_label ].values [0 ])
82
+ max_cluster = int (adata .obs [adata .obs [key ]== 1 ][use_label ].values [0 ])
83
+
84
+ return [min_cluster ,max_cluster ]
85
+
86
+ def calculate_total_dist (result ,centroid_dict ):
87
+ import math
88
+ total_dist = 0
89
+ for edge in result :
90
+ source = centroid_dict [edge [0 ]]
91
+ target = centroid_dict [edge [1 ]]
92
+ dist = math .dist (source ,target )
93
+ total_dist += dist
94
+ return total_dist
0 commit comments