All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.neighbor.KDTree Maven / Gradle / Ivy

The newest version!
/*******************************************************************************
 * Copyright (c) 2010 Haifeng Li
 *   
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *  
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *******************************************************************************/
package smile.neighbor;

import java.util.List;
import smile.math.Math;
import smile.sort.HeapSelect;

/**
 * A KD-tree (short for k-dimensional tree) is a space-partitioning dataset
 * structure for organizing points in a k-dimensional space. KD-trees are
 * a useful dataset structure for nearest neighbor searches. The kd-tree is a
 * binary tree in which every node is a k-dimensional point. Every non-leaf
 * node generates a splitting hyperplane that divides the space into two
 * subspaces. Points left to the hyperplane represent the left sub-tree of
 * that node and the points right to the hyperplane by the right sub-tree.
 * The hyperplane direction is chosen in the following way: every node split
 * to sub-trees is associated with one of the k-dimensions, such that the
 * hyperplane is perpendicular to that dimension vector. So, for example, if
 * for a particular split the "x" axis is chosen, all points in the subtree
 * with a smaller "x" value than the node will appear in the left subtree and
 * all points with larger "x" value will be in the right sub tree.
 * 

* KD-trees are not suitable for efficiently finding the nearest neighbor * in high dimensional spaces. As a general rule, if the dimensionality is D, * then number of points in the dataset, N, should be N >> 2D. * Otherwise, when kd-trees are used with high-dimensional dataset, most of the * points in the tree will be evaluated and the efficiency is no better than * exhaustive search, and approximate nearest-neighbor methods should be used * instead. * * @param the type of data objects in the tree. * * @author Haifeng Li */ public class KDTree implements NearestNeighborSearch, KNNSearch, RNNSearch { /** * The root in the KD-tree. */ class Node { /** * Number of dataset stored in this node. */ int count; /** * The smallest point index stored in this node. */ int index; /** * The index of coordinate used to split this node. */ int split; /** * The cutoff used to split the specific coordinate. */ double cutoff; /** * The child node which values of split coordinate is less than the cutoff value. */ Node lower; /** * The child node which values of split coordinate is greater than or equal to the cutoff value. */ Node upper; /** * If the node is a leaf node. */ boolean isLeaf() { return lower == null && upper == null; } } /** * The keys of data objects. */ private double[][] keys; /** * The data objects. */ private E[] data; /** * The root node of KD-Tree. */ private Node root; /** * The index of objects in each nodes. */ private int[] index; /** * Whether to exclude query object self from the neighborhood. */ private boolean identicalExcluded = true; /** * Constructor. * @param key the keys of data objects. * @param data the data objects. */ public KDTree(double[][] key, E[] data) { if (key.length != data.length) { throw new IllegalArgumentException("The array size of keys and data are different."); } this.keys = key; this.data = data; int n = key.length; index = new int[n]; for (int i = 0; i < n; i++) { index[i] = i; } // Build the tree root = buildNode(0, n); } @Override public String toString() { return "KD-Tree"; } /** * Build a k-d tree from the given set of dataset. */ private Node buildNode(int begin, int end) { int d = keys[0].length; // Allocate the node Node node = new Node(); // Fill in basic info node.count = end - begin; node.index = begin; // Calculate the bounding box double[] lowerBound = new double[d]; double[] upperBound = new double[d]; for (int i = 0; i < d; i++) { lowerBound[i] = keys[index[begin]][i]; upperBound[i] = keys[index[begin]][i]; } for (int i = begin + 1; i < end; i++) { for (int j = 0; j < d; j++) { double c = keys[index[i]][j]; if (lowerBound[j] > c) { lowerBound[j] = c; } if (upperBound[j] < c) { upperBound[j] = c; } } } // Calculate bounding box stats double maxRadius = -1; for (int i = 0; i < d; i++) { double radius = (upperBound[i] - lowerBound[i]) / 2; if (radius > maxRadius) { maxRadius = radius; node.split = i; node.cutoff = (upperBound[i] + lowerBound[i]) / 2; } } // If the max spread is 0, make this a leaf node if (maxRadius == 0) { node.lower = node.upper = null; return node; } // Partition the dataset around the midpoint in this dimension. The // partitioning is done in-place by iterating from left-to-right and // right-to-left in the same way that partioning is done in quicksort. int i1 = begin, i2 = end - 1, size = 0; while (i1 <= i2) { boolean i1Good = (keys[index[i1]][node.split] < node.cutoff); boolean i2Good = (keys[index[i2]][node.split] >= node.cutoff); if (!i1Good && !i2Good) { int temp = index[i1]; index[i1] = index[i2]; index[i2] = temp; i1Good = i2Good = true; } if (i1Good) { i1++; size++; } if (i2Good) { i2--; } } // Create the child nodes node.lower = buildNode(begin, begin + size); node.upper = buildNode(begin + size, end); return node; } /** * Set if exclude query object self from the neighborhood. */ public KDTree setIdenticalExcluded(boolean excluded) { identicalExcluded = excluded; return this; } /** * Get whether if query object self be excluded from the neighborhood. */ public boolean isIdenticalExcluded() { return identicalExcluded; } /** * Returns the nearest neighbors of the given target starting from the give * tree node. * * @param q the query key. * @param node the root of subtree. * @param neighbor the current nearest neighbor. */ private void search(double[] q, Node node, Neighbor neighbor) { if (node.isLeaf()) { // look at all the instances in this leaf for (int idx = node.index; idx < node.index + node.count; idx++) { if (q == keys[index[idx]] && identicalExcluded) { continue; } double distance = Math.squaredDistance(q, keys[index[idx]]); if (distance < neighbor.distance) { neighbor.key = keys[index[idx]]; neighbor.value = data[index[idx]]; neighbor.index = index[idx]; neighbor.distance = distance; } } } else { Node nearer, further; double diff = q[node.split] - node.cutoff; if (diff < 0) { nearer = node.lower; further = node.upper; } else { nearer = node.upper; further = node.lower; } search(q, nearer, neighbor); // now look in further half if (neighbor.distance >= diff * diff) { search(q, further, neighbor); } } } /** * Returns (in the supplied heap object) the k nearest * neighbors of the given target starting from the give * tree node. * * @param q the query key. * @param node the root of subtree. * @param k the number of neighbors to find. * @param heap the heap object to store/update the kNNs found during the search. */ private void search(double[] q, Node node, HeapSelect> heap) { if (node.isLeaf()) { // look at all the instances in this leaf for (int idx = node.index; idx < node.index + node.count; idx++) { if (q == keys[index[idx]] && identicalExcluded) { continue; } double distance = Math.squaredDistance(q, keys[index[idx]]); Neighbor datum = heap.peek(); if (distance < datum.distance) { datum.distance = distance; datum.index = index[idx]; datum.key = keys[index[idx]]; datum.value = data[index[idx]]; heap.heapify(); } } } else { Node nearer, further; double diff = q[node.split] - node.cutoff; if (diff < 0) { nearer = node.lower; further = node.upper; } else { nearer = node.upper; further = node.lower; } search(q, nearer, heap); // now look in further half if (heap.peek().distance >= diff * diff) { search(q, further, heap); } } } /** * Returns the neighbors in the given range of search target from the give * tree node. * * @param q the query key. * @param node the root of subtree. * @param radius the radius of search range from target. * @param neighbors the list of found neighbors in the range. */ private void search(double[] q, Node node, double radius, List> neighbors) { if (node.isLeaf()) { // look at all the instances in this leaf for (int idx = node.index; idx < node.index + node.count; idx++) { if (q == keys[index[idx]] && identicalExcluded) { continue; } double distance = Math.distance(q, keys[index[idx]]); if (distance <= radius) { neighbors.add(new Neighbor(keys[index[idx]], data[index[idx]], index[idx], distance)); } } } else { Node nearer, further; double diff = q[node.split] - node.cutoff; if (diff < 0) { nearer = node.lower; further = node.upper; } else { nearer = node.upper; further = node.lower; } search(q, nearer, radius, neighbors); // now look in further half if (radius >= diff * diff) { search(q, further, radius, neighbors); } } } @Override public Neighbor nearest(double[] q) { Neighbor neighbor = new Neighbor(null, null, 0, Double.MAX_VALUE); search(q, root, neighbor); neighbor.distance = Math.sqrt(neighbor.distance); return neighbor; } @Override public Neighbor[] knn(double[] q, int k) { if (k <= 0) { throw new IllegalArgumentException("Invalid k: " + k); } if (k > keys.length) { throw new IllegalArgumentException("Neighbor array length is larger than the dataset size"); } Neighbor neighbor = new Neighbor(null, null, 0, Double.MAX_VALUE); @SuppressWarnings("unchecked") Neighbor[] neighbors = (Neighbor[]) java.lang.reflect.Array.newInstance(neighbor.getClass(), k); HeapSelect> heap = new HeapSelect>(neighbors); for (int i = 0; i < k; i++) { heap.add(neighbor); neighbor = new Neighbor(null, null, 0, Double.MAX_VALUE); } search(q, root, heap); heap.sort(); for (int i = 0; i < neighbors.length; i++) { neighbors[i].distance = Math.sqrt(neighbors[i].distance); } return neighbors; } @Override public void range(double[] q, double radius, List> neighbors) { if (radius <= 0.0) { throw new IllegalArgumentException("Invalid radius: " + radius); } search(q, root, radius, neighbors); } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy