Skip to content

Commit 38fd93c

Browse files
fixes k nearest
1 parent ec8539f commit 38fd93c

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

WarpField.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,16 @@ def getNodes(vertices, radius, n_nodes=np.inf):
1212
return nodes
1313

1414
def k_nearest(verts, nodes, k):
15-
result = np.zeros((verts.shape[0], k))
15+
l = 0
16+
if verts is nodes:
17+
l = 1
18+
result = np.zeros((verts.shape[0], k), dtype=np.int)
1619
result_D = np.zeros((verts.shape[0], k))
17-
max_per_slice = int(np.floor(1e8/nodes.shape[0])) #about 800 MB memory
20+
max_per_slice = int(np.floor(1e7/nodes.shape[0])) #about 800 MB memory
1821
n_slices = int(np.ceil(verts.shape[0]/max_per_slice))
1922
for i in range(n_slices):
2023
cur_slice_length = min(verts.shape[0]-i*max_per_slice, max_per_slice)
2124
D = distance_matrix(verts[i*max_per_slice:i*max_per_slice+cur_slice_length], nodes)
22-
result[i*max_per_slice:i*max_per_slice+cur_slice_length, :k] = np.argsort(D, axis=1)[:, :k]
23-
result_D[i*max_per_slice:i*max_per_slice+cur_slice_length, :k] = np.sort(D, axis=1)[:, :k]
25+
result[i*max_per_slice:i*max_per_slice+cur_slice_length, :] = np.argsort(D, axis=1)[:, l:l+k]
26+
result_D[i*max_per_slice:i*max_per_slice+cur_slice_length, :] = D[np.arange(cur_slice_length, dtype=np.int)[:,np.newaxis], result[i*max_per_slice:i*max_per_slice+cur_slice_length, :]]
2427
return result, result_D
655 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)