
de.citec.tcs.alignment.learning.LMNNClassifier Maven / Gradle / Ivy
/*
* TCS Alignment Toolbox Version 3
*
* Copyright (C) 2016
* Benjamin Paaßen
* AG Theoretical Computer Science
* Centre of Excellence Cognitive Interaction Technology (CITEC)
* University of Bielefeld
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see .
*/
package de.citec.tcs.alignment.learning;
import java.util.TreeSet;
/**
* This implements a Large Margin Nearest Neighbor classifier as suggested by
* Weinberger, Saul et al. (2009). The classification is made based on the
* LMNN cost function: We calculate the cost function for the given data point
* for each possible class and take the class label with the lowest cost
* function value.
*
* @author Benjamin Paassen - bpaassen(at)techfak.uni-bielefeld.de
*/
public class LMNNClassifier extends AbstractDissimilarityClassifier {
private int K = 5;
private double margin = 0.01;
public LMNNClassifier(int[] trainingLabels) {
super(trainingLabels);
}
/**
* Returns the number of considered nearest neighbors in the LMNN cost
* function.
*
* @return the number of considered nearest neighbors in the LMNN cost
* function.
*/
public int getK() {
return K;
}
/**
* Sets the number of considered nearest neighbors in the LMNN cost
* function.
*
* @param K the number of considered nearest neighbors in the LMNN cost
* function.
*/
public void setK(int K) {
this.K = K;
}
/**
* Returns the margin of safety that is required by the LMNN cost function.
*
* @return the margin of safety that is required by the LMNN cost function.
*/
public double getMargin() {
return margin;
}
/**
* Sets the margin of safety that is required by the LMNN cost function.
*
* @param margin the margin of safety that is required by the LMNN cost
* function.
*/
public void setMargin(double margin) {
this.margin = margin;
}
@Override
public int classifyTest(double[] distances) {
int minLabel = -1;
double minCost = Double.POSITIVE_INFINITY;
for (final Integer label : getLabels()) {
final double cost = calculateLMNNCostFunctionTest(label, distances);
if (cost < minCost) {
minLabel = label;
minCost = cost;
}
}
return minLabel;
}
@Override
public int classifyTraining(int i, double[] distances) {
int minLabel = -1;
double minCost = Double.POSITIVE_INFINITY;
for (final Integer label : getLabels()) {
final double cost = calculateLMNNCostFunctionTraining(i, label, distances);
if (cost < minCost) {
minLabel = label;
minCost = cost;
}
}
return minLabel;
}
/**
* Calculates the value of the LMNN cost function for the given data point,
* assuming that it belongs to the class with the given label and that it
* has the given distances to all training data points.
*
* @param i the index of the data point in the training set.
* @param label the supposed label of the data point.
* @param distances the distances of the data point to all training data
* points, including the self-distance.
*
* @return The value of the LMNN cost function for the data point.
*/
public double calculateLMNNCostFunctionTraining(int i, int label, double[] distances) {
return calculateLMNNCostFunction(
DistanceIndex.getTargetNeighborsTraining(
K, i, label, distances, getTrainingLabels()),
distances);
}
/**
* Calculates the value of the LMNN cost function for the given data point,
* assuming that it belongs to the class with the given label and that it
* has the given distances to all training data points.
*
* This assumes that the data point is not part of the training data set. If
* it is, use calculateLMNNCostFunctionTraining.
*
* @param label the supposed label of the data point.
* @param distances the distances of the data point to all training data
* points, not including the self-distance.
*
* @return The value of the LMNN cost function for the data point.
*/
public double calculateLMNNCostFunctionTest(int label, double[] distances) {
return calculateLMNNCostFunction(
DistanceIndex.getTargetNeighborsTest(
K, label, distances, getTrainingLabels()),
distances);
}
/**
* Calculates the LMNN cost function given the target neighbors of some data
* point and its distances to the training data points.
*
* Given the set of target neighbors N and the set of imposters I (look up
* in class DistanceIndex for more detailed explanations of these terms),
* the LMNN cost function is defined as:
*
* E := \sum_{j \in N} d(j)^2 + \sum_{k \in I} d(j)^2 + margin^2 - d(k)^2
*
* So in essence it punished distances to the target neighbors and closeness
* to imposters, but only while imposters are closer than target neighbors
* (including a margin of safety). Thereby the LMNN cost function also
* provides an objective to optimize the k-nearest neighbor error.
*
* @param targetNeighbors the set of target neighbors.
* @param distances the distances d of the data point for which the
* cost function shall be calculated to all training
* data points.
*
* @return the value of the LMNN cost function given the target neighbors of
* some data point and its distances to the training data points.
*/
public double calculateLMNNCostFunction(TreeSet targetNeighbors, double[] distances) {
double cost = 0;
for (final DistanceIndex tn : targetNeighbors) {
// add squared distance to target neighbor as cost.
final double tn_dist_sqrd = tn.distance * tn.distance;
cost += tn_dist_sqrd;
// find the imposters
final TreeSet imposters = DistanceIndex.getImposters(
tn.index, distances, getTrainingLabels(), margin);
for (final DistanceIndex imp : imposters) {
// add distance inside the margin as cost.
cost += margin * margin + tn_dist_sqrd - imp.distance * imp.distance;
}
}
return cost;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy