All Downloads are FREE. Search and download functionalities are using the official Maven repository.

de.citec.tcs.alignment.SoftAffinePathModel Maven / Gradle / Ivy

/* 
 * TCS Alignment Toolbox
 * 
 * Copyright (C) 2013-2015
 * Benjamin Paaßen, Georg Zentgraf
 * AG Theoretical Computer Science
 * Centre of Excellence Cognitive Interaction Technology (CITEC)
 * University of Bielefeld
 * 
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as
 * published by the Free Software Foundation, either version 3 of the
 * License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 * 
 * You should have received a copy of the GNU Affero General Public License
 * along with this program.  If not, see .
 */
package de.citec.tcs.alignment;

import de.citec.tcs.alignment.AbstractAffineAlignmentAlgorithm.Recurrence;
import de.citec.tcs.alignment.comparators.DerivableComparator;
import de.citec.tcs.alignment.comparators.OperationType;
import de.citec.tcs.alignment.comparators.SkipComparator;
import de.citec.tcs.alignment.comparators.SparseDerivableComparator;
import de.citec.tcs.alignment.comparators.SparseLocalDerivative;
import de.citec.tcs.alignment.parallel.MatrixEngine.MatrixCoordinate;
import de.citec.tcs.alignment.sequence.Sequence;
import de.citec.tcs.alignment.sequence.Value;
import java.util.ArrayList;
import java.util.Collection;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Stack;

/**
 * This is basically a storage for the DP matrix of the
 * SoftAffineAlignmentAlgorithm. It has the added value of being able to
 * calculate the soft derivative with respect to weighting and parameters.
 *
 * @author Benjamin Paassen - [email protected]
 */
public class SoftAffinePathModel implements AlignmentDerivativeAlgorithm {

	private final double beta;
	private final AlignmentSpecification specification;
	private final int minMiddleSkips;
	private final double distance;
	private final EnumMap dp_tables;
	private final double[][] compareMatrix;
	private final double[] deletionMatrix;
	private final double[] insertionMatrix;
	private final double[] skipDeletionMatrix;
	private final double[] skipInsertionMatrix;
	private final Sequence leftSequence;
	private final Sequence rightSequence;

	public SoftAffinePathModel(double beta, AlignmentSpecification specification,
			int minMiddleSkips, double distance,
			EnumMap dp_tables,
			double[][] compareMatrix,
			double[] deletionMatrix, double[] insertionMatrix,
			double[] skipDeletionMatrix, double[] skipInsertionMatrix,
			Sequence leftSequence, Sequence rightSequence) {
		this.beta = beta;
		this.specification = specification;
		this.minMiddleSkips = minMiddleSkips;
		this.distance = distance;
		this.dp_tables = dp_tables;
		this.compareMatrix = compareMatrix;
		this.deletionMatrix = deletionMatrix;
		this.insertionMatrix = insertionMatrix;
		this.skipDeletionMatrix = skipDeletionMatrix;
		this.skipInsertionMatrix = skipInsertionMatrix;
		this.leftSequence = leftSequence;
		this.rightSequence = rightSequence;

		if (leftSequence.getNodeSpecification() != specification.getNodeSpecification()
				&& !leftSequence.getNodeSpecification().equals(specification.getNodeSpecification())) {
			throw new IllegalArgumentException(
					"The first input sequence has an unexpected node specification!");
		}
		if (leftSequence.getNodeSpecification() != rightSequence.getNodeSpecification()
				&& !leftSequence.getNodeSpecification().equals(rightSequence.getNodeSpecification())) {
			throw new IllegalArgumentException(
					"The node specifications of both input sequences to not match!");
		}
		if (leftSequence.getNodes().size() != dp_tables.get(Recurrence.ALI).length - 1) {
			throw new IllegalArgumentException(
					"The given PathMatrix does not fit the given sequences!");
		}
		if (leftSequence.getNodes().size() != compareMatrix.length) {
			throw new IllegalArgumentException(
					"The given compareMatrix does not fit the given sequences!");
		}
		if (leftSequence.getNodes().size() != deletionMatrix.length) {
			throw new IllegalArgumentException(
					"The given deletionMatrix does not fit the given sequences!");
		}
		if (leftSequence.getNodes().size() != skipDeletionMatrix.length) {
			throw new IllegalArgumentException(
					"The given skipDeletionMatrix does not fit the given sequences!");
		}
		if (rightSequence.getNodes().size() != insertionMatrix.length) {
			throw new IllegalArgumentException(
					"The given insertionMatrix does not fit the given sequences!");
		}
		if (rightSequence.getNodes().size() != skipInsertionMatrix.length) {
			throw new IllegalArgumentException(
					"The given skipInsertionMatrix does not fit the given sequences!");
		}
		for (int i = 0; i < dp_tables.get(Recurrence.ALI).length; i++) {
			if (rightSequence.getNodes().size() != dp_tables.get(Recurrence.ALI)[i].length - 1) {
				throw new IllegalArgumentException(
						"The given PathMatrix does not fit the given sequences!");
			}
		}
		for (int i = 0; i < compareMatrix.length; i++) {
			if (rightSequence.getNodes().size() != compareMatrix[i].length) {
				throw new IllegalArgumentException(
						"The given compareMatrix does not fit the given sequences!");
			}
		}

		//check validity of comparators.
		for (int k = 0; k < specification.size(); k++) {
			if (!(specification.getComparator(k) instanceof SkipComparator)) {
				throw new UnsupportedOperationException("The comparator for keyword "
						+ specification.getKeyword(k) + " does not support skips!");
			}
		}
	}

