
de.citec.tcs.alignment.learning.KNNClassifier Maven / Gradle / Ivy
/*
* TCS Alignment Toolbox
*
* Copyright (C) 2013-2015
* Benjamin Paaßen, Georg Zentgraf
* AG Theoretical Computer Science
* Centre of Excellence Cognitive Interaction Technology (CITEC)
* University of Bielefeld
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see .
*/
package de.citec.tcs.alignment.learning;
import java.util.TreeSet;
/**
* This implements a very basic k-nearest neighbor classifier: Given a set of
* data points and a (trained) AlignmentAlgorithm it can determine the k next
* datapoints for a given new datapoint and calculate the label for it based on
* the majority of votes.
*
* @author Benjamin Paassen - bpaassen(at)techfak.uni-bielefeld.de
*/
public class KNNClassifier extends AbstractDissimilarityClassifier {
private int K = 5;
public KNNClassifier(int[] trainingLabels) {
super(trainingLabels);
}
/**
* Returns the number of nearest neighbors that is considered by this
* classifier.
*
* @return the number of nearest neighbors that is considered by this
* classifier.
*/
public int getK() {
return K;
}
/**
* Sets the number of nearest neighbors that is considered by this
* classifier.
*
* @param K the number of nearest neighbors that is considered by this
* classifier.
*/
public void setK(int K) {
this.K = K;
}
/**
* Returns the number of data points within the k nearest neighbors to the
* reference data point, that had some given class label.
* The order of class labels in the output array is the same as in the
* TreeSet returned by the getLabels method.
*
* @param distances the distances of the reference data point to all
* training data points considered by this classifier.
*
* @return the number of data points within the k nearest neighbors to the
* reference data point, that had some given class label.
*/
public int[] calculateVotes(double[] distances) {
return calculateVotes(-1, distances);
}
/**
* Returns the number of data points within the k nearest neighbors to the
* reference data point, that had some given class label.
* The order of class labels in this output array is the same as in the
* TreeSet returned by the getLabels method.
*
* @param i the index of the reference data point itself.
* @param distances the distances of the reference data point to all
* training data points considered by this classifier.
*
* @return the number of data points within the k nearest neighbors to the
* reference data point, that had some given class label.
*/
public int[] calculateVotes(int i, double[] distances) {
if (distances.length != getTrainingLabels().length) {
throw new IllegalArgumentException(
"Expected one distance value for each data point, but got "
+ distances.length + " distance values for "
+ getTrainingLabels().length + " data points!");
}
// retrieve the K nearest neighbors.
TreeSet kNearest = DistanceIndex.getKNearest(K, i, distances);
// then do the votes
final int[] votes = new int[getLabels().size()];
for (final DistanceIndex other : kNearest) {
final int labelIdx = getLabels().headSet(getTrainingLabels()[other.index]).size();
votes[labelIdx]++;
}
return votes;
}
/**
* {@inheritDoc }
*/
@Override
public int classifyTest(double[] distances) {
return classifyTraining(-1, distances);
}
/**
* {@inheritDoc }
*/
@Override
public int classifyTraining(int i, double[] distances) {
int[] votes = calculateVotes(i, distances);
int c = 0;
int best_label = -1;
int most_votes = 0;
for (final Integer label : getLabels()) {
if (votes[c] > most_votes) {
most_votes = votes[c];
best_label = label;
}
c++;
}
return best_label;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy