|
| 1 | +package DataMining_KDTree; |
| 2 | + |
| 3 | +import java.io.BufferedReader; |
| 4 | +import java.io.File; |
| 5 | +import java.io.FileReader; |
| 6 | +import java.io.IOException; |
| 7 | +import java.util.ArrayList; |
| 8 | +import java.util.Collections; |
| 9 | +import java.util.Comparator; |
| 10 | +import java.util.Stack; |
| 11 | + |
| 12 | +/** |
| 13 | + * KD树-k维空间关键数据检索算法工具类 |
| 14 | + * |
| 15 | + * @author lyq |
| 16 | + * |
| 17 | + */ |
| 18 | +public class KDTreeTool { |
| 19 | + // 空间平面的方向 |
| 20 | + public static final int DIRECTION_X = 0; |
| 21 | + public static final int DIRECTION_Y = 1; |
| 22 | + |
| 23 | + // 输入的测试数据坐标点文件 |
| 24 | + private String filePath; |
| 25 | + // 原始所有数据点数据 |
| 26 | + private ArrayList<Point> totalDatas; |
| 27 | + // KD树根节点 |
| 28 | + private TreeNode rootNode; |
| 29 | + |
| 30 | + public KDTreeTool(String filePath) { |
| 31 | + this.filePath = filePath; |
| 32 | + |
| 33 | + readDataFile(); |
| 34 | + } |
| 35 | + |
| 36 | + /** |
| 37 | + * 从文件中读取数据 |
| 38 | + */ |
| 39 | + private void readDataFile() { |
| 40 | + File file = new File(filePath); |
| 41 | + ArrayList<String[]> dataArray = new ArrayList<String[]>(); |
| 42 | + |
| 43 | + try { |
| 44 | + BufferedReader in = new BufferedReader(new FileReader(file)); |
| 45 | + String str; |
| 46 | + String[] tempArray; |
| 47 | + while ((str = in.readLine()) != null) { |
| 48 | + tempArray = str.split(" "); |
| 49 | + dataArray.add(tempArray); |
| 50 | + } |
| 51 | + in.close(); |
| 52 | + } catch (IOException e) { |
| 53 | + e.getStackTrace(); |
| 54 | + } |
| 55 | + |
| 56 | + Point p; |
| 57 | + totalDatas = new ArrayList<>(); |
| 58 | + for (String[] array : dataArray) { |
| 59 | + p = new Point(array[0], array[1]); |
| 60 | + totalDatas.add(p); |
| 61 | + } |
| 62 | + } |
| 63 | + |
| 64 | + /** |
| 65 | + * 创建KD树 |
| 66 | + * |
| 67 | + * @return |
| 68 | + */ |
| 69 | + public TreeNode createKDTree() { |
| 70 | + ArrayList<Point> copyDatas; |
| 71 | + |
| 72 | + rootNode = new TreeNode(); |
| 73 | + // 根据节点开始时所表示的空间时无限大的 |
| 74 | + rootNode.range = new Range(); |
| 75 | + copyDatas = (ArrayList<Point>) totalDatas.clone(); |
| 76 | + recusiveConstructNode(rootNode, copyDatas); |
| 77 | + |
| 78 | + return rootNode; |
| 79 | + } |
| 80 | + |
| 81 | + /** |
| 82 | + * 递归进行KD树的构造 |
| 83 | + * |
| 84 | + * @param node |
| 85 | + * 当前正在构造的节点 |
| 86 | + * @param datas |
| 87 | + * 该节点对应的正在处理的数据 |
| 88 | + * @return |
| 89 | + */ |
| 90 | + private void recusiveConstructNode(TreeNode node, ArrayList<Point> datas) { |
| 91 | + int direction = 0; |
| 92 | + ArrayList<Point> leftSideDatas; |
| 93 | + ArrayList<Point> rightSideDatas; |
| 94 | + Point p; |
| 95 | + TreeNode leftNode; |
| 96 | + TreeNode rightNode; |
| 97 | + Range range; |
| 98 | + Range range2; |
| 99 | + |
| 100 | + // 如果划分的数据点集合只有1个数据,则不再划分 |
| 101 | + if (datas.size() == 1) { |
| 102 | + node.nodeData = datas.get(0); |
| 103 | + return; |
| 104 | + } |
| 105 | + |
| 106 | + // 首先在当前的数据点集合中进行分割方向的选择 |
| 107 | + direction = selectSplitDrc(datas); |
| 108 | + // 根据方向取出中位数点作为数据矢量 |
| 109 | + p = getMiddlePoint(datas, direction); |
| 110 | + |
| 111 | + node.spilt = direction; |
| 112 | + node.nodeData = p; |
| 113 | + |
| 114 | + leftSideDatas = getLeftSideDatas(datas, p, direction); |
| 115 | + datas.removeAll(leftSideDatas); |
| 116 | + // 还要去掉自身 |
| 117 | + datas.remove(p); |
| 118 | + rightSideDatas = datas; |
| 119 | + |
| 120 | + if (leftSideDatas.size() > 0) { |
| 121 | + leftNode = new TreeNode(); |
| 122 | + leftNode.parentNode = node; |
| 123 | + range2 = Range.initLeftRange(p, direction); |
| 124 | + // 获取父节点的空间矢量,进行交集运算做范围拆分 |
| 125 | + range = node.range.crossOperation(range2); |
| 126 | + leftNode.range = range; |
| 127 | + |
| 128 | + node.leftNode = leftNode; |
| 129 | + recusiveConstructNode(leftNode, leftSideDatas); |
| 130 | + } |
| 131 | + |
| 132 | + if (rightSideDatas.size() > 0) { |
| 133 | + rightNode = new TreeNode(); |
| 134 | + rightNode.parentNode = node; |
| 135 | + range2 = Range.initRightRange(p, direction); |
| 136 | + // 获取父节点的空间矢量,进行交集运算做范围拆分 |
| 137 | + range = node.range.crossOperation(range2); |
| 138 | + rightNode.range = range; |
| 139 | + |
| 140 | + node.rightNode = rightNode; |
| 141 | + recusiveConstructNode(rightNode, rightSideDatas); |
| 142 | + } |
| 143 | + } |
| 144 | + |
| 145 | + /** |
| 146 | + * 搜索出给定数据点的最近点 |
| 147 | + * |
| 148 | + * @param p |
| 149 | + * 待比较坐标点 |
| 150 | + */ |
| 151 | + public Point searchNearestData(Point p) { |
| 152 | + // 节点距离给定数据点的距离 |
| 153 | + TreeNode nearestNode = null; |
| 154 | + // 用栈记录遍历过的节点 |
| 155 | + Stack<TreeNode> stackNodes; |
| 156 | + |
| 157 | + stackNodes = new Stack<>(); |
| 158 | + findedNearestLeafNode(p, rootNode, stackNodes); |
| 159 | + |
| 160 | + // 取出叶子节点,作为当前找到的最近节点 |
| 161 | + nearestNode = stackNodes.pop(); |
| 162 | + nearestNode = dfsSearchNodes(stackNodes, p, nearestNode); |
| 163 | + |
| 164 | + return nearestNode.nodeData; |
| 165 | + } |
| 166 | + |
| 167 | + /** |
| 168 | + * 深度优先的方式进行最近点的查找 |
| 169 | + * |
| 170 | + * @param stack |
| 171 | + * KD树节点栈 |
| 172 | + * @param desPoint |
| 173 | + * 给定的数据点 |
| 174 | + * @param nearestNode |
| 175 | + * 当前找到的最近节点 |
| 176 | + * @return |
| 177 | + */ |
| 178 | + private TreeNode dfsSearchNodes(Stack<TreeNode> stack, Point desPoint, |
| 179 | + TreeNode nearestNode) { |
| 180 | + // 是否碰到父节点边界 |
| 181 | + boolean isCollision; |
| 182 | + double minDis; |
| 183 | + double dis; |
| 184 | + TreeNode parentNode; |
| 185 | + |
| 186 | + // 如果栈内节点已经全部弹出,则遍历结束 |
| 187 | + if (stack.isEmpty()) { |
| 188 | + return nearestNode; |
| 189 | + } |
| 190 | + |
| 191 | + // 获取父节点 |
| 192 | + parentNode = stack.pop(); |
| 193 | + |
| 194 | + minDis = desPoint.ouDistance(nearestNode.nodeData); |
| 195 | + dis = desPoint.ouDistance(parentNode.nodeData); |
| 196 | + |
| 197 | + // 如果与当前回溯到的父节点距离更短,则搜索到的节点进行更新 |
| 198 | + if (dis < minDis) { |
| 199 | + minDis = dis; |
| 200 | + nearestNode = parentNode; |
| 201 | + } |
| 202 | + |
| 203 | + // 默认没有碰撞到 |
| 204 | + isCollision = false; |
| 205 | + // 判断是否触碰到了父节点的空间分割线 |
| 206 | + if (parentNode.spilt == DIRECTION_X) { |
| 207 | + if (parentNode.nodeData.x > desPoint.x - minDis |
| 208 | + && parentNode.nodeData.x < desPoint.x + minDis) { |
| 209 | + isCollision = true; |
| 210 | + } |
| 211 | + } else { |
| 212 | + if (parentNode.nodeData.y > desPoint.y - minDis |
| 213 | + && parentNode.nodeData.y < desPoint.y + minDis) { |
| 214 | + isCollision = true; |
| 215 | + } |
| 216 | + } |
| 217 | + |
| 218 | + // 如果触碰到父边界了,并且此节点的孩子节点还未完全遍历完,则可以继续遍历 |
| 219 | + if (isCollision |
| 220 | + && (!parentNode.leftNode.isVisited || !parentNode.rightNode.isVisited)) { |
| 221 | + TreeNode newNode; |
| 222 | + // 新建当前的小局部节点栈 |
| 223 | + Stack<TreeNode> otherStack = new Stack<>(); |
| 224 | + // 从parentNode的树以下继续寻找 |
| 225 | + findedNearestLeafNode(desPoint, parentNode, otherStack); |
| 226 | + newNode = dfsSearchNodes(otherStack, desPoint, otherStack.pop()); |
| 227 | + |
| 228 | + dis = newNode.nodeData.ouDistance(desPoint); |
| 229 | + if (dis < minDis) { |
| 230 | + nearestNode = newNode; |
| 231 | + } |
| 232 | + } |
| 233 | + |
| 234 | + // 继续往上回溯 |
| 235 | + nearestNode = dfsSearchNodes(stack, desPoint, nearestNode); |
| 236 | + |
| 237 | + return nearestNode; |
| 238 | + } |
| 239 | + |
| 240 | + /** |
| 241 | + * 找到与所给定节点的最近的叶子节点 |
| 242 | + * |
| 243 | + * @param p |
| 244 | + * 待比较节点 |
| 245 | + * @param node |
| 246 | + * 当前搜索到的节点 |
| 247 | + * @param stack |
| 248 | + * 遍历过的节点栈 |
| 249 | + */ |
| 250 | + private void findedNearestLeafNode(Point p, TreeNode node, |
| 251 | + Stack<TreeNode> stack) { |
| 252 | + // 分割方向 |
| 253 | + int splitDic; |
| 254 | + |
| 255 | + // 将遍历过的节点加入栈中 |
| 256 | + stack.push(node); |
| 257 | + // 标记为访问过 |
| 258 | + node.isVisited = true; |
| 259 | + // 如果此节点没有左右孩子节点说明已经是叶子节点了 |
| 260 | + if (node.leftNode == null && node.rightNode == null) { |
| 261 | + return; |
| 262 | + } |
| 263 | + |
| 264 | + splitDic = node.spilt; |
| 265 | + // 选择一个符合分割范围的节点继续递归搜寻 |
| 266 | + if ((splitDic == DIRECTION_X && p.x < node.nodeData.x) |
| 267 | + || (splitDic == DIRECTION_Y && p.y < node.nodeData.y)) { |
| 268 | + if (!node.leftNode.isVisited) { |
| 269 | + findedNearestLeafNode(p, node.leftNode, stack); |
| 270 | + } else { |
| 271 | + // 如果左孩子节点已经访问过,则访问另一边 |
| 272 | + findedNearestLeafNode(p, node.rightNode, stack); |
| 273 | + } |
| 274 | + } else if ((splitDic == DIRECTION_X && p.x > node.nodeData.x) |
| 275 | + || (splitDic == DIRECTION_Y && p.y > node.nodeData.y)) { |
| 276 | + if (!node.rightNode.isVisited) { |
| 277 | + findedNearestLeafNode(p, node.rightNode, stack); |
| 278 | + } else { |
| 279 | + // 如果右孩子节点已经访问过,则访问另一边 |
| 280 | + findedNearestLeafNode(p, node.leftNode, stack); |
| 281 | + } |
| 282 | + } |
| 283 | + } |
| 284 | + |
| 285 | + /** |
| 286 | + * 根据给定的数据点通过计算反差选择的分割点 |
| 287 | + * |
| 288 | + * @param datas |
| 289 | + * 部分的集合点集合 |
| 290 | + * @return |
| 291 | + */ |
| 292 | + private int selectSplitDrc(ArrayList<Point> datas) { |
| 293 | + int direction = 0; |
| 294 | + double avgX = 0; |
| 295 | + double avgY = 0; |
| 296 | + double varianceX = 0; |
| 297 | + double varianceY = 0; |
| 298 | + |
| 299 | + for (Point p : datas) { |
| 300 | + avgX += p.x; |
| 301 | + avgY += p.y; |
| 302 | + } |
| 303 | + |
| 304 | + avgX /= datas.size(); |
| 305 | + avgY /= datas.size(); |
| 306 | + |
| 307 | + for (Point p : datas) { |
| 308 | + varianceX += (p.x - avgX) * (p.x - avgX); |
| 309 | + varianceY += (p.y - avgY) * (p.y - avgY); |
| 310 | + } |
| 311 | + |
| 312 | + // 求最后的方差 |
| 313 | + varianceX /= datas.size(); |
| 314 | + varianceY /= datas.size(); |
| 315 | + |
| 316 | + // 通过比较方差的大小决定分割方向,选择波动较大的进行划分 |
| 317 | + direction = varianceX > varianceY ? DIRECTION_X : DIRECTION_Y; |
| 318 | + |
| 319 | + return direction; |
| 320 | + } |
| 321 | + |
| 322 | + /** |
| 323 | + * 根据坐标点方位进行排序,选出中间点的坐标数据 |
| 324 | + * |
| 325 | + * @param datas |
| 326 | + * 数据点集合 |
| 327 | + * @param dir |
| 328 | + * 排序的坐标方向 |
| 329 | + */ |
| 330 | + private Point getMiddlePoint(ArrayList<Point> datas, int dir) { |
| 331 | + int index = 0; |
| 332 | + Point middlePoint; |
| 333 | + |
| 334 | + index = datas.size() / 2; |
| 335 | + if (dir == DIRECTION_X) { |
| 336 | + Collections.sort(datas, new Comparator<Point>() { |
| 337 | + |
| 338 | + @Override |
| 339 | + public int compare(Point o1, Point o2) { |
| 340 | + // TODO Auto-generated method stub |
| 341 | + return o1.x.compareTo(o2.x); |
| 342 | + } |
| 343 | + }); |
| 344 | + } else { |
| 345 | + Collections.sort(datas, new Comparator<Point>() { |
| 346 | + |
| 347 | + @Override |
| 348 | + public int compare(Point o1, Point o2) { |
| 349 | + // TODO Auto-generated method stub |
| 350 | + return o1.y.compareTo(o2.y); |
| 351 | + } |
| 352 | + }); |
| 353 | + } |
| 354 | + |
| 355 | + // 取出中位数 |
| 356 | + middlePoint = datas.get(index); |
| 357 | + |
| 358 | + return middlePoint; |
| 359 | + } |
| 360 | + |
| 361 | + /** |
| 362 | + * 根据方向得到原部分节点集合左侧的数据点 |
| 363 | + * |
| 364 | + * @param datas |
| 365 | + * 原始数据点集合 |
| 366 | + * @param nodeData |
| 367 | + * 数据矢量 |
| 368 | + * @param dir |
| 369 | + * 分割方向 |
| 370 | + * @return |
| 371 | + */ |
| 372 | + private ArrayList<Point> getLeftSideDatas(ArrayList<Point> datas, |
| 373 | + Point nodeData, int dir) { |
| 374 | + ArrayList<Point> leftSideDatas = new ArrayList<>(); |
| 375 | + |
| 376 | + for (Point p : datas) { |
| 377 | + if (dir == DIRECTION_X && p.x < nodeData.x) { |
| 378 | + leftSideDatas.add(p); |
| 379 | + } else if (dir == DIRECTION_Y && p.y < nodeData.y) { |
| 380 | + leftSideDatas.add(p); |
| 381 | + } |
| 382 | + } |
| 383 | + |
| 384 | + return leftSideDatas; |
| 385 | + } |
| 386 | +} |
0 commit comments