
de.citec.tcs.alignment.learning.AbstractDissimilarityClassifier Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of learning Show documentation
Show all versions of learning Show documentation
This module is a custom implementation of the Large Margin
Nearest Neighbor classification scheme of Weinberger, Saul, et al. (2009).
It contains an implementation of the k-nearest neighbor and LMNN classifier
as well as (most importantly) gradient calculation schemes on the LMNN
cost function given a sequential data set and a user-choice of alignment
algorithm. This enables users to learn parameters of the alignment
distance in question using a gradient descent on the LMNN cost function.
More information on this approach can be found in the Masters Thesis
"Adaptive Affine Sequence Alignment Using Algebraic Dynamic Programming"
The newest version!
/*
* TCS Alignment Toolbox Version 3
*
* Copyright (C) 2016
* Benjamin Paaßen
* AG Theoretical Computer Science
* Centre of Excellence Cognitive Interaction Technology (CITEC)
* University of Bielefeld
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see .
*/
package de.citec.tcs.alignment.learning;
import java.util.TreeSet;
/**
* This is a convenience extension of the DissimilarityClassifier interface,
* which already implements most of the functionality to make the implementation
* of DissimilarityClassifiers easier.
*
* @author Benjamin Paassen - bpaassen(at)techfak.uni-bielefeld.de
*/
public abstract class AbstractDissimilarityClassifier implements DissimilarityClassifier {
private final int[] trainingLabels;
private final TreeSet labels = new TreeSet<>();
public AbstractDissimilarityClassifier(int[] trainingLabels) {
this.trainingLabels = trainingLabels;
for (final int label : trainingLabels) {
labels.add(label);
}
}
/**
* Returns the correct class labels for all given training data
* points.
*
* @return the array of class labels for each training data point.
*/
public int[] getTrainingLabels() {
return trainingLabels;
}
/**
* Returns the set of available class labels.
*
* @return the set of available class labels.
*/
public TreeSet getLabels() {
return labels;
}
@Override
public double calculateTrainingAccuracy(double[][] D) {
if (D.length != trainingLabels.length) {
throw new IllegalArgumentException("Expected the distance matrix "
+ "to show the distances from the training to the training data. "
+ "But we had " + trainingLabels.length + " training data points and "
+ D.length + " rows in the distance matrix.");
}
int correct = 0;
for (int i = 0; i < D.length; i++) {
if (D[i].length != trainingLabels.length) {
throw new IllegalArgumentException("Expected the distance matrix "
+ "to show the distances from the test to the training data. "
+ "But we had " + trainingLabels.length + " training data points and "
+ D[i].length + " columns in the distance matrix.");
}
if (classifyTraining(i, D[i]) == trainingLabels[i]) {
correct++;
}
}
return (double) correct / (double) D.length;
}
@Override
public double calculateTestAccuracy(int[] testLabels, double[][] D) {
if (D.length != testLabels.length) {
throw new IllegalArgumentException("Expected the distance matrix to "
+ "show the distances from the test to the training data. But we "
+ "had " + testLabels.length + " test data points and " + D.length
+ " rows in the distance matrix.");
}
int correct = 0;
for (int i = 0; i < testLabels.length; i++) {
if (D[i].length != trainingLabels.length) {
throw new IllegalArgumentException("Expected the distance matrix "
+ "to show the distances from the test to the training data. "
+ "But we had " + trainingLabels.length + " training data points and "
+ D[i].length + " columns in the distance matrix.");
}
if (classifyTest(D[i]) == testLabels[i]) {
correct++;
}
}
return (double) correct / (double) testLabels.length;
}
@Override
public abstract int classifyTraining(int i, double[] distances);
@Override
public abstract int classifyTest(double[] distances);
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy