smile.classification.KNN Maven / Gradle / Ivy
/*
* Copyright (c) 2010-2021 Haifeng Li. All rights reserved.
*
* Smile is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Smile is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Smile. If not, see .
*/
package smile.classification;
import java.io.Serial;
import java.util.Arrays;
import smile.math.MathEx;
import smile.math.distance.Distance;
import smile.math.distance.EuclideanDistance;
import smile.math.distance.Metric;
import smile.neighbor.CoverTree;
import smile.neighbor.KDTree;
import smile.neighbor.KNNSearch;
import smile.neighbor.LinearSearch;
import smile.neighbor.Neighbor;
/**
* K-nearest neighbor classifier. The k-nearest neighbor algorithm (k-NN) is
* a method for classifying objects by a majority vote of its neighbors,
* with the object being assigned to the class most common amongst its k
* nearest neighbors (k is a positive integer, typically small).
* k-NN is a type of instance-based learning, or lazy learning where the
* function is only approximated locally and all computation
* is deferred until classification.
*
* The best choice of k depends upon the data; generally, larger values of
* k reduce the effect of noise on the classification, but make boundaries
* between classes less distinct. A good k can be selected by various
* heuristic techniques, e.g. cross-validation. In binary problems, it is
* helpful to choose k to be an odd number as this avoids tied votes.
*
* A drawback to the basic majority voting classification is that the classes
* with the more frequent instances tend to dominate the prediction of the
* new object, as they tend to come up in the k nearest neighbors when
* the neighbors are computed due to their large number. One way to overcome
* this problem is to weight the classification taking into account the
* distance from the test point to each of its k nearest neighbors.
*
* Often, the classification accuracy of k-NN can be improved significantly
* if the distance metric is learned with specialized algorithms such as
* Large Margin Nearest Neighbor or Neighborhood Components Analysis.
*
* Nearest neighbor rules in effect compute the decision boundary in an
* implicit manner. It is also possible to compute the decision boundary
* itself explicitly, and to do so in an efficient manner so that the
* computational complexity is a function of the boundary complexity.
*
* The nearest neighbor algorithm has some strong consistency results. As
* the amount of data approaches infinity, the algorithm is guaranteed to
* yield an error rate no worse than twice the Bayes error rate (the minimum
* achievable error rate given the distribution of the data). k-NN is
* guaranteed to approach the Bayes error rate, for some value of k (where k
* increases as a function of the number of data points).
*
* @param the data type of model input objects.
*
* @author Haifeng Li
*/
public class KNN extends AbstractClassifier {
@Serial
private static final long serialVersionUID = 2L;
/**
* The data structure for nearest neighbor search.
*/
private final KNNSearch knn;
/**
* The labels of training sample.
*/
private final int[] y;
/**
* The number of neighbors for decision.
*/
private final int k;
/**
* Constructor.
* @param knn k-nearest neighbor search data structure of training instances.
* @param y training labels.
* @param k the number of neighbors for classification.
*/
public KNN(KNNSearch knn, int[] y, int k) {
super(y);
this.knn = knn;
this.k = k;
this.y = y;
}
/**
* Fits the 1-NN classifier.
* @param x training samples.
* @param y training labels.
* @param distance the distance function.
* @param the data type.
* @return the model.
*/
public static KNN fit(T[] x, int[] y, Distance distance) {
return fit(x, y, 1, distance);
}
/**
* Fits the K-NN classifier.
* @param k the number of neighbors.
* @param x training samples.
* @param y training labels.
* @param distance the distance function.
* @param the data type.
* @return the model.
*/
public static KNN fit(T[] x, int[] y, int k, Distance distance) {
if (x.length != y.length) {
throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
}
if (k < 1) {
throw new IllegalArgumentException("Illegal k = " + k);
}
KNNSearch knn;
if (distance instanceof Metric metric) {
knn = CoverTree.of(x, metric);
} else {
knn = LinearSearch.of(x, distance);
}
return new KNN<>(knn, y, k);
}
/**
* Fits the 1-NN classifier.
* @param x training samples.
* @param y training labels.
* @return the model.
*/
public static KNN fit(double[][] x, int[] y) {
return fit(x, y, 1);
}
/**
* Fits the K-NN classifier.
* @param k the number of neighbors for classification.
* @param x training samples.
* @param y training labels.
* @return the model.
*/
public static KNN fit(double[][] x, int[] y, int k) {
if (x.length != y.length) {
throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
}
if (k < 1) {
throw new IllegalArgumentException("Illegal k = " + k);
}
KNNSearch knn;
if (x[0].length < 10) {
knn = KDTree.of(x);
} else {
knn = CoverTree.of(x, new EuclideanDistance());
}
return new KNN<>(knn, y, k);
}
@Override
public int predict(T x) {
Neighbor[] neighbors = knn.search(x, k);
if (k == 1) {
if (neighbors[0] == null) {
throw new IllegalStateException("No neighbor found.");
}
return y[neighbors[0].index];
}
int[] count = new int[classes.size()];
for (Neighbor neighbor : neighbors) {
if (neighbor != null) {
count[classes.indexOf(y[neighbor.index])]++;
}
}
int y = MathEx.whichMax(count);
if (count[y] == 0) {
throw new IllegalStateException("No neighbor found.");
}
return classes.valueOf(y);
}
@Override
public boolean soft() {
return true;
}
@Override
public int predict(T x, double[] posteriori) {
Neighbor[] neighbors = knn.search(x, k);
if (k == 1) {
if (neighbors[0] == null) {
throw new IllegalStateException("No neighbor found.");
}
Arrays.fill(posteriori, 0.0);
posteriori[classes.indexOf(y[neighbors[0].index])] = 1.0;
return y[neighbors[0].index];
}
int[] count = new int[classes.size()];
for (int i = 0; i < k; i++) {
count[classes.indexOf(y[neighbors[i].index])]++;
}
int y = MathEx.whichMax(count);
if (count[y] == 0) {
throw new IllegalStateException("No neighbor found.");
}
for (int i = 0; i < count.length; i++) {
posteriori[i] = (double) count[i] / k;
}
return classes.valueOf(y);
}
}