-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathkdtreetest2.sml
217 lines (167 loc) · 8.69 KB
/
kdtreetest2.sml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
fun putStr str =
(TextIO.output (TextIO.stdOut, str))
fun putStrLn str =
(TextIO.output (TextIO.stdOut, str);
TextIO.output (TextIO.stdOut, "\n"))
fun timing (action) =
let
val timer = Timer.startCPUTimer ()
val result = action ()
val times = Timer.checkCPUTimer timer
in
(result, Time.+ (#usr times, #sys times))
end
exception Point
fun distanceSquared3D ([x1,x2,x3], [y1,y2,y3]) =
let
val d1 = Real.- (x1, y1)
val d2 = Real.- (x2, y2)
val d3 = Real.- (x3, y3)
in
List.foldl Real.+ 0.0 [Real.* (d1,d1), Real.* (d2,d2), Real.* (d3,d3)]
end
| distanceSquared3D (_, _) = raise Point
fun distance3D (x,y) = Math.sqrt(distanceSquared3D (x,y))
val _ = print "starting KDTree tests\n"
structure TensorKDTree = KDTreeFn (structure S = TensorPointSpaceFn (val K = 3)
val distance = distanceSquared3D)
structure MapKDTree = KDTreeFn (structure S = MapPointSpaceFn (val K = 3)
val distance = distanceSquared3D)
functor KdTreeTestFn (structure KDTree: KDTREE
val distance : (real list) * (real list) -> real) =
struct
fun sortPoints (origin,pts) =
ListMergeSort.sort
(fn (x,y) => Real.> (distance (origin,x), distance (origin,y)))
pts
fun check (t) =
(putStrLn ("check: size t = " ^ (Int.toString (KDTree.size t)));
putStrLn ("check: length (toList t) = " ^ (Int.toString (List.length (KDTree.toList t))));
if not ((List.length (KDTree.toList t)) = (KDTree.size t)) then raise Fail "invalid KDTree size" else ();
if not (KDTree.isValid t) then raise Fail "invalid KDTree" else ();
if not (KDTree.allSubtreesAreValid t) then raise Fail "invalid subtree of a KDTree" else ()
)
fun testNearestNeighbor (sorted,t,x) =
let
val nn = valOf (KDTree.nearestNeighbor t x)
val nn' = KDTree.S.pointList (KDTree.S.point (KDTree.pointSpace t) nn)
val _ = (print "nn' = "; TensorFile.realListLineWrite TextIO.stdOut nn')
val _ = (print "distance(x,nn') = "; TensorFile.realWrite TextIO.stdOut (distance3D (x,nn')))
val _ = (print "hd sorted = "; TensorFile.realListLineWrite TextIO.stdOut (hd sorted))
val _ = (print "distance(x,hd sorted) = "; TensorFile.realWrite TextIO.stdOut (distance3D (x,(hd sorted))))
in
if not (ListPair.all (fn (x,y) => Real.>= (1E~16, Real.- (x,y) )) (nn', (hd sorted)))
then raise Fail "nearestNeighbor" else ()
end
fun testNearNeighbors (sorted,t,x,r) =
let
val P = KDTree.pointSpace t
val nns = KDTree.nearNeighbors t r x
val nns' = sortPoints (x, List.map (fn (nn) => KDTree.S.pointList (KDTree.S.point P nn))
nns)
val (sss,_) = List.partition (fn (p) => Real.<= (distance3D (p,x), r))
sorted
in
if not (ListPair.all
(fn (nn,ss) => (ListPair.all (fn (x,y) => Real.>= (1E~16, Real.- (x,y) )) (nn,ss)))
(nns', sss))
then raise Fail "nearNeighbors" else ()
end
end
fun realRandomTensor (xseed,yseed) shape =
let
val length = Index.length shape
val seed = Random.rand (xseed,yseed)
val a = RTensor.Array.array(length, Random.randReal seed)
fun loop 0 = RTensor.fromArray(shape, a)
| loop j = (RTensor.Array.update(a, length-j, Random.randReal seed);
loop (j-1))
in
loop (length - 1)
end
structure TensorKdTreeTest = KdTreeTestFn (structure KDTree = TensorKDTree
val distance = distance3D)
structure MapKdTreeTest = KdTreeTestFn (structure KDTree = MapKDTree
val distance = distance3D)
val M = 100000
val N = 3
val _ = print ("constructing tensor point space...\n")
val P = realRandomTensor (13,17) [M,N]
val _ = print ("constructing map point space...\n")
val (PM,ti) = timing (fn () =>
Loop.foldi (0, M, fn (i,pm) =>
let
val p = [RTensor.sub (P, [i,0]),
RTensor.sub (P, [i,1]),
RTensor.sub (P, [i,2])]
in
#2(MapKDTree.S.insert(p,pm) )
end,
MapKDTree.S.empty))
val _ = print ("map point space constructed (" ^ (Time.toString ti) ^ " s)\n")
fun sortPoints distance (origin,pts) =
ListMergeSort.sort
(fn (x,y) => Real.> (distance (origin,x), distance (origin,y)))
pts
val pts = let
fun recur (i, lst) =
if (i < M)
then (let val p = [RTensor.sub (P, [i,0]),
RTensor.sub (P, [i,1]),
RTensor.sub (P, [i,2])] in
recur (i+1,p :: lst) end)
else lst
in
recur (0, [])
end
val _ = print ("constructing tensor tree...\n")
val (tt,ti) = timing (fn () => TensorKDTree.fromPoints P)
val _ = print ("tensor tree constructed (" ^ (Time.toString ti) ^ " s)\n")
val _ = print ("constructing map tree (addPoint)...\n")
val ((_,mt2),ti) = timing (fn () => foldl (fn(p,(n,ax)) =>
(n+1,MapKDTree.addPoint (128,512) (p,ax)))
(0,MapKDTree.empty) pts)
val _ = print ("map tree constructed (" ^ (Time.toString ti) ^ " s)\n")
val _ = print ("constructing map tree (fromPoints)...\n")
val (mt,ti) = timing (fn () => MapKDTree.fromPoints PM)
val _ = print ("map tree constructed (" ^ (Time.toString ti) ^ " s)\n")
val _ = print ("tensor tree size = " ^ (Int.toString (TensorKDTree.size tt)) ^ "\n")
val _ = print ("map tree size = " ^ (Int.toString (TensorKDTree.size tt)) ^ "\n")
val _ = print ("length of tree list = " ^ (Int.toString (List.length (TensorKDTree.toList tt))) ^ "\n")
val _ = TensorKdTreeTest.check tt
val _ = print "tensor tree consistency check passed\n"
val _ = MapKdTreeTest.check mt2
val _ = print "map tree consistency check passed\n"
val _ = MapKdTreeTest.check mt
val _ = print "map tree consistency check passed\n"
val seed = Random.rand (19,21)
val Ntrials = 2
val _ = let fun recur (i) =
if i > 0
then
(let
val xi = Int.mod (Random.randNat seed, N)
val x = [Real.+ (0.1, RTensor.sub (P, [xi,0])),
Real.- (RTensor.sub (P, [xi,1]), 0.1),
Real.+ (0.1, RTensor.sub (P, [xi,2]))]
val sorted = sortPoints distance3D (x,pts)
val _ = print ("test trial " ^ (Int.toString i) ^ "\n")
val (_,ti) = timing (fn () => TensorKdTreeTest.testNearestNeighbor (sorted,tt,x))
val _ = print ("(tensor tree) nearest neighbor check passed (" ^ (Time.toString ti) ^ " s)\n")
val (_,ti) = timing (fn () => MapKdTreeTest.testNearestNeighbor (sorted,mt,x))
val _ = print ("(map tree) nearest neighbor check passed (" ^ (Time.toString ti) ^ " s)\n")
val (_,ti) = timing (fn () => MapKdTreeTest.testNearestNeighbor (sorted,mt2,x))
val _ = print ("(dynamic map tree) nearest neighbor check passed (" ^ (Time.toString ti) ^ " s)\n")
val (_,ti) = timing (fn () => TensorKdTreeTest.testNearNeighbors (sorted,tt,x,0.3))
val _ = print ("(tensor tree) near neighbors check passed (" ^ (Time.toString ti) ^ "s)\n")
val (_,ti) = timing (fn () => MapKdTreeTest.testNearNeighbors (sorted,mt,x,0.3))
val _ = print ("(map tree) near neighbors check passed (" ^ (Time.toString ti) ^ "s)\n")
val (_,ti) = timing (fn () => MapKdTreeTest.testNearNeighbors (sorted,mt2,x,0.3))
val _ = print ("(dynamic map tree) near neighbors check passed (" ^ (Time.toString ti) ^ "s)\n")
in
recur (i-1)
end)
else ()
in
recur (Ntrials)
end