All Downloads are FREE. Search and download functionalities are using the official Maven repository.

de.citec.tcs.alignment.learning.LMNNGradientCalculator Maven / Gradle / Ivy

/* 
 * TCS Alignment Toolbox Version 3
 * 
 * Copyright (C) 2016
 * Benjamin Paaßen
 * AG Theoretical Computer Science
 * Centre of Excellence Cognitive Interaction Technology (CITEC)
 * University of Bielefeld
 * 
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as
 * published by the Free Software Foundation, either version 3 of the
 * License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 * 
 * You should have received a copy of the GNU Affero General Public License
 * along with this program.  If not, see .
 */

package de.citec.tcs.alignment.learning;

import de.citec.tcs.alignment.AlignmentAlgorithm;
import de.citec.tcs.alignment.DerivableAlignmentDistance;
import de.citec.tcs.alignment.comparators.DerivableComparator;
import de.citec.tcs.alignment.parallel.CommandLineProgressReporter;
import de.citec.tcs.alignment.parallel.Engine;
import de.citec.tcs.alignment.parallel.ProgressReporter;
import java.util.List;
import java.util.TreeSet;
import java.util.concurrent.Callable;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;

/**
 * This implements the Large Margin Nearest Neighbor Metric Learning approach
 * by Weinberger et al. (2009) with respect to alignment distance metric
 * learning.
 *
 * @author Benjamin Paassen - bpaassen(at)techfak.uni-bielefeld.de
 * @param  the class of the elements in the input sequences.
 */
public class LMNNGradientCalculator {

	/**
	 * The sequences used for training.
	 *
	 * @return The sequences used for training.
	 */
	@Getter
	private final List> trainingSeqs;
	/**
	 * The labels for the training sequences.
	 *
	 * @return The labels for the training sequences.
	 */
	@Getter
	private final int[] trainingLabels;
	/**
	 * The algorithm used to compute pairwise DerivableAlignmentDistances of the input data.
	 *
	 * @return The algorithm used to compute pairwise DerivableAlignmentDistances of the input data.
	 */
	@Getter
	private final AlignmentAlgorithm algo;
	/**
	 * The number of considered nearest neighbors in the LMNN cost function.
	 *
	 * @param K The number of considered nearest neighbors in the LMNN cost function.
	 *
	 * @return The number of considered nearest neighbors in the LMNN cost function.
	 */
	@Getter
	@Setter
	private int K = 5;
	/**
	 * The margin of safety that is required by the LMNN cost function.
	 *
	 * @param margin The margin of safety that is required by the LMNN cost function.
	 *
	 * @return The margin of safety that is required by the LMNN cost function.
	 */
	@Getter
	@Setter
	private double margin = 0.01;
	/**
	 * The number of threads used in the parallel computation of the gradient on the LMNN cost
	 * function.
	 *
	 * @param The number of threads used in the parallel computation of the gradient on the LMNN
	 * cost function.
	 *
	 * @return The number of threads used in the parallel computation of the gradient on the LMNN
	 * cost function.
	 */
	@Getter
	@Setter
	private int numberOfThreads = Engine.DEFAULT_NUMBER_OF_THREADS;
	/**
	 * The ProgressReporter that is used to report progress. This is a
	 * CommandLineProgressReporter per default. If it is set to null, the
	 * progress is not reported.
	 *
	 * @param reporter The ProgressReporter that is used to report progress.
	 *
	 * @return The ProgressReporter that is used to report progress.
	 */
	@Getter
	@Setter
	private ProgressReporter reporter = new CommandLineProgressReporter();

	public LMNNGradientCalculator(
			@NonNull List> data,
			@NonNull int[] trainingLabels,
			@NonNull AlignmentAlgorithm> algo) {
		this.trainingLabels = trainingLabels;
		this.trainingSeqs = data;
		this.algo = algo;
		if (trainingLabels.length != data.size()) {
			throw new IllegalArgumentException(
					"Expected one label for each training data point, but got "
					+ data.size() + " training sequences and "
					+ trainingLabels.length + " labels!");
		}
	}