	/**
	 * Returns the parameter defining the "softness" of the alignment. For beta
	 * towards infinity this alignment becomes closer to the strict alignment.
	 * For beta = 0 all possible alignments are equally considered and softmin
	 * returns the average. Please note that a low beta value might lead to a
	 * very rough approximation and that for higher sequence lengths beta has to
	 * be higher, too.
	 *
	 * @return The parameter defining the "softness" of the alignment.
	 */
	public double getBeta() {
		return beta;
	}

	/**
	 * Returns the AlignmentSpecification that was used to compute this
	 * SoftAffinePathModel.
	 *
	 * @return the AlignmentSpecification that was used to compute this
	 * SoftAffinePathModel.
	 */
	public AlignmentSpecification getSpecification() {
		return specification;
	}

	/**
	 * The minimum number of skips that have to be done in the middle
	 * of an alignment. Otherwise only deletions and insertions are allowed.
	 * This is 4 per default.
	 *
	 * @return the minimum number of skips that have to be done in the middle
	 * of an alignment.
	 */
	public int getMinMiddleSkips() {
		return minMiddleSkips;
	}

	/**
	 * Returns the dynamic programming tables that constitute the (soft) affine
	 * alignment of both input sequences. This is mostly for debug purposes.
	 *
	 * @return the dynamic programming tables that constitute the (soft) affine
	 * alignment of both input sequences.
	 */
	public EnumMap getDp_tables() {
		return dp_tables;
	}

	/**
	 * Returns the matrix of comparison costs between both input sequences. The
	 * matrix entry (i,j) contains the cost of replacing node i in the left
	 * input sequence with node j of the right input sequence.
	 *
	 * @return the matrix of comparison costs between both input sequences.
	 */
	public double[][] getCompareMatrix() {
		return compareMatrix;
	}

	/**
	 * Returns the vector of deletion costs. The entry i contains the cost of
	 * deleting node i in the left input sequence.
	 *
	 * @return the vector of deletion costs.
	 */
	public double[] getDeletionMatrix() {
		return deletionMatrix;
	}

	/**
	 * Returns the vector of insertion costs. The entry j contains the cost of
	 * inserting node j of the right input sequence into the left input
	 * sequence.
	 *
	 * @return the vector of insertion costs.
	 */
	public double[] getInsertionMatrix() {
		return insertionMatrix;
	}

	/**
	 * Returns the vector of skip-deletion costs. The entry i contains the cost
	 * of skip-deleting node i in the left input sequence.
	 *
	 * @return the vector of skip-deletion costs.
	 */
	public double[] getSkipDeletionMatrix() {
		return skipDeletionMatrix;
	}

	/**
	 * Returns the vector of skip-insertion costs. The entry j contains the cost
	 * of
	 * skip-inserting node j of the right input sequence into the left input
	 * sequence.
	 *
	 * @return the vector of skip-insertion costs.
	 */
	public double[] getSkipInsertionMatrix() {
		return skipInsertionMatrix;
	}

	/**
	 * {@inheritDoc }
	 */
	@Override
	public double getDistance() {
		return distance;
	}

	/**
	 * {@inheritDoc }
	 */
	@Override
	public Sequence getLeft() {
		return leftSequence;
	}

	/**
	 * {@inheritDoc }
	 */
	@Override
	public Sequence getRight() {
		return rightSequence;
	}

	/**
	 * {@inheritDoc }
	 */
	@Override
	public  double[] calculateRawParameterDerivative(
			DerivableComparator comp, String keyword) {
		final int P = comp.getNumberOfParameters();
		final int k = specification.getKeywordIndex(keyword);
		if (specification.getComparator(k) != comp) {
			throw new UnsupportedOperationException(
					"The given comparator was not used for the given keyword!");
		}
		final double weight = specification.getWeighting()[k];
		if (weight == 0) {
			/*
			 * if the comparator has no weight, we return an empty array without
			 * doing any calculation.
			 */
			return new double[P];
		}

		//support sparsity
		boolean sparse = comp instanceof SparseDerivableComparator;
		final SparseDerivableComparator sparseComp;
		if (sparse) {
			sparseComp = (SparseDerivableComparator) comp;
		} else {
			sparseComp = null;
		}

		//pre-calculation caching
		final int M = leftSequence.getNodes().size();
		final int N = rightSequence.getNodes().size();

		final int origK = specification.getOriginalIndex(k);

		//start of calculation.
        /*
		 * We emulate a recursion by using a stack to make it more efficient, as
		 * we do not have to backtrace the whole matrix but only some paths
		 * within. We can disregard paths with zero probability.
		 */
		final HashMap sparseMatrix = new HashMap<>();
		final Stack calcStack = new Stack();
		//we start at the beginning and trace the alignment to the end.
		LocalMatrixCoordinate current = new LocalMatrixCoordinate(
				Recurrence.SKIPDEL_START, 0, 0);
		calcStack.push(current);
		while (!calcStack.empty()) {
			current = calcStack.pop();

			//if we already calculated this cell, don't do it again.
			if (sparseMatrix.get(current) != null) {
				continue;
			}

			//otherwise start looking according to the current recurrence.
			final ArrayList options = new ArrayList<>();
			final ArrayList operations = new ArrayList<>();
			final ArrayList count = new ArrayList<>();

			if (getCurrentOptions(options, operations, count, current, M, N)) {
				//if this returned true we are at the end of an alignment and can
				//store a zero initialized vector.
				sparseMatrix.put(current, new double[P]);
				continue;
			}

			//do the calculation.
			manageCalculationStep(current, sparseMatrix, comp, origK, weight,
					sparseComp, sparse, P, calcStack,
					options, operations, count);
		}
		return sparseMatrix.get(new LocalMatrixCoordinate(Recurrence.SKIPDEL_START, 0, 0));
	}

	private static class LocalMatrixCoordinate extends MatrixCoordinate {

		public final Recurrence recurrence;

		public LocalMatrixCoordinate(Recurrence recurrence, int i, int j) {
			super(i, j);
			this.recurrence = recurrence;
		}

		@Override
		public int hashCode() {
			return super.hashCode() * Recurrence.values().length + recurrence.ordinal();
		}

		@Override
		public boolean equals(Object obj) {
			if (obj == null) {
				return false;
			}
			if (getClass() != obj.getClass()) {
				return false;
			}
			final LocalMatrixCoordinate other = (LocalMatrixCoordinate) obj;
			if (this.recurrence != other.recurrence) {
				return false;
			}
			if (!super.equals(obj)) {
				return false;
			}
			return true;
		}

		@Override
		public String toString() {
			return recurrence.toString() + super.toString();
		}
	}

