Skip to content

Commit 09c8d7a

Browse files
committed
feature: shortest path finding
1 parent e22bc1f commit 09c8d7a

File tree

3 files changed

+96
-2
lines changed

3 files changed

+96
-2
lines changed

stlearn/spatials/trajectory/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@
1111
from .compare_transitions import compare_transitions
1212

1313
from .set_root import set_root
14+
from .shortest_path_spatial_PAGA import shortest_path_spatial_PAGA

stlearn/spatials/trajectory/pseudotime.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,9 +244,8 @@ def store_available_paths(adata, threshold, use_label, max_nodes, pseudotime_key
244244
paths = nx.all_simple_paths(H, source=source, target=target)
245245
for i, path in enumerate(paths):
246246
if len(path) < max_nodes:
247-
all_paths[i] = path
247+
all_paths[str(i) + "_" + str(source) + "_" + str(target)] = path
248248

249-
# all_paths = list(map(lambda x: " - ".join(np.array(x).astype(str)),all_paths))
250249

251250
adata.uns["available_paths"] = all_paths
252251
print(
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import networkx as nx
2+
import numpy as np
3+
from stlearn.utils import _read_graph
4+
5+
def shortest_path_spatial_PAGA(adata,use_label,key="dpt_pseudotime",):
6+
# Read original PAGA graph
7+
G = nx.from_numpy_array(adata.uns["paga"]["connectivities"].toarray())
8+
edge_weights = nx.get_edge_attributes(G, "weight")
9+
G.remove_edges_from((e for e, w in edge_weights.items() if w <0))
10+
H = G.to_directed()
11+
12+
# Get min_node and max_node
13+
min_node,max_node = find_min_max_node(adata,key,use_label)
14+
15+
# Calculate pseudotime for each node
16+
node_pseudotime = {}
17+
18+
for node in H.nodes:
19+
node_pseudotime[node] = adata.obs.query(use_label + " == '" + str(node) + "'")[
20+
key
21+
].max()
22+
23+
# Force original PAGA to directed PAGA based on pseudotime
24+
edge_to_remove = []
25+
for edge in H.edges:
26+
if node_pseudotime[edge[0]] - node_pseudotime[edge[1]] > 0:
27+
edge_to_remove.append(edge)
28+
H.remove_edges_from(edge_to_remove)
29+
30+
# Extract all available paths
31+
all_paths = {}
32+
j = 0
33+
for source in H.nodes:
34+
for target in H.nodes:
35+
paths = nx.all_simple_paths(H, source=source, target=target)
36+
for i, path in enumerate(paths):
37+
j+=1
38+
all_paths[j] = path
39+
40+
# Filter the target paths from min_node to max_node
41+
target_paths = []
42+
for path in list(all_paths.values()):
43+
if path[0] == min_node and path[-1] == max_node:
44+
target_paths.append(path)
45+
46+
# Get the global graph
47+
G = _read_graph(adata, "global_graph")
48+
49+
centroid_dict = adata.uns["centroid_dict"]
50+
centroid_dict = {int(key): centroid_dict[key] for key in centroid_dict}
51+
52+
# Generate total length of every path. Store by dictionary
53+
dist_dict = {}
54+
for path in target_paths:
55+
path_name = ",".join(list(map(str,path)))
56+
result = []
57+
query_node = get_node(path, adata.uns["split_node"])
58+
for edge in G.edges():
59+
if (edge[0] in query_node) and (edge[1] in query_node):
60+
result.append(edge)
61+
if len(result) >= len(path):
62+
dist_dict[path_name] = calculate_total_dist(result,centroid_dict)
63+
64+
# Find the shortest path
65+
shortest_path = min(dist_dict, key=lambda x: dist_dict[x])
66+
return shortest_path.split(',')
67+
68+
# get name of cluster by subcluster
69+
def get_cluster(search, dictionary):
70+
for cl, sub in dictionary.items():
71+
if search in sub:
72+
return cl
73+
74+
def get_node(node_list, split_node):
75+
result = np.array([])
76+
for node in node_list:
77+
result = np.append(result, np.array(split_node[int(node)]).astype(int))
78+
return result.astype(int)
79+
80+
def find_min_max_node(adata,key="dpt_pseudotime",use_label="leiden"):
81+
min_cluster = int(adata.obs[adata.obs[key]==0][use_label].values[0])
82+
max_cluster = int(adata.obs[adata.obs[key]==1][use_label].values[0])
83+
84+
return [min_cluster,max_cluster]
85+
86+
def calculate_total_dist(result,centroid_dict):
87+
import math
88+
total_dist = 0
89+
for edge in result:
90+
source = centroid_dict[edge[0]]
91+
target = centroid_dict[edge[1]]
92+
dist =math.dist(source,target)
93+
total_dist += dist
94+
return total_dist

0 commit comments

Comments
 (0)