	/**
	 * Calculates the gradient of the LMNN cost function with respect to the
	 * parameters of the given comparator.
	 *
	 * Note that this method relies on the intra-training distance matrix
	 * already being present. It can be calculated using a
	 * ParallelProcessingEngine with the sequences returned by getData(), the
	 * corresponding score algorithm to the algorithm returned by getAlgorithm()
	 * and the setting getFull().
	 *
	 * Given that matrix this gradient calculation is (fairly) fast and works
	 * in linear time if (!) there are not many imposters. The worst case
	 * complexity for this method is still quadratic, but it tries to do as few
	 * computations as possible.
	 *
	 * @param comp the comparator itself.
	 * @param D given N training data points this should be a N x N matrix
	 * of alignment distances computed with the same distance
	 * scheme as is implemented by the given algorithm for this
	 * LMNNGradientCalculator. This distance matrix serves as basis for the
	 * determination of the LMNN cost function.
	 *
	 * @return the gradient of the LMNN cost function with respect to the
	 * parameters of the given comparator. The return format is defined by
	 * the respective comparator.
	 */
	public double[] computeGradient(@NonNull DerivableComparator comp, @NonNull double[][] D) {
		// check the input
		if (trainingSeqs.size() != D.length) {
			throw new IllegalArgumentException(
					"Expected distances for each training data point, but had "
					+ trainingSeqs.size() + " data points and " + D.length
					+ " rows in the given distance matrix.");
		}
		for (int i = 0; i < trainingSeqs.size(); i++) {
			if (trainingSeqs.size() != D[i].length) {
				throw new IllegalArgumentException(
						"Expected distances to each training data point, but had "
						+ trainingSeqs.size() + " data points and " + D.length
						+ " columns in the given distance matrix."
				);
			}
		}

		// calculate the gradient.
		final double[] gradient = new double[comp.getNumberOfParameters()];
		final ParallelGradientEngine engine = new ParallelGradientEngine(comp, D);
		engine.setNumberOfThreads(numberOfThreads);
		engine.setReporter(reporter);
		// iterate over all datapoints.				
		for (int i = 0; i < trainingSeqs.size(); i++) {
			// create a parallel processing job for it.
			engine.addTask(i);
		}
		// calculate
		engine.calculate();
		// sum up the resulting part-gradients.
		for (final Engine.CalculationResult res : engine.getResults()) {
			for (int p = 0; p < gradient.length; p++) {
				gradient[p] += res.result[p];
			}
		}
		return gradient;
	}

	private class ParallelGradientEngine extends Engine {

		private final DerivableComparator comp;
		private final double[][] D;

		public ParallelGradientEngine(DerivableComparator comp, double[][] D) {
			super(Integer.class, double[].class);
			this.comp = comp;
			this.D = D;
		}

		@Override
		public Callable createCallable(Integer i) {
			return new LMNNGradientJob(comp, i, D[i]);
		}

	}

	private class LMNNGradientJob implements Callable {

		private final DerivableComparator comp;
		private final int i;
		private final double[] distances;

		public LMNNGradientJob(DerivableComparator comp, int i, double[] distances) {
			this.comp = comp;
			this.i = i;
			this.distances = distances;
		}

		@Override
		public double[] call() throws Exception {
			final double[] gradient = new double[comp.getNumberOfParameters()];
			// get the target neighbors.
			final TreeSet targetNeighbors
					= DistanceIndex.getTargetNeighborsTraining(
							K, i, distances, trainingLabels);
			for (final DistanceIndex tn : targetNeighbors) {
				// add the gradients for all target neighbors.
				final double[] tnGrad;
				{
					final DerivableAlignmentDistance tnDist = algo.calculateAlignment(
							trainingSeqs.get(i), trainingSeqs.get(tn.index));
					tnGrad = tnDist.computeGradient(comp);
				}
				// add it to the existing gradient.
				for (int d = 0; d < gradient.length; d++) {
					gradient[d] += tn.distance * tnGrad[d];
				}
				// then add the (negative) gradient for the imposters.
				final TreeSet imposters
						= DistanceIndex.getImposters(
								tn.index, distances, trainingLabels, margin);
				for (final DistanceIndex im : imposters) {
					final double[] imGrad;
					{
						final DerivableAlignmentDistance imDist = algo.calculateAlignment(
								trainingSeqs.get(i), trainingSeqs.get(im.index));
						imGrad = imDist.computeGradient(comp);
					}
					// add it to the existing gradient.
					for (int d = 0; d < gradient.length; d++) {
						gradient[d] += tn.distance * tnGrad[d] - im.distance * imGrad[d];
					}
				}
			}
			return gradient;
		}

	}

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy