All Downloads are FREE. Search and download functionalities are using the official Maven repository.

de.citec.tcs.alignment.learning.LMNNClassifier Maven / Gradle / Ivy

Go to download

This module is a custom implementation of the Large Margin Nearest Neighbor classification scheme of Weinberger, Saul, et al. (2009). It contains an implementation of the k-nearest neighbor and LMNN classifier as well as (most importantly) gradient calculation schemes on the LMNN cost function given a sequential data set and a user-choice of alignment algorithm. This enables users to learn parameters of the alignment distance in question using a gradient descent on the LMNN cost function. More information on this approach can be found in the Masters Thesis "Adaptive Affine Sequence Alignment Using Algebraic Dynamic Programming"

The newest version!
/* 
 * TCS Alignment Toolbox Version 3
 * 
 * Copyright (C) 2016
 * Benjamin Paaßen
 * AG Theoretical Computer Science
 * Centre of Excellence Cognitive Interaction Technology (CITEC)
 * University of Bielefeld
 * 
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as
 * published by the Free Software Foundation, either version 3 of the
 * License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 * 
 * You should have received a copy of the GNU Affero General Public License
 * along with this program.  If not, see .
 */

package de.citec.tcs.alignment.learning;

import java.util.TreeSet;

/**
 * This implements a Large Margin Nearest Neighbor classifier as suggested by
 * Weinberger, Saul et al. (2009). The classification is made based on the
 * LMNN cost function: We calculate the cost function for the given data point
 * for each possible class and take the class label with the lowest cost
 * function value.
 *
 * @author Benjamin Paassen - bpaassen(at)techfak.uni-bielefeld.de
 */
public class LMNNClassifier extends AbstractDissimilarityClassifier {

	private int K = 5;
	private double margin = 0.01;

	public LMNNClassifier(int[] trainingLabels) {
		super(trainingLabels);
	}

	/**
	 * Returns the number of considered nearest neighbors in the LMNN cost
	 * function.
	 *
	 * @return the number of considered nearest neighbors in the LMNN cost
	 *         function.
	 */
	public int getK() {
		return K;
	}

	/**
	 * Sets the number of considered nearest neighbors in the LMNN cost
	 * function.
	 *
	 * @param K the number of considered nearest neighbors in the LMNN cost
	 *          function.
	 */
	public void setK(int K) {
		this.K = K;
	}

	/**
	 * Returns the margin of safety that is required by the LMNN cost function.
	 *
	 * @return the margin of safety that is required by the LMNN cost function.
	 */
	public double getMargin() {
		return margin;
	}

	/**
	 * Sets the margin of safety that is required by the LMNN cost function.
	 *
	 * @param margin the margin of safety that is required by the LMNN cost
	 *               function.
	 */
	public void setMargin(double margin) {
		this.margin = margin;
	}

	@Override
	public int classifyTest(double[] distances) {
		int minLabel = -1;
		double minCost = Double.POSITIVE_INFINITY;
		for (final Integer label : getLabels()) {
			final double cost = calculateLMNNCostFunctionTest(label, distances);
			if (cost < minCost) {
				minLabel = label;
				minCost = cost;
			}
		}
		return minLabel;
	}

	@Override
	public int classifyTraining(int i, double[] distances) {
		int minLabel = -1;
		double minCost = Double.POSITIVE_INFINITY;
		for (final Integer label : getLabels()) {
			final double cost = calculateLMNNCostFunctionTraining(i, label, distances);
			if (cost < minCost) {
				minLabel = label;
				minCost = cost;
			}
		}
		return minLabel;
	}

	/**
	 * Calculates the value of the LMNN cost function for the given data point,
	 * assuming that it belongs to the class with the given label and that it
	 * has the given distances to all training data points.
	 *
	 * @param i         the index of the data point in the training set.
	 * @param label     the supposed label of the data point.
	 * @param distances the distances of the data point to all training data
	 *                  points, including the self-distance.
	 *
	 * @return The value of the LMNN cost function for the data point.
	 */
	public double calculateLMNNCostFunctionTraining(int i, int label, double[] distances) {
		return calculateLMNNCostFunction(
				DistanceIndex.getTargetNeighborsTraining(
						K, i, label, distances, getTrainingLabels()),
				distances);
	}

	/**
	 * Calculates the value of the LMNN cost function for the given data point,
	 * assuming that it belongs to the class with the given label and that it
	 * has the given distances to all training data points.
	 *
	 * This assumes that the data point is not part of the training data set. If
	 * it is, use calculateLMNNCostFunctionTraining.
	 *
	 * @param label     the supposed label of the data point.
	 * @param distances the distances of the data point to all training data
	 *                  points, not including the self-distance.
	 *
	 * @return The value of the LMNN cost function for the data point.
	 */
	public double calculateLMNNCostFunctionTest(int label, double[] distances) {
		return calculateLMNNCostFunction(
				DistanceIndex.getTargetNeighborsTest(
						K, label, distances, getTrainingLabels()),
				distances);
	}

	/**
	 * Calculates the LMNN cost function given the target neighbors of some data
	 * point and its distances to the training data points.
	 *
	 * Given the set of target neighbors N and the set of imposters I (look up
	 * in class DistanceIndex for more detailed explanations of these terms),
	 * the LMNN cost function is defined as:
	 *
	 * E := \sum_{j \in N} d(j)^2 + \sum_{k \in I} d(j)^2 + margin^2 - d(k)^2
	 *
	 * So in essence it punished distances to the target neighbors and closeness
	 * to imposters, but only while imposters are closer than target neighbors
	 * (including a margin of safety). Thereby the LMNN cost function also
	 * provides an objective to optimize the k-nearest neighbor error.
	 *
	 * @param targetNeighbors the set of target neighbors.
	 * @param distances       the distances d of the data point for which the
	 *                        cost function shall be calculated to all training
	 *                        data points.
	 *
	 * @return the value of the LMNN cost function given the target neighbors of
	 *         some data point and its distances to the training data points.
	 */
	public double calculateLMNNCostFunction(TreeSet targetNeighbors, double[] distances) {
		double cost = 0;
		for (final DistanceIndex tn : targetNeighbors) {
			// add squared distance to target neighbor as cost.
			final double tn_dist_sqrd = tn.distance * tn.distance;
			cost += tn_dist_sqrd;
			// find the imposters
			final TreeSet imposters = DistanceIndex.getImposters(
					tn.index, distances, getTrainingLabels(), margin);
			for (final DistanceIndex imp : imposters) {
				// add distance inside the margin as cost.
				cost += margin * margin + tn_dist_sqrd - imp.distance * imp.distance;
			}
		}
		return cost;
	}

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy