All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.classifier.sequencelearning.hmm.HmmTrainer Maven / Gradle / Ivy

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.classifier.sequencelearning.hmm;

import java.util.Collection;
import java.util.Iterator;

import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;

/**
 * Class containing several algorithms used to train a Hidden Markov Model. The
 * three main algorithms are: supervised learning, unsupervised Viterbi and
 * unsupervised Baum-Welch.
 */
public final class HmmTrainer {

  /**
   * No public constructor for utility classes.
   */
  private HmmTrainer() {
    // nothing to do here really.
  }

  /**
   * Create an supervised initial estimate of an HMM Model based on a sequence
   * of observed and hidden states.
   *
   * @param nrOfHiddenStates The total number of hidden states
   * @param nrOfOutputStates The total number of output states
   * @param observedSequence Integer array containing the observed sequence
   * @param hiddenSequence   Integer array containing the hidden sequence
   * @param pseudoCount      Value that is assigned to non-occurring transitions to avoid zero
   *                         probabilities.
   * @return An initial model using the estimated parameters
   */
  public static HmmModel trainSupervised(int nrOfHiddenStates, int nrOfOutputStates, int[] observedSequence,
      int[] hiddenSequence, double pseudoCount) {
    // make sure the pseudo count is not zero
    pseudoCount = pseudoCount == 0 ? Double.MIN_VALUE : pseudoCount;

    // initialize the parameters
    DenseMatrix transitionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfHiddenStates);
    DenseMatrix emissionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfOutputStates);
    // assign a small initial probability that is larger than zero, so
    // unseen states will not get a zero probability
    transitionMatrix.assign(pseudoCount);
    emissionMatrix.assign(pseudoCount);
    // given no prior knowledge, we have to assume that all initial hidden
    // states are equally likely
    DenseVector initialProbabilities = new DenseVector(nrOfHiddenStates);
    initialProbabilities.assign(1.0 / nrOfHiddenStates);

    // now loop over the sequences to count the number of transitions
    countTransitions(transitionMatrix, emissionMatrix, observedSequence,
        hiddenSequence);

    // make sure that probabilities are normalized
    for (int i = 0; i < nrOfHiddenStates; i++) {
      // compute sum of probabilities for current row of transition matrix
      double sum = 0;
      for (int j = 0; j < nrOfHiddenStates; j++) {
        sum += transitionMatrix.getQuick(i, j);
      }
      // normalize current row of transition matrix
      for (int j = 0; j < nrOfHiddenStates; j++) {
        transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j) / sum);
      }
      // compute sum of probabilities for current row of emission matrix
      sum = 0;
      for (int j = 0; j < nrOfOutputStates; j++) {
        sum += emissionMatrix.getQuick(i, j);
      }
      // normalize current row of emission matrix
      for (int j = 0; j < nrOfOutputStates; j++) {
        emissionMatrix.setQuick(i, j, emissionMatrix.getQuick(i, j) / sum);
      }
    }

    // return a new model using the parameter estimations
    return new HmmModel(transitionMatrix, emissionMatrix, initialProbabilities);
  }

  /**
   * Function that counts the number of state->state and state->output
   * transitions for the given observed/hidden sequence.
   *
   * @param transitionMatrix transition matrix to use.
   * @param emissionMatrix emission matrix to use for counting.
   * @param observedSequence observation sequence to use.
   * @param hiddenSequence sequence of hidden states to use.
   */
  private static void countTransitions(Matrix transitionMatrix,
                                       Matrix emissionMatrix, int[] observedSequence, int[] hiddenSequence) {
    emissionMatrix.setQuick(hiddenSequence[0], observedSequence[0],
        emissionMatrix.getQuick(hiddenSequence[0], observedSequence[0]) + 1);
    for (int i = 1; i < observedSequence.length; ++i) {
      transitionMatrix
          .setQuick(hiddenSequence[i - 1], hiddenSequence[i], transitionMatrix
              .getQuick(hiddenSequence[i - 1], hiddenSequence[i]) + 1);
      emissionMatrix.setQuick(hiddenSequence[i], observedSequence[i],
          emissionMatrix.getQuick(hiddenSequence[i], observedSequence[i]) + 1);
    }
  }

  /**
   * Create an supervised initial estimate of an HMM Model based on a number of
   * sequences of observed and hidden states.
   *
   * @param nrOfHiddenStates The total number of hidden states
   * @param nrOfOutputStates The total number of output states
   * @param hiddenSequences Collection of hidden sequences to use for training
   * @param observedSequences Collection of observed sequences to use for training associated with hidden sequences.
   * @param pseudoCount      Value that is assigned to non-occurring transitions to avoid zero
   *                         probabilities.
   * @return An initial model using the estimated parameters
   */
  public static HmmModel trainSupervisedSequence(int nrOfHiddenStates,
                                                 int nrOfOutputStates, Collection hiddenSequences,
                                                 Collection observedSequences, double pseudoCount) {

    // make sure the pseudo count is not zero
    pseudoCount = pseudoCount == 0 ? Double.MIN_VALUE : pseudoCount;

    // initialize parameters
    DenseMatrix transitionMatrix = new DenseMatrix(nrOfHiddenStates,
        nrOfHiddenStates);
    DenseMatrix emissionMatrix = new DenseMatrix(nrOfHiddenStates,
        nrOfOutputStates);
    DenseVector initialProbabilities = new DenseVector(nrOfHiddenStates);

    // assign pseudo count to avoid zero probabilities
    transitionMatrix.assign(pseudoCount);
    emissionMatrix.assign(pseudoCount);
    initialProbabilities.assign(pseudoCount);

    // now loop over the sequences to count the number of transitions
    Iterator hiddenSequenceIt = hiddenSequences.iterator();
    Iterator observedSequenceIt = observedSequences.iterator();
    while (hiddenSequenceIt.hasNext() && observedSequenceIt.hasNext()) {
      // fetch the current set of sequences
      int[] hiddenSequence = hiddenSequenceIt.next();
      int[] observedSequence = observedSequenceIt.next();
      // increase the count for initial probabilities
      initialProbabilities.setQuick(hiddenSequence[0], initialProbabilities
          .getQuick(hiddenSequence[0]) + 1);
      countTransitions(transitionMatrix, emissionMatrix, observedSequence,
          hiddenSequence);
    }

    // make sure that probabilities are normalized
    double isum = 0; // sum of initial probabilities
    for (int i = 0; i < nrOfHiddenStates; i++) {
      isum += initialProbabilities.getQuick(i);
      // compute sum of probabilities for current row of transition matrix
      double sum = 0;
      for (int j = 0; j < nrOfHiddenStates; j++) {
        sum += transitionMatrix.getQuick(i, j);
      }
      // normalize current row of transition matrix
      for (int j = 0; j < nrOfHiddenStates; j++) {
        transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j) / sum);
      }
      // compute sum of probabilities for current row of emission matrix
      sum = 0;
      for (int j = 0; j < nrOfOutputStates; j++) {
        sum += emissionMatrix.getQuick(i, j);
      }
      // normalize current row of emission matrix
      for (int j = 0; j < nrOfOutputStates; j++) {
        emissionMatrix.setQuick(i, j, emissionMatrix.getQuick(i, j) / sum);
      }
    }
    // normalize the initial probabilities
    for (int i = 0; i < nrOfHiddenStates; ++i) {
      initialProbabilities.setQuick(i, initialProbabilities.getQuick(i) / isum);
    }

    // return a new model using the parameter estimates
    return new HmmModel(transitionMatrix, emissionMatrix, initialProbabilities);
  }

  /**
   * Iteratively train the parameters of the given initial model wrt to the
   * observed sequence using Viterbi training.
   *
   * @param initialModel     The initial model that gets iterated
   * @param observedSequence The sequence of observed states
   * @param pseudoCount      Value that is assigned to non-occurring transitions to avoid zero
   *                         probabilities.
   * @param epsilon          Convergence criteria
   * @param maxIterations    The maximum number of training iterations
   * @param scaled           Use Log-scaled implementation, this is computationally more
   *                         expensive but offers better numerical stability for large observed
   *                         sequences
   * @return The iterated model
   */
  public static HmmModel trainViterbi(HmmModel initialModel,
                                      int[] observedSequence, double pseudoCount, double epsilon,
                                      int maxIterations, boolean scaled) {

    // make sure the pseudo count is not zero
    pseudoCount = pseudoCount == 0 ? Double.MIN_VALUE : pseudoCount;

    // allocate space for iteration models
    HmmModel lastIteration = initialModel.clone();
    HmmModel iteration = initialModel.clone();

    // allocate space for Viterbi path calculation
    int[] viterbiPath = new int[observedSequence.length];
    int[][] phi = new int[observedSequence.length - 1][initialModel
        .getNrOfHiddenStates()];
    double[][] delta = new double[observedSequence.length][initialModel
        .getNrOfHiddenStates()];

    // now run the Viterbi training iteration
    for (int i = 0; i < maxIterations; ++i) {
      // compute the Viterbi path
      HmmAlgorithms.viterbiAlgorithm(viterbiPath, delta, phi, lastIteration,
          observedSequence, scaled);
      // Viterbi iteration uses the viterbi path to update
      // the probabilities
      Matrix emissionMatrix = iteration.getEmissionMatrix();
      Matrix transitionMatrix = iteration.getTransitionMatrix();

      // first, assign the pseudo count
      emissionMatrix.assign(pseudoCount);
      transitionMatrix.assign(pseudoCount);

      // now count the transitions
      countTransitions(transitionMatrix, emissionMatrix, observedSequence,
          viterbiPath);

      // and normalize the probabilities
      for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) {
        double sum = 0;
        // normalize the rows of the transition matrix
        for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) {
          sum += transitionMatrix.getQuick(j, k);
        }
        for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) {
          transitionMatrix
              .setQuick(j, k, transitionMatrix.getQuick(j, k) / sum);
        }
        // normalize the rows of the emission matrix
        sum = 0;
        for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) {
          sum += emissionMatrix.getQuick(j, k);
        }
        for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) {
          emissionMatrix.setQuick(j, k, emissionMatrix.getQuick(j, k) / sum);
        }
      }
      // check for convergence
      if (checkConvergence(lastIteration, iteration, epsilon)) {
        break;
      }
      // overwrite the last iterated model by the new iteration
      lastIteration.assign(iteration);
    }
    // we are done :)
    return iteration;
  }

  /**
   * Iteratively train the parameters of the given initial model wrt the
   * observed sequence using Baum-Welch training.
   *
   * @param initialModel     The initial model that gets iterated
   * @param observedSequence The sequence of observed states
   * @param epsilon          Convergence criteria
   * @param maxIterations    The maximum number of training iterations
   * @param scaled           Use log-scaled implementations of forward/backward algorithm. This
   *                         is computationally more expensive, but offers better numerical
   *                         stability for long output sequences.
   * @return The iterated model
   */
  public static HmmModel trainBaumWelch(HmmModel initialModel,
                                        int[] observedSequence, double epsilon, int maxIterations, boolean scaled) {
    // allocate space for the iterations
    HmmModel lastIteration = initialModel.clone();
    HmmModel iteration = initialModel.clone();

    // allocate space for baum-welch factors
    int hiddenCount = initialModel.getNrOfHiddenStates();
    int visibleCount = observedSequence.length;
    Matrix alpha = new DenseMatrix(visibleCount, hiddenCount);
    Matrix beta = new DenseMatrix(visibleCount, hiddenCount);

    // now run the baum Welch training iteration
    for (int it = 0; it < maxIterations; ++it) {
      // fetch emission and transition matrix of current iteration
      Vector initialProbabilities = iteration.getInitialProbabilities();
      Matrix emissionMatrix = iteration.getEmissionMatrix();
      Matrix transitionMatrix = iteration.getTransitionMatrix();

      // compute forward and backward factors
      HmmAlgorithms.forwardAlgorithm(alpha, iteration, observedSequence, scaled);
      HmmAlgorithms.backwardAlgorithm(beta, iteration, observedSequence, scaled);

      if (scaled) {
        logScaledBaumWelch(observedSequence, iteration, alpha, beta);
      } else {
        unscaledBaumWelch(observedSequence, iteration, alpha, beta);
      }
      // normalize transition/emission probabilities
      // and normalize the probabilities
      double isum = 0;
      for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) {
        double sum = 0;
        // normalize the rows of the transition matrix
        for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) {
          sum += transitionMatrix.getQuick(j, k);
        }
        for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) {
          transitionMatrix
              .setQuick(j, k, transitionMatrix.getQuick(j, k) / sum);
        }
        // normalize the rows of the emission matrix
        sum = 0;
        for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) {
          sum += emissionMatrix.getQuick(j, k);
        }
        for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) {
          emissionMatrix.setQuick(j, k, emissionMatrix.getQuick(j, k) / sum);
        }
        // normalization parameter for initial probabilities
        isum += initialProbabilities.getQuick(j);
      }
      // normalize initial probabilities
      for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
        initialProbabilities.setQuick(i, initialProbabilities.getQuick(i)
            / isum);
      }
      // check for convergence
      if (checkConvergence(lastIteration, iteration, epsilon)) {
        break;
      }
      // overwrite the last iterated model by the new iteration
      lastIteration.assign(iteration);
    }
    // we are done :)
    return iteration;
  }

  private static void unscaledBaumWelch(int[] observedSequence, HmmModel iteration, Matrix alpha, Matrix beta) {
    Vector initialProbabilities = iteration.getInitialProbabilities();
    Matrix emissionMatrix = iteration.getEmissionMatrix();
    Matrix transitionMatrix = iteration.getTransitionMatrix();
    double modelLikelihood = HmmEvaluator.modelLikelihood(alpha, false);

    for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
      initialProbabilities.setQuick(i, alpha.getQuick(0, i)
          * beta.getQuick(0, i));
    }

    // recompute transition probabilities
    for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
      for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) {
        double temp = 0;
        for (int t = 0; t < observedSequence.length - 1; ++t) {
          temp += alpha.getQuick(t, i)
              * emissionMatrix.getQuick(j, observedSequence[t + 1])
              * beta.getQuick(t + 1, j);
        }
        transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j)
            * temp / modelLikelihood);
      }
    }
    // recompute emission probabilities
    for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
      for (int j = 0; j < iteration.getNrOfOutputStates(); ++j) {
        double temp = 0;
        for (int t = 0; t < observedSequence.length; ++t) {
          // delta tensor
          if (observedSequence[t] == j) {
            temp += alpha.getQuick(t, i) * beta.getQuick(t, i);
          }
        }
        emissionMatrix.setQuick(i, j, temp / modelLikelihood);
      }
    }
  }

  private static void logScaledBaumWelch(int[] observedSequence, HmmModel iteration, Matrix alpha, Matrix beta) {
    Vector initialProbabilities = iteration.getInitialProbabilities();
    Matrix emissionMatrix = iteration.getEmissionMatrix();
    Matrix transitionMatrix = iteration.getTransitionMatrix();
    double modelLikelihood = HmmEvaluator.modelLikelihood(alpha, true);

    for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
      initialProbabilities.setQuick(i, Math.exp(alpha.getQuick(0, i) + beta.getQuick(0, i)));
    }

    // recompute transition probabilities
    for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
      for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) {
        double sum = Double.NEGATIVE_INFINITY; // log(0)
        for (int t = 0; t < observedSequence.length - 1; ++t) {
          double temp = alpha.getQuick(t, i)
              + Math.log(emissionMatrix.getQuick(j, observedSequence[t + 1]))
              + beta.getQuick(t + 1, j);
          if (temp > Double.NEGATIVE_INFINITY) {
            // handle 0-probabilities
            sum = temp + Math.log1p(Math.exp(sum - temp));
          }
        }
        transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j)
            * Math.exp(sum - modelLikelihood));
      }
    }
    // recompute emission probabilities
    for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
      for (int j = 0; j < iteration.getNrOfOutputStates(); ++j) {
        double sum = Double.NEGATIVE_INFINITY; // log(0)
        for (int t = 0; t < observedSequence.length; ++t) {
          // delta tensor
          if (observedSequence[t] == j) {
            double temp = alpha.getQuick(t, i) + beta.getQuick(t, i);
            if (temp > Double.NEGATIVE_INFINITY) {
              // handle 0-probabilities
              sum = temp + Math.log1p(Math.exp(sum - temp));
            }
          }
        }
        emissionMatrix.setQuick(i, j, Math.exp(sum - modelLikelihood));
      }
    }
  }

  /**
   * Check convergence of two HMM models by computing a simple distance between
   * emission / transition matrices
   *
   * @param oldModel Old HMM Model
   * @param newModel New HMM Model
   * @param epsilon  Convergence Factor
   * @return true if training converged to a stable state.
   */
  private static boolean checkConvergence(HmmModel oldModel, HmmModel newModel,
                                          double epsilon) {
    // check convergence of transitionProbabilities
    Matrix oldTransitionMatrix = oldModel.getTransitionMatrix();
    Matrix newTransitionMatrix = newModel.getTransitionMatrix();
    double diff = 0;
    for (int i = 0; i < oldModel.getNrOfHiddenStates(); ++i) {
      for (int j = 0; j < oldModel.getNrOfHiddenStates(); ++j) {
        double tmp = oldTransitionMatrix.getQuick(i, j)
            - newTransitionMatrix.getQuick(i, j);
        diff += tmp * tmp;
      }
    }
    double norm = Math.sqrt(diff);
    diff = 0;
    // check convergence of emissionProbabilities
    Matrix oldEmissionMatrix = oldModel.getEmissionMatrix();
    Matrix newEmissionMatrix = newModel.getEmissionMatrix();
    for (int i = 0; i < oldModel.getNrOfHiddenStates(); i++) {
      for (int j = 0; j < oldModel.getNrOfOutputStates(); j++) {

        double tmp = oldEmissionMatrix.getQuick(i, j)
            - newEmissionMatrix.getQuick(i, j);
        diff += tmp * tmp;
      }
    }
    norm += Math.sqrt(diff);
    // iteration has converged :)
    return norm < epsilon;
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy