Skip to content

Commit b6cc554

Browse files
committed
Add functions kNearestNeighbors and remove.
1 parent 604f6ca commit b6cc554

File tree

3 files changed

+59
-42
lines changed

3 files changed

+59
-42
lines changed

Data/Trees/KdTree.hs

Lines changed: 49 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class Point p where
1919
-- |dist2 returns the squared distance between two points.
2020
dist2 :: p -> p -> Double
2121
dist2 a b = sum . map diff2 $ [0..dimension a - 1]
22-
where diff2 i = (coord i a - coord i b)^2
22+
where diff2 i = (coord i a - coord i b)^2
2323

2424
-- |compareDistance p a b compares the distances of a and b to p.
2525
compareDistance :: (Point p) => p -> p -> p -> Ordering
@@ -37,9 +37,9 @@ instance Point Point3d where
3737

3838

3939
data KdTree point = KdNode { kdLeft :: KdTree point,
40-
kdPoint :: point,
40+
kdPoint :: point,
4141
kdRight :: KdTree point,
42-
kdAxis :: Int }
42+
kdAxis :: Int }
4343
| KdEmpty
4444
deriving (Eq, Ord, Show)
4545

@@ -50,8 +50,8 @@ instance Functor KdTree where
5050
instance F.Foldable KdTree where
5151
foldr f init KdEmpty = init
5252
foldr f init (KdNode l x r _) = F.foldr f init3 l
53-
where init3 = f x init2
54-
init2 = F.foldr f init r
53+
where init3 = f x init2
54+
init2 = F.foldr f init r
5555

5656
fromList :: Point p => [p] -> KdTree p
5757
fromList points = fromListWithDepth points 0
@@ -62,16 +62,16 @@ fromListWithDepth [] _ = KdEmpty
6262
fromListWithDepth points depth = node
6363
where axis = axisFromDepth (head points) depth
6464

65-
-- Sort point list and choose median as pivot element
66-
sortedPoints =
67-
L.sortBy (\a b -> coord axis a `compare` coord axis b) points
68-
medianIndex = length sortedPoints `div` 2
69-
70-
-- Create node and construct subtrees
71-
node = KdNode { kdLeft = fromListWithDepth (take medianIndex sortedPoints) (depth+1),
72-
kdPoint = sortedPoints !! medianIndex,
73-
kdRight = fromListWithDepth (drop (medianIndex+1) sortedPoints) (depth+1),
74-
kdAxis = axis }
65+
-- Sort point list and choose median as pivot element
66+
sortedPoints =
67+
L.sortBy (\a b -> coord axis a `compare` coord axis b) points
68+
medianIndex = length sortedPoints `div` 2
69+
70+
-- Create node and construct subtrees
71+
node = KdNode { kdLeft = fromListWithDepth (take medianIndex sortedPoints) (depth+1),
72+
kdPoint = sortedPoints !! medianIndex,
73+
kdRight = fromListWithDepth (drop (medianIndex+1) sortedPoints) (depth+1),
74+
kdAxis = axis }
7575

7676
axisFromDepth :: Point p => p -> Int -> Int
7777
axisFromDepth p depth = depth `mod` k
@@ -90,18 +90,18 @@ nearestNeighbor (KdNode KdEmpty p KdEmpty _) probe = Just p
9090
nearestNeighbor (KdNode l p r axis) probe =
9191
if xProbe <= xp then doStuff l r else doStuff r l
9292
where xProbe = coord axis probe
93-
xp = coord axis p
93+
xp = coord axis p
9494
doStuff tree1 tree2 =
95-
let candidates1 = case nearestNeighbor tree1 probe of
96-
Nothing -> [p]
97-
Just best1 -> [best1, p]
98-
sphereIntersectsPlane = (xProbe - xp)^2 <= dist2 probe p
99-
candidates2 = if sphereIntersectsPlane
100-
then candidates1 ++ maybeToList (nearestNeighbor tree2 probe)
101-
else candidates1 in
102-
Just . L.minimumBy (compareDistance probe) $ candidates2
103-
104-
-- |invariant tells whether the KD tree property holds for a given tree and
95+
let candidates1 = case nearestNeighbor tree1 probe of
96+
Nothing -> [p]
97+
Just best1 -> [best1, p]
98+
sphereIntersectsPlane = (xProbe - xp)^2 <= dist2 probe p
99+
candidates2 = if sphereIntersectsPlane
100+
then candidates1 ++ maybeToList (nearestNeighbor tree2 probe)
101+
else candidates1 in
102+
Just . L.minimumBy (compareDistance probe) $ candidates2
103+
104+
-- |invariant tells whether the K-D tree property holds for a given tree and
105105
-- all its subtrees.
106106
-- Specifically, it tests that all points in the left subtree lie to the left
107107
-- of the plane, p is on the plane, and all points in the right subtree lie to
@@ -110,16 +110,33 @@ invariant :: Point p => KdTree p -> Bool
110110
invariant KdEmpty = True
111111
invariant (KdNode l p r axis) = leftIsGood && rightIsGood
112112
where x = coord axis p
113-
leftIsGood = all ((<= x) . coord axis) (toList l)
114-
rightIsGood = all ((>= x) . coord axis) (toList r)
113+
leftIsGood = all ((<= x) . coord axis) (toList l)
114+
rightIsGood = all ((>= x) . coord axis) (toList r)
115115

116+
-- |invariant' tells whether the K-D tree property holds for all subtrees.
116117
invariant' :: Point p => KdTree p -> Bool
117118
invariant' = all invariant . subtrees
118119

120+
kNearestNeighbors :: (Eq p, Point p) => KdTree p -> Int -> p -> [p]
121+
kNearestNeighbors KdEmpty _ _ = []
122+
kNearestNeighbors _ k _ | k <= 0 = []
123+
kNearestNeighbors tree k probe = nearest : kNearestNeighbors tree' (k-1) probe
124+
where nearest = fromJust $ nearestNeighbor tree probe
125+
tree' = tree `remove` nearest
126+
127+
remove :: (Eq p, Point p) => KdTree p -> p -> KdTree p
128+
remove KdEmpty _ = KdEmpty
129+
remove (KdNode l p r axis) pKill =
130+
if p == pKill
131+
then fromListWithDepth (toList l ++ toList r) axis
132+
else if coord axis pKill <= coord axis p
133+
then KdNode (remove l pKill) p r axis
134+
else KdNode l p (remove r pKill) axis
135+
119136
instance Arbitrary Point3d where
120137
arbitrary = do
121-
x <- arbitrary
122-
y <- arbitrary
123-
z <- arbitrary
124-
return (Point3d x y z)
138+
x <- arbitrary
139+
y <- arbitrary
140+
z <- arbitrary
141+
return (Point3d x y z)
125142

KdTreeTest.hs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,22 @@ prop_nearestNeighbor :: [Kd.Point3d] -> Kd.Point3d -> Bool
2020
prop_nearestNeighbor points probe =
2121
Kd.nearestNeighbor tree probe == bruteNearestNeighbor points probe
2222
where tree = Kd.fromList points
23+
bruteNearestNeighbor :: [Kd.Point3d] -> Kd.Point3d -> Maybe Kd.Point3d
24+
bruteNearestNeighbor [] _ = Nothing
25+
bruteNearestNeighbor points probe =
26+
Just . head . L.sortBy (Kd.compareDistance probe) $ points
2327

2428
prop_pointsAreClosestToThemselves :: [Kd.Point3d] -> Bool
2529
prop_pointsAreClosestToThemselves points =
2630
map Just points == map (Kd.nearestNeighbor tree) points
2731
where tree = Kd.fromList points
2832

29-
bruteNearestNeighbor :: [Kd.Point3d] -> Kd.Point3d -> Maybe Kd.Point3d
30-
bruteNearestNeighbor [] _ = Nothing
31-
bruteNearestNeighbor points probe =
32-
Just . head . L.sortBy (Kd.compareDistance probe) $ points
33+
prop_kNearestNeighborsMatchesBrute :: [Kd.Point3d] -> Int -> Kd.Point3d -> Bool
34+
prop_kNearestNeighborsMatchesBrute points k p =
35+
L.sort (Kd.kNearestNeighbors tree k p) == L.sort (bruteKnearestNeighbors points k p)
36+
where tree = Kd.fromList points
37+
bruteKnearestNeighbors points k p =
38+
take k . L.sortBy (Kd.compareDistance p) $ points
3339

3440
main = $quickCheckAll
3541

Makefile

Lines changed: 0 additions & 6 deletions
This file was deleted.

0 commit comments

Comments
 (0)