Skip to content

Commit 3fc09c6

Browse files
author
linyiqun
committed
算法工具封装类
算法工具封装类
1 parent ab2389a commit 3fc09c6

File tree

1 file changed

+386
-0
lines changed

1 file changed

+386
-0
lines changed
Lines changed: 386 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,386 @@
1+
package DataMining_KDTree;
2+
3+
import java.io.BufferedReader;
4+
import java.io.File;
5+
import java.io.FileReader;
6+
import java.io.IOException;
7+
import java.util.ArrayList;
8+
import java.util.Collections;
9+
import java.util.Comparator;
10+
import java.util.Stack;
11+
12+
/**
13+
* KD树-k维空间关键数据检索算法工具类
14+
*
15+
* @author lyq
16+
*
17+
*/
18+
public class KDTreeTool {
19+
// 空间平面的方向
20+
public static final int DIRECTION_X = 0;
21+
public static final int DIRECTION_Y = 1;
22+
23+
// 输入的测试数据坐标点文件
24+
private String filePath;
25+
// 原始所有数据点数据
26+
private ArrayList<Point> totalDatas;
27+
// KD树根节点
28+
private TreeNode rootNode;
29+
30+
public KDTreeTool(String filePath) {
31+
this.filePath = filePath;
32+
33+
readDataFile();
34+
}
35+
36+
/**
37+
* 从文件中读取数据
38+
*/
39+
private void readDataFile() {
40+
File file = new File(filePath);
41+
ArrayList<String[]> dataArray = new ArrayList<String[]>();
42+
43+
try {
44+
BufferedReader in = new BufferedReader(new FileReader(file));
45+
String str;
46+
String[] tempArray;
47+
while ((str = in.readLine()) != null) {
48+
tempArray = str.split(" ");
49+
dataArray.add(tempArray);
50+
}
51+
in.close();
52+
} catch (IOException e) {
53+
e.getStackTrace();
54+
}
55+
56+
Point p;
57+
totalDatas = new ArrayList<>();
58+
for (String[] array : dataArray) {
59+
p = new Point(array[0], array[1]);
60+
totalDatas.add(p);
61+
}
62+
}
63+
64+
/**
65+
* 创建KD树
66+
*
67+
* @return
68+
*/
69+
public TreeNode createKDTree() {
70+
ArrayList<Point> copyDatas;
71+
72+
rootNode = new TreeNode();
73+
// 根据节点开始时所表示的空间时无限大的
74+
rootNode.range = new Range();
75+
copyDatas = (ArrayList<Point>) totalDatas.clone();
76+
recusiveConstructNode(rootNode, copyDatas);
77+
78+
return rootNode;
79+
}
80+
81+
/**
82+
* 递归进行KD树的构造
83+
*
84+
* @param node
85+
* 当前正在构造的节点
86+
* @param datas
87+
* 该节点对应的正在处理的数据
88+
* @return
89+
*/
90+
private void recusiveConstructNode(TreeNode node, ArrayList<Point> datas) {
91+
int direction = 0;
92+
ArrayList<Point> leftSideDatas;
93+
ArrayList<Point> rightSideDatas;
94+
Point p;
95+
TreeNode leftNode;
96+
TreeNode rightNode;
97+
Range range;
98+
Range range2;
99+
100+
// 如果划分的数据点集合只有1个数据,则不再划分
101+
if (datas.size() == 1) {
102+
node.nodeData = datas.get(0);
103+
return;
104+
}
105+
106+
// 首先在当前的数据点集合中进行分割方向的选择
107+
direction = selectSplitDrc(datas);
108+
// 根据方向取出中位数点作为数据矢量
109+
p = getMiddlePoint(datas, direction);
110+
111+
node.spilt = direction;
112+
node.nodeData = p;
113+
114+
leftSideDatas = getLeftSideDatas(datas, p, direction);
115+
datas.removeAll(leftSideDatas);
116+
// 还要去掉自身
117+
datas.remove(p);
118+
rightSideDatas = datas;
119+
120+
if (leftSideDatas.size() > 0) {
121+
leftNode = new TreeNode();
122+
leftNode.parentNode = node;
123+
range2 = Range.initLeftRange(p, direction);
124+
// 获取父节点的空间矢量,进行交集运算做范围拆分
125+
range = node.range.crossOperation(range2);
126+
leftNode.range = range;
127+
128+
node.leftNode = leftNode;
129+
recusiveConstructNode(leftNode, leftSideDatas);
130+
}
131+
132+
if (rightSideDatas.size() > 0) {
133+
rightNode = new TreeNode();
134+
rightNode.parentNode = node;
135+
range2 = Range.initRightRange(p, direction);
136+
// 获取父节点的空间矢量,进行交集运算做范围拆分
137+
range = node.range.crossOperation(range2);
138+
rightNode.range = range;
139+
140+
node.rightNode = rightNode;
141+
recusiveConstructNode(rightNode, rightSideDatas);
142+
}
143+
}
144+
145+
/**
146+
* 搜索出给定数据点的最近点
147+
*
148+
* @param p
149+
* 待比较坐标点
150+
*/
151+
public Point searchNearestData(Point p) {
152+
// 节点距离给定数据点的距离
153+
TreeNode nearestNode = null;
154+
// 用栈记录遍历过的节点
155+
Stack<TreeNode> stackNodes;
156+
157+
stackNodes = new Stack<>();
158+
findedNearestLeafNode(p, rootNode, stackNodes);
159+
160+
// 取出叶子节点,作为当前找到的最近节点
161+
nearestNode = stackNodes.pop();
162+
nearestNode = dfsSearchNodes(stackNodes, p, nearestNode);
163+
164+
return nearestNode.nodeData;
165+
}
166+
167+
/**
168+
* 深度优先的方式进行最近点的查找
169+
*
170+
* @param stack
171+
* KD树节点栈
172+
* @param desPoint
173+
* 给定的数据点
174+
* @param nearestNode
175+
* 当前找到的最近节点
176+
* @return
177+
*/
178+
private TreeNode dfsSearchNodes(Stack<TreeNode> stack, Point desPoint,
179+
TreeNode nearestNode) {
180+
// 是否碰到父节点边界
181+
boolean isCollision;
182+
double minDis;
183+
double dis;
184+
TreeNode parentNode;
185+
186+
// 如果栈内节点已经全部弹出,则遍历结束
187+
if (stack.isEmpty()) {
188+
return nearestNode;
189+
}
190+
191+
// 获取父节点
192+
parentNode = stack.pop();
193+
194+
minDis = desPoint.ouDistance(nearestNode.nodeData);
195+
dis = desPoint.ouDistance(parentNode.nodeData);
196+
197+
// 如果与当前回溯到的父节点距离更短,则搜索到的节点进行更新
198+
if (dis < minDis) {
199+
minDis = dis;
200+
nearestNode = parentNode;
201+
}
202+
203+
// 默认没有碰撞到
204+
isCollision = false;
205+
// 判断是否触碰到了父节点的空间分割线
206+
if (parentNode.spilt == DIRECTION_X) {
207+
if (parentNode.nodeData.x > desPoint.x - minDis
208+
&& parentNode.nodeData.x < desPoint.x + minDis) {
209+
isCollision = true;
210+
}
211+
} else {
212+
if (parentNode.nodeData.y > desPoint.y - minDis
213+
&& parentNode.nodeData.y < desPoint.y + minDis) {
214+
isCollision = true;
215+
}
216+
}
217+
218+
// 如果触碰到父边界了,并且此节点的孩子节点还未完全遍历完,则可以继续遍历
219+
if (isCollision
220+
&& (!parentNode.leftNode.isVisited || !parentNode.rightNode.isVisited)) {
221+
TreeNode newNode;
222+
// 新建当前的小局部节点栈
223+
Stack<TreeNode> otherStack = new Stack<>();
224+
// 从parentNode的树以下继续寻找
225+
findedNearestLeafNode(desPoint, parentNode, otherStack);
226+
newNode = dfsSearchNodes(otherStack, desPoint, otherStack.pop());
227+
228+
dis = newNode.nodeData.ouDistance(desPoint);
229+
if (dis < minDis) {
230+
nearestNode = newNode;
231+
}
232+
}
233+
234+
// 继续往上回溯
235+
nearestNode = dfsSearchNodes(stack, desPoint, nearestNode);
236+
237+
return nearestNode;
238+
}
239+
240+
/**
241+
* 找到与所给定节点的最近的叶子节点
242+
*
243+
* @param p
244+
* 待比较节点
245+
* @param node
246+
* 当前搜索到的节点
247+
* @param stack
248+
* 遍历过的节点栈
249+
*/
250+
private void findedNearestLeafNode(Point p, TreeNode node,
251+
Stack<TreeNode> stack) {
252+
// 分割方向
253+
int splitDic;
254+
255+
// 将遍历过的节点加入栈中
256+
stack.push(node);
257+
// 标记为访问过
258+
node.isVisited = true;
259+
// 如果此节点没有左右孩子节点说明已经是叶子节点了
260+
if (node.leftNode == null && node.rightNode == null) {
261+
return;
262+
}
263+
264+
splitDic = node.spilt;
265+
// 选择一个符合分割范围的节点继续递归搜寻
266+
if ((splitDic == DIRECTION_X && p.x < node.nodeData.x)
267+
|| (splitDic == DIRECTION_Y && p.y < node.nodeData.y)) {
268+
if (!node.leftNode.isVisited) {
269+
findedNearestLeafNode(p, node.leftNode, stack);
270+
} else {
271+
// 如果左孩子节点已经访问过,则访问另一边
272+
findedNearestLeafNode(p, node.rightNode, stack);
273+
}
274+
} else if ((splitDic == DIRECTION_X && p.x > node.nodeData.x)
275+
|| (splitDic == DIRECTION_Y && p.y > node.nodeData.y)) {
276+
if (!node.rightNode.isVisited) {
277+
findedNearestLeafNode(p, node.rightNode, stack);
278+
} else {
279+
// 如果右孩子节点已经访问过,则访问另一边
280+
findedNearestLeafNode(p, node.leftNode, stack);
281+
}
282+
}
283+
}
284+
285+
/**
286+
* 根据给定的数据点通过计算反差选择的分割点
287+
*
288+
* @param datas
289+
* 部分的集合点集合
290+
* @return
291+
*/
292+
private int selectSplitDrc(ArrayList<Point> datas) {
293+
int direction = 0;
294+
double avgX = 0;
295+
double avgY = 0;
296+
double varianceX = 0;
297+
double varianceY = 0;
298+
299+
for (Point p : datas) {
300+
avgX += p.x;
301+
avgY += p.y;
302+
}
303+
304+
avgX /= datas.size();
305+
avgY /= datas.size();
306+
307+
for (Point p : datas) {
308+
varianceX += (p.x - avgX) * (p.x - avgX);
309+
varianceY += (p.y - avgY) * (p.y - avgY);
310+
}
311+
312+
// 求最后的方差
313+
varianceX /= datas.size();
314+
varianceY /= datas.size();
315+
316+
// 通过比较方差的大小决定分割方向,选择波动较大的进行划分
317+
direction = varianceX > varianceY ? DIRECTION_X : DIRECTION_Y;
318+
319+
return direction;
320+
}
321+
322+
/**
323+
* 根据坐标点方位进行排序,选出中间点的坐标数据
324+
*
325+
* @param datas
326+
* 数据点集合
327+
* @param dir
328+
* 排序的坐标方向
329+
*/
330+
private Point getMiddlePoint(ArrayList<Point> datas, int dir) {
331+
int index = 0;
332+
Point middlePoint;
333+
334+
index = datas.size() / 2;
335+
if (dir == DIRECTION_X) {
336+
Collections.sort(datas, new Comparator<Point>() {
337+
338+
@Override
339+
public int compare(Point o1, Point o2) {
340+
// TODO Auto-generated method stub
341+
return o1.x.compareTo(o2.x);
342+
}
343+
});
344+
} else {
345+
Collections.sort(datas, new Comparator<Point>() {
346+
347+
@Override
348+
public int compare(Point o1, Point o2) {
349+
// TODO Auto-generated method stub
350+
return o1.y.compareTo(o2.y);
351+
}
352+
});
353+
}
354+
355+
// 取出中位数
356+
middlePoint = datas.get(index);
357+
358+
return middlePoint;
359+
}
360+
361+
/**
362+
* 根据方向得到原部分节点集合左侧的数据点
363+
*
364+
* @param datas
365+
* 原始数据点集合
366+
* @param nodeData
367+
* 数据矢量
368+
* @param dir
369+
* 分割方向
370+
* @return
371+
*/
372+
private ArrayList<Point> getLeftSideDatas(ArrayList<Point> datas,
373+
Point nodeData, int dir) {
374+
ArrayList<Point> leftSideDatas = new ArrayList<>();
375+
376+
for (Point p : datas) {
377+
if (dir == DIRECTION_X && p.x < nodeData.x) {
378+
leftSideDatas.add(p);
379+
} else if (dir == DIRECTION_Y && p.y < nodeData.y) {
380+
leftSideDatas.add(p);
381+
}
382+
}
383+
384+
return leftSideDatas;
385+
}
386+
}

0 commit comments

Comments
 (0)