All Downloads are FREE. Search and download functionalities are using the official Maven repository.

de.citec.tcs.alignment.learning.KNNClassifier Maven / Gradle / Ivy

/* 
 * TCS Alignment Toolbox
 * 
 * Copyright (C) 2013-2015
 * Benjamin Paaßen, Georg Zentgraf
 * AG Theoretical Computer Science
 * Centre of Excellence Cognitive Interaction Technology (CITEC)
 * University of Bielefeld
 * 
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as
 * published by the Free Software Foundation, either version 3 of the
 * License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 * 
 * You should have received a copy of the GNU Affero General Public License
 * along with this program.  If not, see .
 */
package de.citec.tcs.alignment.learning;

import java.util.TreeSet;

/**
 * This implements a very basic k-nearest neighbor classifier: Given a set of
 * data points and a (trained) AlignmentAlgorithm it can determine the k next
 * datapoints for a given new datapoint and calculate the label for it based on
 * the majority of votes.
 *
 * @author Benjamin Paassen - bpaassen(at)techfak.uni-bielefeld.de
 */
public class KNNClassifier extends AbstractDissimilarityClassifier {

	private int K = 5;

	public KNNClassifier(int[] trainingLabels) {
		super(trainingLabels);
	}

	/**
	 * Returns the number of nearest neighbors that is considered by this
	 * classifier.
	 *
	 * @return the number of nearest neighbors that is considered by this
	 *         classifier.
	 */
	public int getK() {
		return K;
	}

	/**
	 * Sets the number of nearest neighbors that is considered by this
	 * classifier.
	 *
	 * @param K the number of nearest neighbors that is considered by this
	 *          classifier.
	 */
	public void setK(int K) {
		this.K = K;
	}

	/**
	 * Returns the number of data points within the k nearest neighbors to the
	 * reference data point, that had some given class label.
	 * The order of class labels in the output array is the same as in the
	 * TreeSet returned by the getLabels method.
	 *
	 * @param distances the distances of the reference data point to all
	 *                  training data points considered by this classifier.
	 *
	 * @return the number of data points within the k nearest neighbors to the
	 *         reference data point, that had some given class label.
	 */
	public int[] calculateVotes(double[] distances) {
		return calculateVotes(-1, distances);
	}

	/**
	 * Returns the number of data points within the k nearest neighbors to the
	 * reference data point, that had some given class label.
	 * The order of class labels in this output array is the same as in the
	 * TreeSet returned by the getLabels method.
	 *
	 * @param i         the index of the reference data point itself.
	 * @param distances the distances of the reference data point to all
	 *                  training data points considered by this classifier.
	 *
	 * @return the number of data points within the k nearest neighbors to the
	 *         reference data point, that had some given class label.
	 */
	public int[] calculateVotes(int i, double[] distances) {
		if (distances.length != getTrainingLabels().length) {
			throw new IllegalArgumentException(
					"Expected one distance value for each data point, but got "
					+ distances.length + " distance values for "
					+ getTrainingLabels().length + " data points!");
		}
		// retrieve the K nearest neighbors.
		TreeSet kNearest = DistanceIndex.getKNearest(K, i, distances);
		// then do the votes
		final int[] votes = new int[getLabels().size()];
		for (final DistanceIndex other : kNearest) {
			final int labelIdx = getLabels().headSet(getTrainingLabels()[other.index]).size();
			votes[labelIdx]++;
		}
		return votes;
	}

	/**
	 * {@inheritDoc }
	 */
	@Override
	public int classifyTest(double[] distances) {
		return classifyTraining(-1, distances);
	}

	/**
	 * {@inheritDoc }
	 */
	@Override
	public int classifyTraining(int i, double[] distances) {
		int[] votes = calculateVotes(i, distances);
		int c = 0;
		int best_label = -1;
		int most_votes = 0;
		for (final Integer label : getLabels()) {
			if (votes[c] > most_votes) {
				most_votes = votes[c];
				best_label = label;
			}
			c++;
		}
		return best_label;
	}

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy