Skip to content

Commit 96193e0

Browse files
committed
nearest() better performance (54/100)
1 parent 460c472 commit 96193e0

File tree

1 file changed

+28
-19
lines changed
  • week5/balanced-search-trees/assignment-kd-trees

1 file changed

+28
-19
lines changed

week5/balanced-search-trees/assignment-kd-trees/KdTree.java

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import java.util.ArrayList;
2+
23
import edu.princeton.cs.algs4.Point2D;
34
import edu.princeton.cs.algs4.RectHV;
45

@@ -172,40 +173,48 @@ private void rangeSearch(RectHV rect, Node node, ArrayList<Point2D> points) {
172173
}
173174
}
174175
}
175-
// // a nearest neighbor in the set to point p; null if the set is empty
176+
177+
// a nearest neighbor in the set to point p; null if the set is empty
176178
public Point2D nearest(Point2D p) {
177179
if (p == null) throw new IllegalArgumentException("nearest method called with a null argument");
178-
return nearestSearch(p, this.root, null, Double.POSITIVE_INFINITY);
180+
return nearestSearch(this.root, p, null, Double.POSITIVE_INFINITY);
179181
}
180182

181-
private Point2D nearestSearch(Point2D p, Node node, Point2D nearest, double minDistance) {
182-
if (node == null) return nearest;
183+
private Point2D nearestSearch(Node node, Point2D target, Point2D nearestPoint, double nearestDistance) {
184+
if (node == null) return null;
183185

184-
double distance = p.distanceSquaredTo(node.point);
185-
if (distance < minDistance) {
186-
nearest = node.point;
187-
minDistance = distance;
186+
double distance = node.point.distanceSquaredTo(target);
187+
if (distance < nearestDistance) {
188+
nearestPoint = node.point;
189+
nearestDistance = distance;
188190
}
189191

192+
Node first, second;
190193
if (node.isVertical) {
191-
if (p.x() < node.point.x()) {
192-
nearest = nearestSearch(p, node.left, nearest, minDistance);
193-
nearest = nearestSearch(p, node.right, nearest, nearest.distanceSquaredTo(p));
194+
if (target.x() < node.point.x()) {
195+
first = node.left;
196+
second = node.right;
194197
} else {
195-
nearest = nearestSearch(p, node.right, nearest, minDistance);
196-
nearest = nearestSearch(p, node.left, nearest, nearest.distanceSquaredTo(p));
198+
first = node.right;
199+
second = node.left;
197200
}
198201
} else {
199-
if (p.y() < node.point.y()) {
200-
nearest = nearestSearch(p, node.left, nearest, minDistance);
201-
nearest = nearestSearch(p, node.right, nearest, nearest.distanceSquaredTo(p));
202+
if (target.y() < node.point.y()) {
203+
first = node.left;
204+
second = node.right;
202205
} else {
203-
nearest = nearestSearch(p, node.right, nearest, minDistance);
204-
nearest = nearestSearch(p, node.left, nearest, nearest.distanceSquaredTo(p));
206+
first = node.right;
207+
second = node.left;
205208
}
206209
}
207210

208-
return nearest;
211+
nearestPoint = nearestSearch(first, target, nearestPoint, nearestDistance);
212+
213+
if (second != null && second.point.distanceSquaredTo(target) < nearestDistance) {
214+
nearestPoint = nearestSearch(second, target, nearestPoint, nearestDistance);
215+
}
216+
217+
return nearestPoint;
209218
}
210219

211220
// unit testing of the methods (optional)

0 commit comments

Comments
 (0)