	/**
	 * This checks which alignment options could be taken given the current
	 * position in the dynamic programming matrix. It returns true if we are
	 * at the valid end of an alignment.
	 */
	private boolean getCurrentOptions(Collection options,
			Collection operations,
			Collection count,
			LocalMatrixCoordinate current,
			int M, int N) {
		final int i = current.i;
		final int j = current.j;
		switch (current.recurrence) {
			/*
			 * SKIPDEL_START = skip_del(, SKIPDEL_START) |
			 * SKIPINS_START #h;
			 */
			case SKIPDEL_START: {
				if (j > 0) {
					throw new UnsupportedOperationException(
							"Unexpected internal error: "
							+ "SKIPDEL_START was used in an "
							+ "ill-defined region!");
				}
				//skip_del(, SKIPDEL_START)
				if (i < M) {
					/*
					 * if we are not at the end of the first sequence yet,
					 * we can use skipdeletions.
					 */
					options.add(Recurrence.SKIPDEL_START);
					operations.add(OperationType.SKIPDELETION);
					count.add(1);
				}
				//SKIPINS_START
				if (j < N || i == M) {
					/*
					 * If the second sequence is not empty or if we are
					 * at the end of the first sequence, we can switch
					 * to the SKIPINS_START recurrence.
					 */
					options.add(Recurrence.SKIPINS_START);
					operations.add(null);
					count.add(0);
				}
			}
			break;
			/*
			 * SKIPINS_START = skip_ins(, SKIPINS_START) |
			 * rep(, ALI) |
			 * nil() #h;
			 */
			case SKIPINS_START: {
				if (j == N && i < M) {
					throw new UnsupportedOperationException(
							"Unexpected internal error: "
							+ "SKIPINS_START was used in an "
							+ "ill-defined region!");
				}
				//skip_ins(, SKIPINS_START)
				if (j < N - 1 || (j < N && i == M)) {
					/*
					 * if we are not at the end of the second sequence yet,
					 * we can use skipinsertions.
					 */
					options.add(Recurrence.SKIPINS_START);
					operations.add(OperationType.SKIPINSERTION);
					count.add(1);
				}
				//rep(, ALI)
				if (i < M && j < N) {
					/*
					 * if we are neither at the end of the first nor the
					 * second sequence, we can replace.
					 */
					options.add(Recurrence.ALI);
					operations.add(OperationType.REPLACEMENT);
					count.add(1);
				}
				//nil()
				if (i == M && j == N) {
					/*
					 * If we are at the end, we use a zero-initialized
					 * vector.
					 */
					return true;
				}
			}
			break;
			/*
			 * ALI = del(, DEL) | ins(, INS) |
			 * rep(, ALI) |
			 * SKIPDEL_MIDDLE | SKIPINS_MIDDLE | SKIPDEL_END #h;
			 */
			case ALI: {
				if (i < M && j < N) {
					//del(, DEL)
					if (i < M - 1) {
						//if we are not at the end of the first sequence yet, we can delete.
						options.add(Recurrence.DEL);
						operations.add(OperationType.DELETION);
						count.add(1);
					}
					//ins(, INS)
					if (j < N - 1) {
						//if we are not at the end of the second sequence yet, we can insert.
						options.add(Recurrence.INS);
						operations.add(OperationType.INSERTION);
						count.add(1);
					}
					//rep(, ALI)
						/*
					 * if we are neither at the end of the first nor the
					 * second sequence, we can replace.
					 */
					options.add(Recurrence.ALI);
					operations.add(OperationType.REPLACEMENT);
					count.add(1);
					/*
					 * SKIPDEL_MIDDLE or, better said:
					 * skip_del(,
					 * skip_del(,
					 * skip_del(,
					 * skip_del(,
					 * SKIPDEL_MIDDLE_LOOP))))
					 */
					if (i < M - minMiddleSkips) {
						//if we have enough space, we can skip in the middle.
						options.add(Recurrence.SKIPDEL_MIDDLE);
						operations.add(OperationType.SKIPDELETION);
						count.add(minMiddleSkips);
					}
					/*
					 * SKIPINS_MIDDLE or, better said:
					 * skip_ins(,
					 * skip_ins(,
					 * skip_ins(,
					 * skip_ins(,
					 * SKIPINS_MIDDLE_LOOP)))) #h;
					 */
					if (j < N - minMiddleSkips) {
						//if we have enough space, we can skip in the middle.
						options.add(Recurrence.SKIPINS_MIDDLE);
						operations.add(OperationType.SKIPINSERTION);
						count.add(minMiddleSkips);
					}
				}
				//SKIPDEL_END
				//in any case we can start skipping the rest of the sequences.
				options.add(Recurrence.SKIPDEL_END);
				operations.add(null);
				count.add(0);
			}
			break;
			/*
			 * DEL = del(, DEL) | ins(, INS) |
			 * rep(, ALI) #h;
			 */
			case DEL: {
				if (i == M || j == N) {
					throw new UnsupportedOperationException(
							"Unexpected internal error: "
							+ "DEL was used in an "
							+ "ill-defined region!");
				}
				if (i < M && j < N) {
					//del(, DEL)
					if (i < M - 1) {
						//if we are not at the end of the first sequence yet, we can delete.
						options.add(Recurrence.DEL);
						operations.add(OperationType.DELETION);
						count.add(1);
					}
					//ins(, INS)
					if (j < N - 1) {
						//if we are not at the end of the second sequence yet, we can insert.
						options.add(Recurrence.INS);
						operations.add(OperationType.INSERTION);
						count.add(1);
					}
					//rep(, ALI)
						/*
					 * if we are neither at the end of the first nor the
					 * second sequence, we can replace.
					 */
					options.add(Recurrence.ALI);
					operations.add(OperationType.REPLACEMENT);
					count.add(1);
				}
			}
			break;
			/*
			 * INS = ins(, INS) | rep(, ALI) #h;
			 */
			case INS: {
				if (i == M || j == N) {
					throw new UnsupportedOperationException(
							"Unexpected internal error: "
							+ "INS was used in an "
							+ "ill-defined region!");
				}
				if (i < M && j < N) {
					//ins(, INS)
					if (j < N - 1) {
						//if we are not at the end of the second sequence yet, we can insert.
						options.add(Recurrence.INS);
						operations.add(OperationType.INSERTION);
						count.add(1);
					}
					//rep(, ALI)
						/*
					 * if we are neither at the end of the first nor the
					 * second sequence, we can replace.
					 */
					options.add(Recurrence.ALI);
					operations.add(OperationType.REPLACEMENT);
					count.add(1);
				}
			}
			break;
			/*
			 * SKIPDEL_MIDDLE_LOOP = skip_del(,
			 * SKIPDEL_MIDDLE_LOOP) |
			 * rep(,ALI) #h;
			 */
			case SKIPDEL_MIDDLE: {
				if (i == M || j == N) {
					throw new UnsupportedOperationException(
							"Unexpected internal error: "
							+ "SKIPDEL_MIDDLE was used in an "
							+ "ill-defined region!");
				}
				//skip_del(, SKIPDEL_MIDDLE_LOOP)
				if (i < M && j < N) {
					if (i < M - 1) {
						/*
						 * if we are not at the end of the first sequence
						 * yet,
						 * we can use skipdeletions.
						 */
						options.add(Recurrence.SKIPDEL_MIDDLE);
						operations.add(OperationType.SKIPDELETION);
						count.add(1);
					}
					//rep(, ALI)
						/*
					 * if we are neither at the end of the first nor the
					 * second sequence, we can replace.
					 */
					options.add(Recurrence.ALI);
					operations.add(OperationType.REPLACEMENT);
					count.add(1);
				}
			}
			break;
			/*
			 * SKIPINS_MIDDLE_LOOP = skip_ins(,
			 * SKIPINS_MIDDLE_LOOP) |
			 * rep(,ALI) #h;
			 */
			case SKIPINS_MIDDLE: {
				if (i == M || j == N) {
					throw new UnsupportedOperationException(
							"Unexpected internal error: "
							+ "SKIPINS_MIDDLE was used in an "
							+ "ill-defined region!");
				}
				//skip_ins(, SKIPINS_MIDDLE_LOOP)
				if (i < M && j < N) {
					if (j < N - 1) {
						/*
						 * if we are not at the end of the second sequence
						 * yet,
						 * we can use skipinsertions.
						 */
						options.add(Recurrence.SKIPINS_MIDDLE);
						operations.add(OperationType.SKIPINSERTION);
						count.add(1);
					}
					//rep(, ALI)
						/*
					 * if we are neither at the end of the first nor the
					 * second sequence, we can replace.
					 */
					options.add(Recurrence.ALI);
					operations.add(OperationType.REPLACEMENT);
					count.add(1);
				}
			}
			break;
			/*
			 * SKIPDEL_END = skip_del(, SKIPDEL_END) |
			 * SKIPINS_END #h;
			 */
			case SKIPDEL_END: {
				//skip_del(, SKIPDEL_START)
				if (i < M) {
					/*
					 * if we are not at the end of the first sequence yet,
					 * we can use skipdeletions.
					 */
					options.add(Recurrence.SKIPDEL_END);
					operations.add(OperationType.SKIPDELETION);
					count.add(1);
				} else {
					//SKIPINS_START
						/*
					 * If we are at the end of the first sequence, we copy
					 * the SKIPINS_END entry.
					 */
					options.add(Recurrence.SKIPINS_END);
					operations.add(null);
					count.add(0);
				}
			}
			break;
			/*
			 *
			 * SKIPINS_END = skip_ins(, SKIPINS_END) |
			 * nil() #h;
			 */
			case SKIPINS_END: {
				if (i < M) {
					throw new UnsupportedOperationException(
							"Unexpected internal error: "
							+ "SKIPINS_END was used in an "
							+ "ill-defined region!");
				}
				//skip_ins(, SKIPINS_END)
				if (j < N) {
					/*
					 * if we are not at the end of the second sequence yet,
					 * we can use skipinsertions.
					 */
					options.add(Recurrence.SKIPINS_END);
					operations.add(OperationType.SKIPINSERTION);
					count.add(1);
				} else {
					//nil()
						/*
					 * If we are at the end, we use a zero-initialized
					 * vector.
					 */
					return true;
				}
			}
			break;
			default:
				throw new UnsupportedOperationException(
						"Unknown recurrence! " + current);
		}
		return false;
	}

	/*
	 * This manages to calculate a step in the soft parameter derivative
	 * calculation. It searches for all pre-requisites for the current
	 * calculation step and puts them on the calculation stack if they are
	 * not yet calculated. If they are the current matrix entry is calculated.
	 */
	private  void manageCalculationStep(LocalMatrixCoordinate current,
			HashMap sparseMatrix,
			DerivableComparator comp, int origK, double weight,
			SparseDerivableComparator sparseComp, boolean sparse, int P,
			final Stack calcStack,
			Collection optionsCollection,
			Collection operationsCollection,
			Collection countCollection) {
		final int OPNUM = optionsCollection.size();
		if (OPNUM == 0) {
			throw new UnsupportedOperationException("Unexpected internal error! "
					+ "No alignment choices!");
		}
		//transform input collections to arrays.
		final Recurrence[] options = optionsCollection.toArray(new Recurrence[OPNUM]);
		final OperationType[] operations = operationsCollection.toArray(new OperationType[OPNUM]);
		final int[] counts = new int[OPNUM];
		{
			int o = 0;
			for (Integer count : countCollection) {
				counts[o] = count;
				o++;
			}
		}
		//get coordinates.
		final int i = current.i;
		final int j = current.j;

		/*
		 * calculate softmin derivatives.
		 */
		final double[] softminDerivs = calculateSoftminDerivatives(
				current, options, operations, counts);
		/*
		 * Look if all necessary pre-requisites have been
		 * calculated.
		 */
		boolean pushedBack = false;
		for (int o = 0; o < options.length; o++) {
			if (softminDerivs[o] > 0) {
				final int iOld;
				final int jOld;
				if (operations[o] != null) {
					switch (operations[o]) {
						case DELETION:
						case SKIPDELETION:
							iOld = i + counts[o];
							jOld = j;
							break;
						case INSERTION:
						case SKIPINSERTION:
							iOld = i;
							jOld = j + counts[o];
							break;
						case REPLACEMENT:
							iOld = i + counts[o];
							jOld = j + counts[o];
							break;
						default:
							throw new UnsupportedOperationException(
									"Unsupported operation: " + operations[o]);
					}
				} else {
					iOld = i;
					jOld = j;
				}
				final LocalMatrixCoordinate oldCoord
						= new LocalMatrixCoordinate(
								options[o], iOld, jOld);
				if (!sparseMatrix.containsKey(oldCoord)) {
					//if not we calculate that first.
					if (!pushedBack) {
						calcStack.add(current);
						pushedBack = true;
					}
					calcStack.add(oldCoord);
				}
			}
		}
		//if we have everything, calculate the result.
		if (!pushedBack) {
			calculateLocalSoftminParameterDerivative(current,
					sparseMatrix, comp, origK, weight, sparseComp, sparse, P,
					softminDerivs,
					options, operations, counts);
		}

	}

	/**
	 * This calculates the term
	 *
	 * softmin'(o) := (1 - beta * (cost(o) - softmin)) * p_o
	 *
	 * for each alignment operation. More details are to be found in the
	 * Softmin class.
	 *
	 */
	private double[] calculateSoftminDerivatives(
			LocalMatrixCoordinate current,
			Recurrence[] options, OperationType[] operations, int[] count) {
		final int i = current.i;
		final int j = current.j;
		//calculate the costs for each option.
		final double[] costs = new double[options.length];
		for (int o = 0; o < options.length; o++) {
			if (operations[o] == null) {
				//if there is no operation we can just copy the old cost.
				costs[o] = dp_tables.get(options[o])[i][j];
			} else {
				//find out the direction in which we moved in the dynamic programming matrix.
				final int i2;
				final int j2;
				switch (operations[o]) {
					case DELETION:
					case SKIPDELETION:
						i2 = count[o];
						j2 = 0;
						break;
					case INSERTION:
					case SKIPINSERTION:
						i2 = 0;
						j2 = count[o];
						break;
					case REPLACEMENT:
						i2 = count[o];
						j2 = count[o];
						break;
					default:
						throw new UnsupportedOperationException("Unsupported operation: " + operations[o]);
				}

				//find out the local cost.
				double localCost = 0;
				for (int s = 0; s < count[o]; s++) {
					switch (operations[o]) {
						case DELETION:
							localCost += deletionMatrix[i + s];
							break;
						case SKIPDELETION:
							localCost += skipDeletionMatrix[i + s];
							break;
						case INSERTION:
							localCost += insertionMatrix[j + s];
							break;
						case SKIPINSERTION:
							localCost += skipInsertionMatrix[j + s];
							break;
						case REPLACEMENT:
							localCost += compareMatrix[i + s][j + s];
							break;
						default:
							throw new UnsupportedOperationException("Unsupported operation: " + operations[o]);
					}
				}
				costs[o] = dp_tables.get(options[o])[i + i2][j + j2] + localCost;
			}
		}
		//calculate softmin'
		final double[] softminDerivs = Softmin.calculateSoftminDerivatives(beta, costs);
		//and return it.
		return softminDerivs;
	}

	/**
	 * This calculates the term:
	 *
	 * \Sum_o softmin'(o) * (recursion(o,p) + localDerivative(o,p))
	 *
	 * for each parameter p.
	 *
	 */
	private  void calculateLocalSoftminParameterDerivative(
			LocalMatrixCoordinate current,
			HashMap sparseMatrix,
			DerivableComparator comp, int origK, double weight,
			SparseDerivableComparator sparseComp, boolean sparse, int P,
			double[] softminDerivs,
			Recurrence[] options, OperationType[] operations, int[] count) {
		final int i = current.i;
		final int j = current.j;
		final double[] newDeriv = new double[P];
		//the local derivatives for each option.
		final double[][] unweightedDerivs = new double[options.length][P];
		for (int o = 0; o < options.length; o++) {
			if (softminDerivs[o] <= 0) {
				continue;
			}
			if (operations[o] == null) {
				//if there is no operation just copy the old derivative.
				unweightedDerivs[o] = sparseMatrix.get(
						new LocalMatrixCoordinate(options[o], i, j));
			} else {
				//find out the direction in which we moved in the dynamic programming matrix.
				final int i2;
				final int j2;
				switch (operations[o]) {
					case DELETION:
					case SKIPDELETION:
						i2 = count[o];
						j2 = 0;
						break;
					case INSERTION:
					case SKIPINSERTION:
						i2 = 0;
						j2 = count[o];
						break;
					case REPLACEMENT:
						i2 = count[o];
						j2 = count[o];
						break;
					default:
						throw new UnsupportedOperationException("Unsupported operation: " + operations[o]);
				}

				final double[] tmp = new double[P];

				//calculate the local derivative.
				for (int s = 0; s < count[o]; s++) {
					switch (operations[o]) {
						case DELETION:
						case SKIPDELETION:
							calculateLocalDerivative(comp, sparseComp, sparse,
									(X) leftSequence.getNodes().get(i + s).getValue(origK),
									null,
									operations[o], tmp, P);
							break;
						case INSERTION:
						case SKIPINSERTION:
							calculateLocalDerivative(comp, sparseComp, sparse,
									null,
									(X) rightSequence.getNodes().get(j + s).getValue(origK),
									operations[o], tmp, P);
							break;
						case REPLACEMENT:
							calculateLocalDerivative(comp, sparseComp, sparse,
									(X) leftSequence.getNodes().get(i + s).getValue(origK),
									(X) rightSequence.getNodes().get(j + s).getValue(origK),
									operations[o], tmp, P);
							break;
						default:
							throw new UnsupportedOperationException("Unsupported operation: " + operations[o]);
					}
					for (int p = 0; p < P; p++) {
						unweightedDerivs[o][p] += tmp[p];
					}
				}
				//multiply with the weight.
				for (int p = 0; p < P; p++) {
					unweightedDerivs[o][p] *= weight;
				}
				//add the respective old value.
				final double[] old = sparseMatrix.get(
						new LocalMatrixCoordinate(options[o], i + i2, j + j2));
				for (int p = 0; p < P; p++) {
					unweightedDerivs[o][p] += old[p];
				}
			}
		}
		//calculate the new soft derivative for the parameters.
		for (int o = 0; o < options.length; o++) {
			if (softminDerivs[o] > 0) {
				for (int p = 0; p < P; p++) {
					newDeriv[p] += softminDerivs[o] * unweightedDerivs[o][p];
				}
			}
		}
		//store it in the sparse matrix.
		sparseMatrix.put(current, newDeriv);
	}

	private static  void calculateLocalDerivative(
			DerivableComparator comp,
			SparseDerivableComparator sparseComp, boolean sparse,
			X leftVal, X rightVal, OperationType type,
			double[] localDeriv, int P) {
		//if we do not have a sparse comparator, we iterate over all parameters.
		if (!sparse) {
			for (int p = 0; p < P; p++) {
				localDeriv[p] = comp.calculateLocalDerivative(p, leftVal, rightVal, type);
			}
		} else {
			//clean local derivative.
			for (int p = 0; p < P; p++) {
				localDeriv[p] = 0;
			}
			//otherwise use sparsity
			final Iterator it = sparseComp.
					calculateSparseLocalDerivative(leftVal, rightVal,
							type).iterator();
			while (it.hasNext()) {
				final SparseLocalDerivative.SparseDeriativeEntry sparseLocalDerivative = it.next();
				localDeriv[sparseLocalDerivative.getParameterIndex()]
						= sparseLocalDerivative.getDerivative();
			}
		}
	}

	/**
	 * {@inheritDoc }
	 */
	@Override
	public  Y calculateParameterDerivative(DerivableComparator comp, String keyword) {
		return comp.transformToResult(calculateRawParameterDerivative(comp, keyword));
	}

	/**
	 * {@inheritDoc }
	 */
	@Override
	public double[] calculateWeightDerivative() {
		final int K = specification.size();
		final int M = leftSequence.getNodes().size();
		final int N = rightSequence.getNodes().size();

		//start of calculation.
        /*
		 * We emulate a recursion by using a stack to make it more efficient, as
		 * we do not have to backtrace the whole matrix but only some paths
		 * within. We can disregard paths with zero probability.
		 */
		final HashMap sparseMatrix = new HashMap<>();
		final Stack calcStack = new Stack();
		//we start at the beginning and trace the alignment to the end
		LocalMatrixCoordinate current = new LocalMatrixCoordinate(
				Recurrence.SKIPDEL_START, 0, 0);
		calcStack.push(current);
		while (!calcStack.empty()) {
			current = calcStack.pop();

			//if we already calculated this cell, don't do it again.
			if (sparseMatrix.get(current) != null) {
				continue;
			}

			//otherwise start looking according to the current recurrence.
			final ArrayList options = new ArrayList<>();
			final ArrayList operations = new ArrayList<>();
			final ArrayList count = new ArrayList<>();
			if (getCurrentOptions(options, operations, count, current, M, N)) {
				//if this returned true we are at the end of an alignment and can
				//store a zero initialized vector.
				sparseMatrix.put(current, new double[K]);
				continue;
			}
			//do the calculation.
			manageWeightCalculationStep(current, sparseMatrix, K, calcStack,
					options, operations, count);
		}
		return sparseMatrix.get(new LocalMatrixCoordinate(Recurrence.SKIPDEL_START, 0, 0));
	}

	/*
	 * This manages to calculate a step in the soft weight derivative
	 * calculation. It searches for all pre-requisites for the current
	 * calculation step and puts them on the calculation stack if they are
	 * not yet calculated. If they are the current matrix entry is calculated.
	 */
	private void manageWeightCalculationStep(LocalMatrixCoordinate current,
			HashMap sparseMatrix, int K,
			final Stack calcStack,
			Collection optionsCollection,
			Collection operationsCollection,
			Collection countCollection) {
		final int OPNUM = optionsCollection.size();
		if (OPNUM == 0) {
			throw new UnsupportedOperationException("Unexpected internal error! "
					+ "No alignment choices!");
		}
		//transform input collections to arrays.
		final Recurrence[] options = optionsCollection.toArray(new Recurrence[OPNUM]);
		final OperationType[] operations = operationsCollection.toArray(new OperationType[OPNUM]);
		final int[] counts = new int[OPNUM];
		{
			int o = 0;
			for (Integer count : countCollection) {
				counts[o] = count;
				o++;
			}
		}
		//get coordinates.
		final int i = current.i;
		final int j = current.j;

		/*
		 * calculate softmin derivatives.
		 */
		final double[] softminDerivs = calculateSoftminDerivatives(
				current, options, operations, counts);
		/*
		 * Look if all necessary pre-requisites have been
		 * calculated.
		 */
		boolean pushedBack = false;
		for (int o = 0; o < options.length; o++) {
			if (softminDerivs[o] > 0) {
				final int iOld;
				final int jOld;
				if (operations[o] != null) {
					switch (operations[o]) {
						case DELETION:
						case SKIPDELETION:
							iOld = i + counts[o];
							jOld = j;
							break;
						case INSERTION:
						case SKIPINSERTION:
							iOld = i;
							jOld = j + counts[o];
							break;
						case REPLACEMENT:
							iOld = i + counts[o];
							jOld = j + counts[o];
							break;
						default:
							throw new UnsupportedOperationException(
									"Unsupported operation: " + operations[o]);
					}
				} else {
					iOld = i;
					jOld = j;
				}
				final LocalMatrixCoordinate oldCoord
						= new LocalMatrixCoordinate(
								options[o], iOld, jOld);
				if (!sparseMatrix.containsKey(oldCoord)) {
					//if not we calculate that first.
					if (!pushedBack) {
						calcStack.add(current);
						pushedBack = true;
					}
					calcStack.add(oldCoord);
				}
			}
		}
		//if we have everything, calculate the result.
		if (!pushedBack) {
			calculateLocalSoftminWeightDerivative(current, sparseMatrix,
					softminDerivs, K,
					options, operations, counts);
		}

	}

	/*
	 * This calculates the term
	 *
	 * sum_o softmin'(o) * (recursion(o,w) + local_weight_derivative(o,w))
	 *
	 * for each keyword weight w. The local weight derivative is equivalent to
	 * the
	 */
	private  void calculateLocalSoftminWeightDerivative(
			LocalMatrixCoordinate current,
			HashMap sparseMatrix,
			double[] softminDerivs, int K,
			Recurrence[] options, OperationType[] operations, int[] count) {
		final int i = current.i;
		final int j = current.j;
		final double[] newDeriv = new double[K];
		//the local derivatives for each option.
		final double[][] unweightedDerivs = new double[options.length][K];
		for (int o = 0; o < options.length; o++) {
			if (softminDerivs[o] <= 0) {
				continue;
			}
			if (operations[o] == null) {
				//if there is no operation just copy the old derivative.
				unweightedDerivs[o] = sparseMatrix.get(
						new LocalMatrixCoordinate(options[o], i, j));
			} else {
				//find out the direction in which we moved in the dynamic programming matrix.
				final int i2;
				final int j2;
				switch (operations[o]) {
					case DELETION:
					case SKIPDELETION:
						i2 = count[o];
						j2 = 0;
						break;
					case INSERTION:
					case SKIPINSERTION:
						i2 = 0;
						j2 = count[o];
						break;
					case REPLACEMENT:
						i2 = count[o];
						j2 = count[o];
						break;
					default:
						throw new UnsupportedOperationException("Unsupported operation: " + operations[o]);
				}

				//calculate the local derivative, which is equal to the local
				//operation costs for each keyword.
				for (int s = 0; s < count[o]; s++) {
					final double[] localCosts;
					switch (operations[o]) {
						case DELETION:
							localCosts = specification.calculateDeletionCosts(
									leftSequence.getNodes().get(i + s));
							break;
						case SKIPDELETION:
							localCosts = specification.calculateSkipDeletionCosts(
									leftSequence.getNodes().get(i + s));
							break;
						case INSERTION:
							localCosts = specification.calculateInsertionCosts(
									rightSequence.getNodes().get(j + s));
							break;
						case SKIPINSERTION:
							localCosts = specification.calculateSkipInsertionCosts(
									rightSequence.getNodes().get(j + s));
							break;
						case REPLACEMENT:
							localCosts = specification.calculateReplacementCosts(
									leftSequence.getNodes().get(i + s),
									rightSequence.getNodes().get(j + s));
							break;
						default:
							throw new UnsupportedOperationException("Unsupported operation: " + operations[o]);
					}
					for (int k = 0; k < K; k++) {
						unweightedDerivs[o][k] += localCosts[k];
					}
				}
				//add the respective old value.
				final double[] old = sparseMatrix.get(
						new LocalMatrixCoordinate(options[o], i + i2, j + j2));
				for (int k = 0; k < K; k++) {
					unweightedDerivs[o][k] += old[k];
				}
			}
		}
		//calculate the new soft derivative for the parameters.
		for (int o = 0; o < options.length; o++) {
			if (softminDerivs[o] > 0) {
				for (int k = 0; k < K; k++) {
					newDeriv[k] += softminDerivs[o] * unweightedDerivs[o][k];
				}
			}
		}
		//store it in the sparse matrix.
		sparseMatrix.put(current, newDeriv);
	}

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy