
de.citec.tcs.alignment.ParallelGradientEngine Maven / Gradle / Ivy
/*
* TCS Alignment Toolbox Version 3
*
* Copyright (C) 2016
* Benjamin Paaßen
* AG Theoretical Computer Science
* Centre of Excellence Cognitive Interaction Technology (CITEC)
* University of Bielefeld
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see .
*/
package de.citec.tcs.alignment;
import de.citec.tcs.alignment.comparators.DerivableComparator;
import de.citec.tcs.alignment.parallel.Engine;
import de.citec.tcs.alignment.parallel.MatrixEngine;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.Callable;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
/**
*
* This allows parallel processing of gradient calculations.
*
* @param the class of elements in the left input sequences.
* @param the class of elements in the right input sequences.
*
* @author Benjamin Paassen - bpaassen(at)techfak.uni-bielefeld.de
*/
public class ParallelGradientEngine extends MatrixEngine {
private final HashMap> distances;
/**
* The DerivableComparator with respect to which the gradient shall be computed.
*
* @param comparator The DerivableComparator with respect to which the gradient shall be
* computed.
*
* @return The DerivableComparator with respect to which the gradient shall be computed.
*/
@Getter
@Setter
@NonNull
private DerivableComparator comparator;
/**
* Creates a ParallelGradientEngine that computes gradients for several
* DeriableAlignmentDistance objects in parallel.
*
* @param derivableMatrixEntries a map of MatrixCoordinates to DeriableAlignmentDistance
* objects.
* @param M the number of rows in the original distance matrix the given distance objects belong
* to.
* @param N the number of columns in the original distance matrix the given distance objects
* belong to.
*/
/**
* Creates a ParallelGradientEngine that computes gradients for several
* DeriableAlignmentDistance objects in parallel.
*
* @param derivableMatrixEntries a map of MatrixCoordinates to DeriableAlignmentDistance
* objects.
* @param M the number of rows in the original distance matrix the given distance objects belong
* to.
* @param N the number of columns in the original distance matrix the given distance objects
* belong to.
* @param comparator The DerivableComparator with respect to which the gradient shall be
* computed.
*/
public ParallelGradientEngine(
@NonNull final Map> derivableMatrixEntries,
int M, int N, @NonNull final DerivableComparator comparator) {
super(M, N, double[].class);
this.distances = new HashMap<>(derivableMatrixEntries);
this.comparator = comparator;
}
/**
* Creates a ParallelGradientEngine that computes gradients for several
* DeriableAlignmentDistance objects in parallel.
*
* @param results a set of MatrixCoordinates with DeriableAlignmentDistance objects.
* @param M the number of rows in the original distance matrix the given distance objects belong
* to.
* @param N the number of columns in the original distance matrix the given distance objects
* belong to.
* @param comparator The DerivableComparator with respect to which the gradient shall be
* computed.
*/
public ParallelGradientEngine(
@NonNull final Collection>> results,
int M, int N, @NonNull final DerivableComparator comparator) {
super(M, N, double[].class);
this.comparator = comparator;
this.distances = new HashMap<>();
for (final Engine.CalculationResult extends MatrixEngine.MatrixCoordinate, ? extends DerivableAlignmentDistance> result : results) {
distances.put(result.ident, result.result);
}
}
/**
* Creates a ParallelGradientEngine that computes gradients for several
* DeriableAlignmentDistance objects in parallel.
*
* @param derivableMatrixEntries a matrix of DeriableAlignmentDistance objects.
* @param comparator The DerivableComparator with respect to which the gradient shall be
* computed.
*/
public ParallelGradientEngine(@NonNull final DerivableAlignmentDistance[][] derivableMatrixEntries,
@NonNull final DerivableComparator comparator) {
super(derivableMatrixEntries.length,
MatrixEngine.extractNumberOfColumns(derivableMatrixEntries),
double[].class);
this.comparator = comparator;
this.distances = new HashMap<>();
for (int i = 0; i < getM(); i++) {
if (derivableMatrixEntries[i] == null) {
continue;
}
if (derivableMatrixEntries[i].length != getN()) {
throw new IllegalArgumentException("The number of columns in the input matrix is inconsistent!");
}
for (int j = 0; j < getN(); j++) {
if (derivableMatrixEntries[i][j] == null) {
continue;
}
this.distances.put(new MatrixCoordinate(i, j), derivableMatrixEntries[i][j]);
}
}
}
/**
* Returns the DerivableAlignmentDistance objects used for derivative calculation.
*
* @return the DerivableAlignmentDistance objects used for derivative calculation.
*/
public Map> getDistances() {
return distances;
}
/**
* Overrides the setFull() method of matrix engine. This issues a
* calculation task for every available DerivableAlignmentDistance.
*/
@Override
public void setFull() {
setSpecificTasks(distances.keySet());
}
@Override
public Callable createCallable(MatrixCoordinate ident) {
final DerivableAlignmentDistance dist = distances.get(ident);
if (dist == null) {
throw new IllegalArgumentException("No derivable matrix entry was given for the matrix coordinate " + ident);
}
if (comparator == null) {
throw new IllegalArgumentException("No comparator was given for which the gradient can be calculated.");
}
return new DerivativeCallable(dist);
}
private class DerivativeCallable implements Callable {
private final DerivableAlignmentDistance dist;
public DerivativeCallable(DerivableAlignmentDistance dist) {
this.dist = dist;
}
@Override
public double[] call() throws Exception {
return dist.computeGradient(comparator);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy