@@ -19,7 +19,7 @@ class Point p where
19
19
-- | dist2 returns the squared distance between two points.
20
20
dist2 :: p -> p -> Double
21
21
dist2 a b = sum . map diff2 $ [0 .. dimension a - 1 ]
22
- where diff2 i = (coord i a - coord i b)^ 2
22
+ where diff2 i = (coord i a - coord i b)^ 2
23
23
24
24
-- | compareDistance p a b compares the distances of a and b to p.
25
25
compareDistance :: (Point p ) => p -> p -> p -> Ordering
@@ -37,9 +37,9 @@ instance Point Point3d where
37
37
38
38
39
39
data KdTree point = KdNode { kdLeft :: KdTree point ,
40
- kdPoint :: point ,
40
+ kdPoint :: point ,
41
41
kdRight :: KdTree point ,
42
- kdAxis :: Int }
42
+ kdAxis :: Int }
43
43
| KdEmpty
44
44
deriving (Eq , Ord , Show )
45
45
@@ -50,8 +50,8 @@ instance Functor KdTree where
50
50
instance F. Foldable KdTree where
51
51
foldr f init KdEmpty = init
52
52
foldr f init (KdNode l x r _) = F. foldr f init3 l
53
- where init3 = f x init2
54
- init2 = F. foldr f init r
53
+ where init3 = f x init2
54
+ init2 = F. foldr f init r
55
55
56
56
fromList :: Point p => [p ] -> KdTree p
57
57
fromList points = fromListWithDepth points 0
@@ -62,16 +62,16 @@ fromListWithDepth [] _ = KdEmpty
62
62
fromListWithDepth points depth = node
63
63
where axis = axisFromDepth (head points) depth
64
64
65
- -- Sort point list and choose median as pivot element
66
- sortedPoints =
67
- L. sortBy (\ a b -> coord axis a `compare` coord axis b) points
68
- medianIndex = length sortedPoints `div` 2
69
-
70
- -- Create node and construct subtrees
71
- node = KdNode { kdLeft = fromListWithDepth (take medianIndex sortedPoints) (depth+ 1 ),
72
- kdPoint = sortedPoints !! medianIndex,
73
- kdRight = fromListWithDepth (drop (medianIndex+ 1 ) sortedPoints) (depth+ 1 ),
74
- kdAxis = axis }
65
+ -- Sort point list and choose median as pivot element
66
+ sortedPoints =
67
+ L. sortBy (\ a b -> coord axis a `compare` coord axis b) points
68
+ medianIndex = length sortedPoints `div` 2
69
+
70
+ -- Create node and construct subtrees
71
+ node = KdNode { kdLeft = fromListWithDepth (take medianIndex sortedPoints) (depth+ 1 ),
72
+ kdPoint = sortedPoints !! medianIndex,
73
+ kdRight = fromListWithDepth (drop (medianIndex+ 1 ) sortedPoints) (depth+ 1 ),
74
+ kdAxis = axis }
75
75
76
76
axisFromDepth :: Point p => p -> Int -> Int
77
77
axisFromDepth p depth = depth `mod` k
@@ -90,18 +90,18 @@ nearestNeighbor (KdNode KdEmpty p KdEmpty _) probe = Just p
90
90
nearestNeighbor (KdNode l p r axis) probe =
91
91
if xProbe <= xp then doStuff l r else doStuff r l
92
92
where xProbe = coord axis probe
93
- xp = coord axis p
93
+ xp = coord axis p
94
94
doStuff tree1 tree2 =
95
- let candidates1 = case nearestNeighbor tree1 probe of
96
- Nothing -> [p]
97
- Just best1 -> [best1, p]
98
- sphereIntersectsPlane = (xProbe - xp)^ 2 <= dist2 probe p
99
- candidates2 = if sphereIntersectsPlane
100
- then candidates1 ++ maybeToList (nearestNeighbor tree2 probe)
101
- else candidates1 in
102
- Just . L. minimumBy (compareDistance probe) $ candidates2
103
-
104
- -- | invariant tells whether the KD tree property holds for a given tree and
95
+ let candidates1 = case nearestNeighbor tree1 probe of
96
+ Nothing -> [p]
97
+ Just best1 -> [best1, p]
98
+ sphereIntersectsPlane = (xProbe - xp)^ 2 <= dist2 probe p
99
+ candidates2 = if sphereIntersectsPlane
100
+ then candidates1 ++ maybeToList (nearestNeighbor tree2 probe)
101
+ else candidates1 in
102
+ Just . L. minimumBy (compareDistance probe) $ candidates2
103
+
104
+ -- | invariant tells whether the K-D tree property holds for a given tree and
105
105
-- all its subtrees.
106
106
-- Specifically, it tests that all points in the left subtree lie to the left
107
107
-- of the plane, p is on the plane, and all points in the right subtree lie to
@@ -110,16 +110,33 @@ invariant :: Point p => KdTree p -> Bool
110
110
invariant KdEmpty = True
111
111
invariant (KdNode l p r axis) = leftIsGood && rightIsGood
112
112
where x = coord axis p
113
- leftIsGood = all ((<= x) . coord axis) (toList l)
114
- rightIsGood = all ((>= x) . coord axis) (toList r)
113
+ leftIsGood = all ((<= x) . coord axis) (toList l)
114
+ rightIsGood = all ((>= x) . coord axis) (toList r)
115
115
116
+ -- | invariant' tells whether the K-D tree property holds for all subtrees.
116
117
invariant' :: Point p => KdTree p -> Bool
117
118
invariant' = all invariant . subtrees
118
119
120
+ kNearestNeighbors :: (Eq p , Point p ) => KdTree p -> Int -> p -> [p ]
121
+ kNearestNeighbors KdEmpty _ _ = []
122
+ kNearestNeighbors _ k _ | k <= 0 = []
123
+ kNearestNeighbors tree k probe = nearest : kNearestNeighbors tree' (k- 1 ) probe
124
+ where nearest = fromJust $ nearestNeighbor tree probe
125
+ tree' = tree `remove` nearest
126
+
127
+ remove :: (Eq p , Point p ) => KdTree p -> p -> KdTree p
128
+ remove KdEmpty _ = KdEmpty
129
+ remove (KdNode l p r axis) pKill =
130
+ if p == pKill
131
+ then fromListWithDepth (toList l ++ toList r) axis
132
+ else if coord axis pKill <= coord axis p
133
+ then KdNode (remove l pKill) p r axis
134
+ else KdNode l p (remove r pKill) axis
135
+
119
136
instance Arbitrary Point3d where
120
137
arbitrary = do
121
- x <- arbitrary
122
- y <- arbitrary
123
- z <- arbitrary
124
- return (Point3d x y z)
138
+ x <- arbitrary
139
+ y <- arbitrary
140
+ z <- arbitrary
141
+ return (Point3d x y z)
125
142
0 commit comments