All Downloads are FREE. Search and download functionalities are using the official Maven repository.

de.citec.tcs.alignment.learning.AbstractDissimilarityClassifier Maven / Gradle / Ivy

Go to download

This module is a custom implementation of the Large Margin Nearest Neighbor classification scheme of Weinberger, Saul, et al. (2009). It contains an implementation of the k-nearest neighbor and LMNN classifier as well as (most importantly) gradient calculation schemes on the LMNN cost function given a sequential data set and a user-choice of alignment algorithm. This enables users to learn parameters of the alignment distance in question using a gradient descent on the LMNN cost function. More information on this approach can be found in the Masters Thesis "Adaptive Affine Sequence Alignment Using Algebraic Dynamic Programming"

The newest version!
/* 
 * TCS Alignment Toolbox Version 3
 * 
 * Copyright (C) 2016
 * Benjamin Paaßen
 * AG Theoretical Computer Science
 * Centre of Excellence Cognitive Interaction Technology (CITEC)
 * University of Bielefeld
 * 
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as
 * published by the Free Software Foundation, either version 3 of the
 * License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 * 
 * You should have received a copy of the GNU Affero General Public License
 * along with this program.  If not, see .
 */

package de.citec.tcs.alignment.learning;

import java.util.TreeSet;

/**
 * This is a convenience extension of the DissimilarityClassifier interface,
 * which already implements most of the functionality to make the implementation
 * of DissimilarityClassifiers easier.
 *
 * @author Benjamin Paassen - bpaassen(at)techfak.uni-bielefeld.de
 */
public abstract class AbstractDissimilarityClassifier implements DissimilarityClassifier {

	private final int[] trainingLabels;
	private final TreeSet labels = new TreeSet<>();

	public AbstractDissimilarityClassifier(int[] trainingLabels) {
		this.trainingLabels = trainingLabels;
		for (final int label : trainingLabels) {
			labels.add(label);
		}
	}

	/**
	 * Returns the correct class labels for all given training data
	 * points.
	 *
	 * @return the array of class labels for each training data point.
	 */
	public int[] getTrainingLabels() {
		return trainingLabels;
	}

	/**
	 * Returns the set of available class labels.
	 *
	 * @return the set of available class labels.
	 */
	public TreeSet getLabels() {
		return labels;
	}

	@Override
	public double calculateTrainingAccuracy(double[][] D) {
		if (D.length != trainingLabels.length) {
			throw new IllegalArgumentException("Expected the distance matrix "
					+ "to show the distances from the training to the training data. "
					+ "But we had " + trainingLabels.length + " training data points and "
					+ D.length + " rows in the distance matrix.");
		}
		int correct = 0;
		for (int i = 0; i < D.length; i++) {
			if (D[i].length != trainingLabels.length) {
				throw new IllegalArgumentException("Expected the distance matrix "
						+ "to show the distances from the test to the training data. "
						+ "But we had " + trainingLabels.length + " training data points and "
						+ D[i].length + " columns in the distance matrix.");
			}
			if (classifyTraining(i, D[i]) == trainingLabels[i]) {
				correct++;
			}
		}
		return (double) correct / (double) D.length;
	}

	@Override
	public double calculateTestAccuracy(int[] testLabels, double[][] D) {
		if (D.length != testLabels.length) {
			throw new IllegalArgumentException("Expected the distance matrix to "
					+ "show the distances from the test to the training data. But we "
					+ "had " + testLabels.length + " test data points and " + D.length
					+ " rows in the distance matrix.");
		}
		int correct = 0;
		for (int i = 0; i < testLabels.length; i++) {
			if (D[i].length != trainingLabels.length) {
				throw new IllegalArgumentException("Expected the distance matrix "
						+ "to show the distances from the test to the training data. "
						+ "But we had " + trainingLabels.length + " training data points and "
						+ D[i].length + " columns in the distance matrix.");
			}
			if (classifyTest(D[i]) == testLabels[i]) {
				correct++;
			}
		}
		return (double) correct / (double) testLabels.length;
	}

	@Override
	public abstract int classifyTraining(int i, double[] distances);

	@Override
	public abstract int classifyTest(double[] distances);
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy