
de.citec.tcs.alignment.learning.LMNNClassifier Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of learning Show documentation
Show all versions of learning Show documentation
This module is a custom implementation of the Large Margin
Nearest Neighbor classification scheme of Weinberger, Saul, et al. (2009).
It contains an implementation of the k-nearest neighbor and LMNN classifier
as well as (most importantly) gradient calculation schemes on the LMNN
cost function given a sequential data set and a user-choice of alignment
algorithm. This enables users to learn parameters of the alignment
distance in question using a gradient descent on the LMNN cost function.
More information on this approach can be found in the Masters Thesis
"Adaptive Affine Sequence Alignment Using Algebraic Dynamic Programming"
The newest version!
/*
* TCS Alignment Toolbox Version 3
*
* Copyright (C) 2016
* Benjamin Paaßen
* AG Theoretical Computer Science
* Centre of Excellence Cognitive Interaction Technology (CITEC)
* University of Bielefeld
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see .
*/
package de.citec.tcs.alignment.learning;
import java.util.TreeSet;
/**
* This implements a Large Margin Nearest Neighbor classifier as suggested by
* Weinberger, Saul et al. (2009). The classification is made based on the
* LMNN cost function: We calculate the cost function for the given data point
* for each possible class and take the class label with the lowest cost
* function value.
*
* @author Benjamin Paassen - bpaassen(at)techfak.uni-bielefeld.de
*/
public class LMNNClassifier extends AbstractDissimilarityClassifier {
private int K = 5;
private double margin = 0.01;
public LMNNClassifier(int[] trainingLabels) {
super(trainingLabels);
}
/**
* Returns the number of considered nearest neighbors in the LMNN cost
* function.
*
* @return the number of considered nearest neighbors in the LMNN cost
* function.
*/
public int getK() {
return K;
}
/**
* Sets the number of considered nearest neighbors in the LMNN cost
* function.
*
* @param K the number of considered nearest neighbors in the LMNN cost
* function.
*/
public void setK(int K) {
this.K = K;
}
/**
* Returns the margin of safety that is required by the LMNN cost function.
*
* @return the margin of safety that is required by the LMNN cost function.
*/
public double getMargin() {
return margin;
}
/**
* Sets the margin of safety that is required by the LMNN cost function.
*
* @param margin the margin of safety that is required by the LMNN cost
* function.
*/
public void setMargin(double margin) {
this.margin = margin;
}
@Override
public int classifyTest(double[] distances) {
int minLabel = -1;
double minCost = Double.POSITIVE_INFINITY;
for (final Integer label : getLabels()) {
final double cost = calculateLMNNCostFunctionTest(label, distances);
if (cost < minCost) {
minLabel = label;
minCost = cost;
}
}
return minLabel;
}
@Override
public int classifyTraining(int i, double[] distances) {
int minLabel = -1;
double minCost = Double.POSITIVE_INFINITY;
for (final Integer label : getLabels()) {
final double cost = calculateLMNNCostFunctionTraining(i, label, distances);
if (cost < minCost) {
minLabel = label;
minCost = cost;
}
}
return minLabel;
}
/**
* Calculates the value of the LMNN cost function for the given data point,
* assuming that it belongs to the class with the given label and that it
* has the given distances to all training data points.
*
* @param i the index of the data point in the training set.
* @param label the supposed label of the data point.
* @param distances the distances of the data point to all training data
* points, including the self-distance.
*
* @return The value of the LMNN cost function for the data point.
*/
public double calculateLMNNCostFunctionTraining(int i, int label, double[] distances) {
return calculateLMNNCostFunction(
DistanceIndex.getTargetNeighborsTraining(
K, i, label, distances, getTrainingLabels()),
distances);
}
/**
* Calculates the value of the LMNN cost function for the given data point,
* assuming that it belongs to the class with the given label and that it
* has the given distances to all training data points.
*
* This assumes that the data point is not part of the training data set. If
* it is, use calculateLMNNCostFunctionTraining.
*
* @param label the supposed label of the data point.
* @param distances the distances of the data point to all training data
* points, not including the self-distance.
*
* @return The value of the LMNN cost function for the data point.
*/
public double calculateLMNNCostFunctionTest(int label, double[] distances) {
return calculateLMNNCostFunction(
DistanceIndex.getTargetNeighborsTest(
K, label, distances, getTrainingLabels()),
distances);
}
/**
* Calculates the LMNN cost function given the target neighbors of some data
* point and its distances to the training data points.
*
* Given the set of target neighbors N and the set of imposters I (look up
* in class DistanceIndex for more detailed explanations of these terms),
* the LMNN cost function is defined as:
*
* E := \sum_{j \in N} d(j)^2 + \sum_{k \in I} d(j)^2 + margin^2 - d(k)^2
*
* So in essence it punished distances to the target neighbors and closeness
* to imposters, but only while imposters are closer than target neighbors
* (including a margin of safety). Thereby the LMNN cost function also
* provides an objective to optimize the k-nearest neighbor error.
*
* @param targetNeighbors the set of target neighbors.
* @param distances the distances d of the data point for which the
* cost function shall be calculated to all training
* data points.
*
* @return the value of the LMNN cost function given the target neighbors of
* some data point and its distances to the training data points.
*/
public double calculateLMNNCostFunction(TreeSet targetNeighbors, double[] distances) {
double cost = 0;
for (final DistanceIndex tn : targetNeighbors) {
// add squared distance to target neighbor as cost.
final double tn_dist_sqrd = tn.distance * tn.distance;
cost += tn_dist_sqrd;
// find the imposters
final TreeSet imposters = DistanceIndex.getImposters(
tn.index, distances, getTrainingLabels(), margin);
for (final DistanceIndex imp : imposters) {
// add distance inside the margin as cost.
cost += margin * margin + tn_dist_sqrd - imp.distance * imp.distance;
}
}
return cost;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy