Skip to content

Commit 7458295

Browse files
committed
initial KdTree implementation (48/100)
1 parent 2bebc4b commit 7458295

File tree

2 files changed

+174
-23
lines changed

2 files changed

+174
-23
lines changed

week5/balanced-search-trees/assignment-kd-trees/KdTree.java

Lines changed: 173 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,209 @@
1-
import java.util.TreeSet;
21
import java.util.ArrayList;
32
import edu.princeton.cs.algs4.Point2D;
43
import edu.princeton.cs.algs4.RectHV;
54

5+
// 2d-tree implementation. Write a mutable data type KdTree.java that uses a 2d-tree to implement the same API (but replace PointSET with KdTree). A 2d-tree is a generalization of a BST to two-dimensional keys. The idea is to build a BST with points in the nodes, using the x- and y-coordinates of the points as keys in strictly alternating sequence.
6+
67
public class KdTree {
7-
private TreeSet<Point2D> tree;
8+
private class Node {
9+
private Point2D point;
10+
private Node left;
11+
private Node right;
12+
private boolean isVertical;
13+
14+
public Node(Point2D point, boolean isVertical) {
15+
this.point = point;
16+
this.isVertical = isVertical;
17+
}
18+
19+
public int size() {
20+
int size = 1;
21+
if (this.left != null) {
22+
size += this.left.size();
23+
}
24+
if (this.right != null) {
25+
size += this.right.size();
26+
}
27+
return size;
28+
}
29+
30+
public Node insert(Point2D p) {
31+
if (this.point.equals(p)) {
32+
return this;
33+
}
34+
if (this.isVertical) {
35+
if (p.x() < this.point.x()) {
36+
if (this.left == null) {
37+
this.left = new Node(p, false);
38+
} else {
39+
this.left = this.left.insert(p);
40+
}
41+
} else {
42+
if (this.right == null) {
43+
this.right = new Node(p, false);
44+
} else {
45+
this.right = this.right.insert(p);
46+
}
47+
}
48+
} else {
49+
if (p.y() < this.point.y()) {
50+
if (this.left == null) {
51+
this.left = new Node(p, true);
52+
} else {
53+
this.left = this.left.insert(p);
54+
}
55+
} else {
56+
if (this.right == null) {
57+
this.right = new Node(p, true);
58+
} else {
59+
this.right = this.right.insert(p);
60+
}
61+
}
62+
}
63+
return this;
64+
}
65+
66+
public boolean contains(Point2D p) {
67+
if (this.point.equals(p)) {
68+
return true;
69+
}
70+
if (this.isVertical) {
71+
if (p.x() < this.point.x()) {
72+
if (this.left == null) {
73+
return false;
74+
}
75+
return this.left.contains(p);
76+
} else {
77+
if (this.right == null) {
78+
return false;
79+
}
80+
return this.right.contains(p);
81+
}
82+
} else {
83+
if (p.y() < this.point.y()) {
84+
if (this.left == null) {
85+
return false;
86+
}
87+
return this.left.contains(p);
88+
} else {
89+
if (this.right == null) {
90+
return false;
91+
}
92+
return this.right.contains(p);
93+
}
94+
}
95+
}
96+
97+
public void draw() {
98+
this.point.draw();
99+
if (this.left != null) {
100+
this.left.draw();
101+
}
102+
if (this.right != null) {
103+
this.right.draw();
104+
}
105+
}
106+
}
107+
108+
private Node root;
8109

9110
// construct an empty set of points
10111
public KdTree() {
11-
this.tree = new TreeSet<Point2D>();
112+
this.root = null;
12113
}
13114
// is the set empty?
14115
public boolean isEmpty() {
15-
return this.tree.isEmpty();
116+
return this.root == null;
16117
}
17118
// number of points in the set
18119
public int size() {
19-
return tree.size();
120+
if (this.isEmpty()) return 0;
121+
return this.root.size();
20122
}
21123
// add the point to the set (if it is not already in the set)
22124
public void insert(Point2D p) {
23-
tree.add(p);
125+
if (p == null) throw new IllegalArgumentException("insert method called with a null argument");
126+
this.root = this.root == null ? new Node(p, true) : this.root.insert(p);
24127
}
25-
// // does the set contain point p?
128+
// does the set contain point p?
26129
public boolean contains(Point2D p) {
27-
return tree.contains(p);
130+
if (p == null) throw new IllegalArgumentException("contains method called with a null argument");
131+
if (this.isEmpty()) return false;
132+
return this.root.contains(p);
28133
}
29134
// draw all points to standard draw
30135
public void draw() {
31-
for (Point2D p : tree) {
32-
p.draw();
33-
}
136+
if (this.isEmpty()) return;
137+
138+
this.root.draw();
34139
}
35140
// all points that are inside the rectangle (or on the boundary)
36141
public Iterable<Point2D> range(RectHV rect) {
142+
if (rect == null) throw new IllegalArgumentException("range method called with a null argument");
143+
37144
ArrayList<Point2D> points = new ArrayList<Point2D>();
38-
for (Point2D p : tree) {
39-
if (rect.contains(p)) {
40-
points.add(p);
145+
if (this.isEmpty()) return points;
146+
147+
rangeSearch(rect, root, points);
148+
149+
return points;
150+
}
151+
152+
private void rangeSearch(RectHV rect, Node node, ArrayList<Point2D> points) {
153+
if (node == null) return;
154+
155+
if (rect.contains(node.point)) {
156+
points.add(node.point);
157+
}
158+
159+
if (node.isVertical) {
160+
if (rect.xmin() < node.point.x()) {
161+
rangeSearch(rect, node.left, points);
162+
}
163+
if (rect.xmax() >= node.point.x()) {
164+
rangeSearch(rect, node.right, points);
165+
}
166+
} else {
167+
if (rect.ymin() < node.point.y()) {
168+
rangeSearch(rect, node.left, points);
169+
}
170+
if (rect.ymax() >= node.point.y()) {
171+
rangeSearch(rect, node.right, points);
41172
}
42173
}
43-
return points;
44174
}
45175
// // a nearest neighbor in the set to point p; null if the set is empty
46176
public Point2D nearest(Point2D p) {
47-
Point2D nearest = null;
48-
double minDistance = Double.POSITIVE_INFINITY;
49-
for (Point2D point : tree) {
50-
double distance = p.distanceSquaredTo(point);
51-
if (distance < minDistance) {
52-
nearest = point;
53-
minDistance = distance;
177+
return nearestSearch(p, this.root, null, Double.POSITIVE_INFINITY);
178+
}
179+
180+
private Point2D nearestSearch(Point2D p, Node node, Point2D nearest, double minDistance) {
181+
if (node == null) return nearest;
182+
183+
double distance = p.distanceSquaredTo(node.point);
184+
if (distance < minDistance) {
185+
nearest = node.point;
186+
minDistance = distance;
187+
}
188+
189+
if (node.isVertical) {
190+
if (p.x() < node.point.x()) {
191+
nearest = nearestSearch(p, node.left, nearest, minDistance);
192+
nearest = nearestSearch(p, node.right, nearest, nearest.distanceSquaredTo(p));
193+
} else {
194+
nearest = nearestSearch(p, node.right, nearest, minDistance);
195+
nearest = nearestSearch(p, node.left, nearest, nearest.distanceSquaredTo(p));
196+
}
197+
} else {
198+
if (p.y() < node.point.y()) {
199+
nearest = nearestSearch(p, node.left, nearest, minDistance);
200+
nearest = nearestSearch(p, node.right, nearest, nearest.distanceSquaredTo(p));
201+
} else {
202+
nearest = nearestSearch(p, node.right, nearest, minDistance);
203+
nearest = nearestSearch(p, node.left, nearest, nearest.distanceSquaredTo(p));
54204
}
55205
}
206+
56207
return nearest;
57208
}
58209

week5/balanced-search-trees/assignment-kd-trees/PointSET.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ public Point2D nearest(Point2D p) {
4747
Point2D nearest = null;
4848
double minDistance = Double.POSITIVE_INFINITY;
4949
for (Point2D point : tree) {
50-
double distance = p.distanceTo(point);
50+
double distance = p.distanceSquaredTo(point);
5151
if (distance < minDistance) {
5252
nearest = point;
5353
minDistance = distance;

0 commit comments

Comments
 (0)