
de.citec.tcs.alignment.learning.LMNNGradientCalculator Maven / Gradle / Ivy
/*
* TCS Alignment Toolbox Version 3
*
* Copyright (C) 2016
* Benjamin Paaßen
* AG Theoretical Computer Science
* Centre of Excellence Cognitive Interaction Technology (CITEC)
* University of Bielefeld
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see .
*/
package de.citec.tcs.alignment.learning;
import de.citec.tcs.alignment.AlignmentAlgorithm;
import de.citec.tcs.alignment.DerivableAlignmentDistance;
import de.citec.tcs.alignment.comparators.DerivableComparator;
import de.citec.tcs.alignment.parallel.CommandLineProgressReporter;
import de.citec.tcs.alignment.parallel.Engine;
import de.citec.tcs.alignment.parallel.ProgressReporter;
import java.util.List;
import java.util.TreeSet;
import java.util.concurrent.Callable;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
/**
* This implements the Large Margin Nearest Neighbor Metric Learning approach
* by Weinberger et al. (2009) with respect to alignment distance metric
* learning.
*
* @author Benjamin Paassen - bpaassen(at)techfak.uni-bielefeld.de
* @param the class of the elements in the input sequences.
*/
public class LMNNGradientCalculator {
/**
* The sequences used for training.
*
* @return The sequences used for training.
*/
@Getter
private final List extends List> trainingSeqs;
/**
* The labels for the training sequences.
*
* @return The labels for the training sequences.
*/
@Getter
private final int[] trainingLabels;
/**
* The algorithm used to compute pairwise DerivableAlignmentDistances of the input data.
*
* @return The algorithm used to compute pairwise DerivableAlignmentDistances of the input data.
*/
@Getter
private final AlignmentAlgorithm algo;
/**
* The number of considered nearest neighbors in the LMNN cost function.
*
* @param K The number of considered nearest neighbors in the LMNN cost function.
*
* @return The number of considered nearest neighbors in the LMNN cost function.
*/
@Getter
@Setter
private int K = 5;
/**
* The margin of safety that is required by the LMNN cost function.
*
* @param margin The margin of safety that is required by the LMNN cost function.
*
* @return The margin of safety that is required by the LMNN cost function.
*/
@Getter
@Setter
private double margin = 0.01;
/**
* The number of threads used in the parallel computation of the gradient on the LMNN cost
* function.
*
* @param The number of threads used in the parallel computation of the gradient on the LMNN
* cost function.
*
* @return The number of threads used in the parallel computation of the gradient on the LMNN
* cost function.
*/
@Getter
@Setter
private int numberOfThreads = Engine.DEFAULT_NUMBER_OF_THREADS;
/**
* The ProgressReporter that is used to report progress. This is a
* CommandLineProgressReporter per default. If it is set to null, the
* progress is not reported.
*
* @param reporter The ProgressReporter that is used to report progress.
*
* @return The ProgressReporter that is used to report progress.
*/
@Getter
@Setter
private ProgressReporter reporter = new CommandLineProgressReporter();
public LMNNGradientCalculator(
@NonNull List extends List> data,
@NonNull int[] trainingLabels,
@NonNull AlignmentAlgorithm> algo) {
this.trainingLabels = trainingLabels;
this.trainingSeqs = data;
this.algo = algo;
if (trainingLabels.length != data.size()) {
throw new IllegalArgumentException(
"Expected one label for each training data point, but got "
+ data.size() + " training sequences and "
+ trainingLabels.length + " labels!");
}
}
/**
* Calculates the gradient of the LMNN cost function with respect to the
* parameters of the given comparator.
*
* Note that this method relies on the intra-training distance matrix
* already being present. It can be calculated using a
* ParallelProcessingEngine with the sequences returned by getData(), the
* corresponding score algorithm to the algorithm returned by getAlgorithm()
* and the setting getFull().
*
* Given that matrix this gradient calculation is (fairly) fast and works
* in linear time if (!) there are not many imposters. The worst case
* complexity for this method is still quadratic, but it tries to do as few
* computations as possible.
*
* @param comp the comparator itself.
* @param D given N training data points this should be a N x N matrix
* of alignment distances computed with the same distance
* scheme as is implemented by the given algorithm for this
* LMNNGradientCalculator. This distance matrix serves as basis for the
* determination of the LMNN cost function.
*
* @return the gradient of the LMNN cost function with respect to the
* parameters of the given comparator. The return format is defined by
* the respective comparator.
*/
public double[] computeGradient(@NonNull DerivableComparator comp, @NonNull double[][] D) {
// check the input
if (trainingSeqs.size() != D.length) {
throw new IllegalArgumentException(
"Expected distances for each training data point, but had "
+ trainingSeqs.size() + " data points and " + D.length
+ " rows in the given distance matrix.");
}
for (int i = 0; i < trainingSeqs.size(); i++) {
if (trainingSeqs.size() != D[i].length) {
throw new IllegalArgumentException(
"Expected distances to each training data point, but had "
+ trainingSeqs.size() + " data points and " + D.length
+ " columns in the given distance matrix."
);
}
}
// calculate the gradient.
final double[] gradient = new double[comp.getNumberOfParameters()];
final ParallelGradientEngine engine = new ParallelGradientEngine(comp, D);
engine.setNumberOfThreads(numberOfThreads);
engine.setReporter(reporter);
// iterate over all datapoints.
for (int i = 0; i < trainingSeqs.size(); i++) {
// create a parallel processing job for it.
engine.addTask(i);
}
// calculate
engine.calculate();
// sum up the resulting part-gradients.
for (final Engine.CalculationResult res : engine.getResults()) {
for (int p = 0; p < gradient.length; p++) {
gradient[p] += res.result[p];
}
}
return gradient;
}
private class ParallelGradientEngine extends Engine {
private final DerivableComparator comp;
private final double[][] D;
public ParallelGradientEngine(DerivableComparator comp, double[][] D) {
super(Integer.class, double[].class);
this.comp = comp;
this.D = D;
}
@Override
public Callable createCallable(Integer i) {
return new LMNNGradientJob(comp, i, D[i]);
}
}
private class LMNNGradientJob implements Callable {
private final DerivableComparator comp;
private final int i;
private final double[] distances;
public LMNNGradientJob(DerivableComparator comp, int i, double[] distances) {
this.comp = comp;
this.i = i;
this.distances = distances;
}
@Override
public double[] call() throws Exception {
final double[] gradient = new double[comp.getNumberOfParameters()];
// get the target neighbors.
final TreeSet targetNeighbors
= DistanceIndex.getTargetNeighborsTraining(
K, i, distances, trainingLabels);
for (final DistanceIndex tn : targetNeighbors) {
// add the gradients for all target neighbors.
final double[] tnGrad;
{
final DerivableAlignmentDistance tnDist = algo.calculateAlignment(
trainingSeqs.get(i), trainingSeqs.get(tn.index));
tnGrad = tnDist.computeGradient(comp);
}
// add it to the existing gradient.
for (int d = 0; d < gradient.length; d++) {
gradient[d] += tn.distance * tnGrad[d];
}
// then add the (negative) gradient for the imposters.
final TreeSet imposters
= DistanceIndex.getImposters(
tn.index, distances, trainingLabels, margin);
for (final DistanceIndex im : imposters) {
final double[] imGrad;
{
final DerivableAlignmentDistance imDist = algo.calculateAlignment(
trainingSeqs.get(i), trainingSeqs.get(im.index));
imGrad = imDist.computeGradient(comp);
}
// add it to the existing gradient.
for (int d = 0; d < gradient.length; d++) {
gradient[d] += tn.distance * tnGrad[d] - im.distance * imGrad[d];
}
}
}
return gradient;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy