-
Notifications
You must be signed in to change notification settings - Fork 2
/
get_cliques.py
305 lines (260 loc) · 12 KB
/
get_cliques.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
#!/usr/bin/env python3
#
# get_cliques.py
# author: Christopher JF Cameron
#
"""
Finds cliques (potential consensus particles) of size k in each graph (micrograph)
"""
import itertools
import networkx as nx
import statistics
from repic.utils.common import *
from scipy.sparse import coo_matrix
from scipy.spatial import KDTree
name = "get_cliques"
"""str: module name (used by argparse subparser)"""
def add_arguments(parser):
"""
Adds argparse command line arguments for get_cliques.py
Args:
parser (object): argparse parse_args() object
Returns:
None
"""
parser.add_argument("in_dir",
help="path to input directory containing subdirectories of particle bounding box coordinate files ")
parser.add_argument("out_dir",
help="path to output directory (WARNING - script will delete directory if it exists)")
parser.add_argument("box_size", type=int,
help="particle bounding box size (in int[pixels])")
parser.add_argument("--multi_out", action="store_true",
help="set output of cliques to be members sorted by picker name")
parser.add_argument("--get_cc", action="store_true",
help="filters cliques for those in the largest Connected Component (CC)")
def add_nodes_to_graph(graph, node_pairs, node_names, k=3):
"""
Adds vertices and edges to the graph
Args:
graph (obj): NetworkX graph() object
node_pairs (list): list of paired vertex (particle bounding box) coordinates and their edge weight
node_names (list): list of node (particle picking algorithm) names
Keyword Args:
k (int, default=3): number of methods
Returns:
None
"""
global node_id
for coords_1, key_1, weight_1, coords_2, key_2, weight_2, jaccard in node_pairs:
graph.add_node(coords_1,
name=node_names[int(np.rint(key_1 * k))],
weight=weight_1)
graph.add_node(coords_2,
name=node_names[int(np.rint(key_2 * k))],
weight=weight_2)
# weight attribute used by nx
graph.add_edge(coords_1, coords_2, weight=jaccard)
def calc_jaccard(x, y, a, b, box_size):
"""
Calculates Jaccard Index (similarity) for particle bounding boxes A (x,y) and B (a,b) with given box size
Args:
x (int): x-coodinate of particle bounding box A
y (int): y-coordinate of particle bounding box A
a (int): x-coordinate of particle bounding box B
b (int): y-coordinate of particle bounding box B
box_size (int): particle bounding box height/width
Returns:
float: Jaccard Index of two particle bounding boxes
"""
x_overlap = max(min(x, a) + box_size - max(x, a), 0)
y_overlap = max(min(y, b) + box_size - max(y, b), 0)
jaccard = x_overlap * y_overlap
return jaccard / ((2 * box_size ** 2) - jaccard)
def find_cliques(graph, k):
"""
Finds all cliques in graph of size k
Args:
graph (obj): NetworkX graph() object
k (int): clique size
Returns:
set: set of k-sized cliques in the graph
"""
cliques = set()
for clique in nx.find_cliques(graph):
if len(clique) == k:
cliques.add(tuple(sorted(clique)))
return cliques
def main(args):
"""
Builds NetworkX graph from file set and finds all k-sized cliques
Args:
args (obj): argparse command line argument object
"""
# ensure input directory exists
assert (os.path.exists(args.in_dir)
), "Error - input directory does not exist"
# set up output directory
del_dir(args.out_dir)
exclude = ["box_size", "out_dir", "multi_out", "get_cc"]
# get method subdirectories
methods = sorted([os.path.basename(val) for val in glob.glob(os.path.join(args.in_dir, '*'))
if os.path.isdir(val)], key=str)
create_dir(args.out_dir)
# determine method with shortest naming convention
start_method = None
num_methods = len(methods)
for method in methods:
# collect example box file from each method subdirectory
for box_file in glob.glob(os.path.join(args.in_dir, method, "*.box")):
# identify basename of file and use it to find matching files in other subdirectories
tmp = os.path.basename(box_file).replace(".box", '')
tmp = f"*{tmp}*"
n = len(sum([glob.glob(os.path.join(args.in_dir, method, tmp))
for method in methods], []))
break
# if the current method's naming convention can be used to identify pairs, keep it
if n == num_methods:
start_method = method
break
assert (not start_method ==
None), "Error - particle bounding box file names cannot be paired across methods"
del box_file, tmp, n
print(f"Using {start_method} BOX files as starting point")
# iterate over grouped particle bounding box files
k = len(methods) # number of methods/clique size
for in_file in glob.glob(os.path.join(args.in_dir, methods[0], "*.box")):
start = time.time()
# dertemine basename of particle bounding box file
basename = os.path.basename(in_file).replace(".box", '')
print(f"\n--- {basename} ---\n")
basename = f"*{basename}*"
print("Loading particle bounding box coordinates into memory ... ")
try:
# get coords for each provided picker
# key = i/k for method i of k methods
coords = np.array(get_box_coords(
in_file, key=0., return_weights=True))
for i, method in enumerate(methods[1:], 1):
coords = np.concatenate((coords,
np.asarray(get_box_coords(os.path.join(args.in_dir, method, basename),
key=i /
float(k),
return_weights=True))))
del i, method
except (UnboundLocalError, IndexError) as e:
# create empty BOX file if particle bounding boxes are not picked by all methods
print("Skipping micrograph - not all methods have picked particles")
out_file = os.path.join(args.out_dir, ''.join(
[basename[1:-1], ".box"]))
with open(out_file, 'wt') as o:
pass
continue
print("Calculating Jaccard indices ... ")
# build k-d tree from x, y, and z coordinates with method key values
kd_tree = KDTree(coords[:, :4])
# get pairs of particle bounding boxes within distance threshold r
# k-d tree uses Minkowski distance (default p == 2 is Euclidean distance)
r = args.box_size + 1. # +1 for key column
pairs = kd_tree.query_pairs(r)
# calculate Jaccard indices between pairs
data = []
for i, j in pairs:
x, y, z, key_1, weight_1, id_1 = coords[i]
a, b, c, key_2, weight_2, id_2 = coords[j]
if (not key_1 == key_2) and ((jaccard := calc_jaccard(x, y, a, b, args.box_size)) > 0.3):
data.append(
tuple([(x, y, id_1), key_1, weight_1, (a, b, id_2), key_2, weight_2, jaccard]))
del coords, kd_tree, x, y, z, key_1, weight_1, id_1, a, b, c, key_2, weight_2, id_2, jaccard
print("Building graph ... ")
# build graph weighted pairs
graph = nx.Graph()
# [add_nodes_to_graph(graph, vals, methods) for vals in data]
add_nodes_to_graph(graph, data, methods, k=k)
del data
# list connected component stats
components = [len(val) for val in nx.connected_components(graph)]
print("\tnumber of CCs:", len(components))
print("\tlargest CC length:", max(components))
print("\tmean CC length:", np.mean(components))
if args.get_cc:
# replace graph of all particle bounding boxes in largest CC
for cc in sorted(nx.connected_components(graph), key=len, reverse=True):
graph = graph.subgraph(cc)
break
print("Finding cliques ... ")
# find cliques
all_cliques = find_cliques(graph, k)
n = len(all_cliques)
# sorted list of vertices in cliques
v = sorted(set(sum(all_cliques, ())))
print(f"\t{n} cliques found with {len(v)} unique vertices")
print("Building ILP data structures ... ")
# cliques confidences - median(clique confidences)
confidence = np.zeros(n, dtype=np.float32)
w = np.zeros(n, dtype=np.float32) # weight vector of cliques
# iterate over cliques, fill in w, retain vertex-to-clique assignments
cliques, rows, cols = [], [], []
for j, clique in enumerate(all_cliques):
subgraph = graph.subgraph(clique)
if args.multi_out:
# return all nodes in clique sorted by picker name
cliques.append(sorted(subgraph.nodes(),
key=lambda x: subgraph.nodes[x]["name"]))
else:
# determine best particle bounding box identification in clique based on
# overlap with other members
cliques.append(max(subgraph.degree(weight="weight"),
key=lambda x: x[1])[0])
# calculate ILP weight for clique
# median(Jaccard of members/edges) * median(members/nodes confidence)
confidence[j] = statistics.median(
list(nx.get_node_attributes(subgraph, "weight").values()))
w[j] = confidence[j] * statistics.median(list(nx.get_edge_attributes(
subgraph, "weight").values()))
# retain row / col indices for sparse matrix
cols.extend([j] * k)
rows.extend([v.index(val) for val in clique])
del j, clique, subgraph
assert (len(cliques) == len(
w)), "Error - concensus coordinates and ILP weight vector are not equal lengths"
assert (len(w) == len(confidence)
), "Error - cliques weights and confidences are not equal lengths"
assert (len(cols) == len(
rows)), "Error - ILP sparse matrix indices (rows / cols) are not equal lengths"
assert (len(cliques) * k == len(cols)
), "Error - consensus coordinates or ILP sparse matrix indices (rows / cols) missing"
A = coo_matrix(([1] * len(cols), (rows, cols)), shape=(len(v), n))
del n, v, rows, cols
if len(cliques) == 0:
# no cliques found
print("Skipping micrograph - no cliques found")
# write empty BOX file
out_file = os.path.join(args.out_dir, ''.join(
[basename[1:-1], ".box"]))
with open(out_file, 'w') as o:
pass
del out_file, o
continue
# write structures to storage for ILP optimization
# add multi-out header
cliques = [methods] + cliques if args.multi_out else cliques
for label, val in zip(
["weight_vector", "consensus_coords",
"consensus_confidences", "constraint_matrix"],
[w, cliques, confidence, A]):
out_file = os.path.join(args.out_dir, ''.join(
[basename[1:-1], '_', label, ".pickle"]))
with open(out_file, 'wb') as o:
pickle.dump(val, o, protocol=pickle.HIGHEST_PROTOCOL)
out_file = os.path.join(args.out_dir, ''.join(
[basename[1:-1], "_runtime.tsv"]))
with open(out_file, 'wt') as o:
# runtime (in seconds), largest CC, number of CC
o.write('\t'.join([str(val) for val in [time.time() - start,
np.max(components), len(components)]]) + '\n')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
"""obj: argparse parse_args() object"""
add_arguments(parser)
args = parser.parse_args()
main(args)