org.apache.mahout.classifier.sequencelearning.hmm.HmmTrainer Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mahout-mr Show documentation
Show all versions of mahout-mr Show documentation
Scalable machine learning libraries
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.classifier.sequencelearning.hmm;
import java.util.Collection;
import java.util.Iterator;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
/**
* Class containing several algorithms used to train a Hidden Markov Model. The
* three main algorithms are: supervised learning, unsupervised Viterbi and
* unsupervised Baum-Welch.
*/
public final class HmmTrainer {
/**
* No public constructor for utility classes.
*/
private HmmTrainer() {
// nothing to do here really.
}
/**
* Create an supervised initial estimate of an HMM Model based on a sequence
* of observed and hidden states.
*
* @param nrOfHiddenStates The total number of hidden states
* @param nrOfOutputStates The total number of output states
* @param observedSequence Integer array containing the observed sequence
* @param hiddenSequence Integer array containing the hidden sequence
* @param pseudoCount Value that is assigned to non-occurring transitions to avoid zero
* probabilities.
* @return An initial model using the estimated parameters
*/
public static HmmModel trainSupervised(int nrOfHiddenStates, int nrOfOutputStates, int[] observedSequence,
int[] hiddenSequence, double pseudoCount) {
// make sure the pseudo count is not zero
pseudoCount = pseudoCount == 0 ? Double.MIN_VALUE : pseudoCount;
// initialize the parameters
DenseMatrix transitionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfHiddenStates);
DenseMatrix emissionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfOutputStates);
// assign a small initial probability that is larger than zero, so
// unseen states will not get a zero probability
transitionMatrix.assign(pseudoCount);
emissionMatrix.assign(pseudoCount);
// given no prior knowledge, we have to assume that all initial hidden
// states are equally likely
DenseVector initialProbabilities = new DenseVector(nrOfHiddenStates);
initialProbabilities.assign(1.0 / nrOfHiddenStates);
// now loop over the sequences to count the number of transitions
countTransitions(transitionMatrix, emissionMatrix, observedSequence,
hiddenSequence);
// make sure that probabilities are normalized
for (int i = 0; i < nrOfHiddenStates; i++) {
// compute sum of probabilities for current row of transition matrix
double sum = 0;
for (int j = 0; j < nrOfHiddenStates; j++) {
sum += transitionMatrix.getQuick(i, j);
}
// normalize current row of transition matrix
for (int j = 0; j < nrOfHiddenStates; j++) {
transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j) / sum);
}
// compute sum of probabilities for current row of emission matrix
sum = 0;
for (int j = 0; j < nrOfOutputStates; j++) {
sum += emissionMatrix.getQuick(i, j);
}
// normalize current row of emission matrix
for (int j = 0; j < nrOfOutputStates; j++) {
emissionMatrix.setQuick(i, j, emissionMatrix.getQuick(i, j) / sum);
}
}
// return a new model using the parameter estimations
return new HmmModel(transitionMatrix, emissionMatrix, initialProbabilities);
}
/**
* Function that counts the number of state->state and state->output
* transitions for the given observed/hidden sequence.
*
* @param transitionMatrix transition matrix to use.
* @param emissionMatrix emission matrix to use for counting.
* @param observedSequence observation sequence to use.
* @param hiddenSequence sequence of hidden states to use.
*/
private static void countTransitions(Matrix transitionMatrix,
Matrix emissionMatrix, int[] observedSequence, int[] hiddenSequence) {
emissionMatrix.setQuick(hiddenSequence[0], observedSequence[0],
emissionMatrix.getQuick(hiddenSequence[0], observedSequence[0]) + 1);
for (int i = 1; i < observedSequence.length; ++i) {
transitionMatrix
.setQuick(hiddenSequence[i - 1], hiddenSequence[i], transitionMatrix
.getQuick(hiddenSequence[i - 1], hiddenSequence[i]) + 1);
emissionMatrix.setQuick(hiddenSequence[i], observedSequence[i],
emissionMatrix.getQuick(hiddenSequence[i], observedSequence[i]) + 1);
}
}
/**
* Create an supervised initial estimate of an HMM Model based on a number of
* sequences of observed and hidden states.
*
* @param nrOfHiddenStates The total number of hidden states
* @param nrOfOutputStates The total number of output states
* @param hiddenSequences Collection of hidden sequences to use for training
* @param observedSequences Collection of observed sequences to use for training associated with hidden sequences.
* @param pseudoCount Value that is assigned to non-occurring transitions to avoid zero
* probabilities.
* @return An initial model using the estimated parameters
*/
public static HmmModel trainSupervisedSequence(int nrOfHiddenStates,
int nrOfOutputStates, Collection hiddenSequences,
Collection observedSequences, double pseudoCount) {
// make sure the pseudo count is not zero
pseudoCount = pseudoCount == 0 ? Double.MIN_VALUE : pseudoCount;
// initialize parameters
DenseMatrix transitionMatrix = new DenseMatrix(nrOfHiddenStates,
nrOfHiddenStates);
DenseMatrix emissionMatrix = new DenseMatrix(nrOfHiddenStates,
nrOfOutputStates);
DenseVector initialProbabilities = new DenseVector(nrOfHiddenStates);
// assign pseudo count to avoid zero probabilities
transitionMatrix.assign(pseudoCount);
emissionMatrix.assign(pseudoCount);
initialProbabilities.assign(pseudoCount);
// now loop over the sequences to count the number of transitions
Iterator hiddenSequenceIt = hiddenSequences.iterator();
Iterator observedSequenceIt = observedSequences.iterator();
while (hiddenSequenceIt.hasNext() && observedSequenceIt.hasNext()) {
// fetch the current set of sequences
int[] hiddenSequence = hiddenSequenceIt.next();
int[] observedSequence = observedSequenceIt.next();
// increase the count for initial probabilities
initialProbabilities.setQuick(hiddenSequence[0], initialProbabilities
.getQuick(hiddenSequence[0]) + 1);
countTransitions(transitionMatrix, emissionMatrix, observedSequence,
hiddenSequence);
}
// make sure that probabilities are normalized
double isum = 0; // sum of initial probabilities
for (int i = 0; i < nrOfHiddenStates; i++) {
isum += initialProbabilities.getQuick(i);
// compute sum of probabilities for current row of transition matrix
double sum = 0;
for (int j = 0; j < nrOfHiddenStates; j++) {
sum += transitionMatrix.getQuick(i, j);
}
// normalize current row of transition matrix
for (int j = 0; j < nrOfHiddenStates; j++) {
transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j) / sum);
}
// compute sum of probabilities for current row of emission matrix
sum = 0;
for (int j = 0; j < nrOfOutputStates; j++) {
sum += emissionMatrix.getQuick(i, j);
}
// normalize current row of emission matrix
for (int j = 0; j < nrOfOutputStates; j++) {
emissionMatrix.setQuick(i, j, emissionMatrix.getQuick(i, j) / sum);
}
}
// normalize the initial probabilities
for (int i = 0; i < nrOfHiddenStates; ++i) {
initialProbabilities.setQuick(i, initialProbabilities.getQuick(i) / isum);
}
// return a new model using the parameter estimates
return new HmmModel(transitionMatrix, emissionMatrix, initialProbabilities);
}
/**
* Iteratively train the parameters of the given initial model wrt to the
* observed sequence using Viterbi training.
*
* @param initialModel The initial model that gets iterated
* @param observedSequence The sequence of observed states
* @param pseudoCount Value that is assigned to non-occurring transitions to avoid zero
* probabilities.
* @param epsilon Convergence criteria
* @param maxIterations The maximum number of training iterations
* @param scaled Use Log-scaled implementation, this is computationally more
* expensive but offers better numerical stability for large observed
* sequences
* @return The iterated model
*/
public static HmmModel trainViterbi(HmmModel initialModel,
int[] observedSequence, double pseudoCount, double epsilon,
int maxIterations, boolean scaled) {
// make sure the pseudo count is not zero
pseudoCount = pseudoCount == 0 ? Double.MIN_VALUE : pseudoCount;
// allocate space for iteration models
HmmModel lastIteration = initialModel.clone();
HmmModel iteration = initialModel.clone();
// allocate space for Viterbi path calculation
int[] viterbiPath = new int[observedSequence.length];
int[][] phi = new int[observedSequence.length - 1][initialModel
.getNrOfHiddenStates()];
double[][] delta = new double[observedSequence.length][initialModel
.getNrOfHiddenStates()];
// now run the Viterbi training iteration
for (int i = 0; i < maxIterations; ++i) {
// compute the Viterbi path
HmmAlgorithms.viterbiAlgorithm(viterbiPath, delta, phi, lastIteration,
observedSequence, scaled);
// Viterbi iteration uses the viterbi path to update
// the probabilities
Matrix emissionMatrix = iteration.getEmissionMatrix();
Matrix transitionMatrix = iteration.getTransitionMatrix();
// first, assign the pseudo count
emissionMatrix.assign(pseudoCount);
transitionMatrix.assign(pseudoCount);
// now count the transitions
countTransitions(transitionMatrix, emissionMatrix, observedSequence,
viterbiPath);
// and normalize the probabilities
for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) {
double sum = 0;
// normalize the rows of the transition matrix
for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) {
sum += transitionMatrix.getQuick(j, k);
}
for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) {
transitionMatrix
.setQuick(j, k, transitionMatrix.getQuick(j, k) / sum);
}
// normalize the rows of the emission matrix
sum = 0;
for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) {
sum += emissionMatrix.getQuick(j, k);
}
for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) {
emissionMatrix.setQuick(j, k, emissionMatrix.getQuick(j, k) / sum);
}
}
// check for convergence
if (checkConvergence(lastIteration, iteration, epsilon)) {
break;
}
// overwrite the last iterated model by the new iteration
lastIteration.assign(iteration);
}
// we are done :)
return iteration;
}
/**
* Iteratively train the parameters of the given initial model wrt the
* observed sequence using Baum-Welch training.
*
* @param initialModel The initial model that gets iterated
* @param observedSequence The sequence of observed states
* @param epsilon Convergence criteria
* @param maxIterations The maximum number of training iterations
* @param scaled Use log-scaled implementations of forward/backward algorithm. This
* is computationally more expensive, but offers better numerical
* stability for long output sequences.
* @return The iterated model
*/
public static HmmModel trainBaumWelch(HmmModel initialModel,
int[] observedSequence, double epsilon, int maxIterations, boolean scaled) {
// allocate space for the iterations
HmmModel lastIteration = initialModel.clone();
HmmModel iteration = initialModel.clone();
// allocate space for baum-welch factors
int hiddenCount = initialModel.getNrOfHiddenStates();
int visibleCount = observedSequence.length;
Matrix alpha = new DenseMatrix(visibleCount, hiddenCount);
Matrix beta = new DenseMatrix(visibleCount, hiddenCount);
// now run the baum Welch training iteration
for (int it = 0; it < maxIterations; ++it) {
// fetch emission and transition matrix of current iteration
Vector initialProbabilities = iteration.getInitialProbabilities();
Matrix emissionMatrix = iteration.getEmissionMatrix();
Matrix transitionMatrix = iteration.getTransitionMatrix();
// compute forward and backward factors
HmmAlgorithms.forwardAlgorithm(alpha, iteration, observedSequence, scaled);
HmmAlgorithms.backwardAlgorithm(beta, iteration, observedSequence, scaled);
if (scaled) {
logScaledBaumWelch(observedSequence, iteration, alpha, beta);
} else {
unscaledBaumWelch(observedSequence, iteration, alpha, beta);
}
// normalize transition/emission probabilities
// and normalize the probabilities
double isum = 0;
for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) {
double sum = 0;
// normalize the rows of the transition matrix
for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) {
sum += transitionMatrix.getQuick(j, k);
}
for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) {
transitionMatrix
.setQuick(j, k, transitionMatrix.getQuick(j, k) / sum);
}
// normalize the rows of the emission matrix
sum = 0;
for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) {
sum += emissionMatrix.getQuick(j, k);
}
for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) {
emissionMatrix.setQuick(j, k, emissionMatrix.getQuick(j, k) / sum);
}
// normalization parameter for initial probabilities
isum += initialProbabilities.getQuick(j);
}
// normalize initial probabilities
for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
initialProbabilities.setQuick(i, initialProbabilities.getQuick(i)
/ isum);
}
// check for convergence
if (checkConvergence(lastIteration, iteration, epsilon)) {
break;
}
// overwrite the last iterated model by the new iteration
lastIteration.assign(iteration);
}
// we are done :)
return iteration;
}
private static void unscaledBaumWelch(int[] observedSequence, HmmModel iteration, Matrix alpha, Matrix beta) {
Vector initialProbabilities = iteration.getInitialProbabilities();
Matrix emissionMatrix = iteration.getEmissionMatrix();
Matrix transitionMatrix = iteration.getTransitionMatrix();
double modelLikelihood = HmmEvaluator.modelLikelihood(alpha, false);
for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
initialProbabilities.setQuick(i, alpha.getQuick(0, i)
* beta.getQuick(0, i));
}
// recompute transition probabilities
for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) {
double temp = 0;
for (int t = 0; t < observedSequence.length - 1; ++t) {
temp += alpha.getQuick(t, i)
* emissionMatrix.getQuick(j, observedSequence[t + 1])
* beta.getQuick(t + 1, j);
}
transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j)
* temp / modelLikelihood);
}
}
// recompute emission probabilities
for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
for (int j = 0; j < iteration.getNrOfOutputStates(); ++j) {
double temp = 0;
for (int t = 0; t < observedSequence.length; ++t) {
// delta tensor
if (observedSequence[t] == j) {
temp += alpha.getQuick(t, i) * beta.getQuick(t, i);
}
}
emissionMatrix.setQuick(i, j, temp / modelLikelihood);
}
}
}
private static void logScaledBaumWelch(int[] observedSequence, HmmModel iteration, Matrix alpha, Matrix beta) {
Vector initialProbabilities = iteration.getInitialProbabilities();
Matrix emissionMatrix = iteration.getEmissionMatrix();
Matrix transitionMatrix = iteration.getTransitionMatrix();
double modelLikelihood = HmmEvaluator.modelLikelihood(alpha, true);
for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
initialProbabilities.setQuick(i, Math.exp(alpha.getQuick(0, i) + beta.getQuick(0, i)));
}
// recompute transition probabilities
for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) {
double sum = Double.NEGATIVE_INFINITY; // log(0)
for (int t = 0; t < observedSequence.length - 1; ++t) {
double temp = alpha.getQuick(t, i)
+ Math.log(emissionMatrix.getQuick(j, observedSequence[t + 1]))
+ beta.getQuick(t + 1, j);
if (temp > Double.NEGATIVE_INFINITY) {
// handle 0-probabilities
sum = temp + Math.log1p(Math.exp(sum - temp));
}
}
transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j)
* Math.exp(sum - modelLikelihood));
}
}
// recompute emission probabilities
for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
for (int j = 0; j < iteration.getNrOfOutputStates(); ++j) {
double sum = Double.NEGATIVE_INFINITY; // log(0)
for (int t = 0; t < observedSequence.length; ++t) {
// delta tensor
if (observedSequence[t] == j) {
double temp = alpha.getQuick(t, i) + beta.getQuick(t, i);
if (temp > Double.NEGATIVE_INFINITY) {
// handle 0-probabilities
sum = temp + Math.log1p(Math.exp(sum - temp));
}
}
}
emissionMatrix.setQuick(i, j, Math.exp(sum - modelLikelihood));
}
}
}
/**
* Check convergence of two HMM models by computing a simple distance between
* emission / transition matrices
*
* @param oldModel Old HMM Model
* @param newModel New HMM Model
* @param epsilon Convergence Factor
* @return true if training converged to a stable state.
*/
private static boolean checkConvergence(HmmModel oldModel, HmmModel newModel,
double epsilon) {
// check convergence of transitionProbabilities
Matrix oldTransitionMatrix = oldModel.getTransitionMatrix();
Matrix newTransitionMatrix = newModel.getTransitionMatrix();
double diff = 0;
for (int i = 0; i < oldModel.getNrOfHiddenStates(); ++i) {
for (int j = 0; j < oldModel.getNrOfHiddenStates(); ++j) {
double tmp = oldTransitionMatrix.getQuick(i, j)
- newTransitionMatrix.getQuick(i, j);
diff += tmp * tmp;
}
}
double norm = Math.sqrt(diff);
diff = 0;
// check convergence of emissionProbabilities
Matrix oldEmissionMatrix = oldModel.getEmissionMatrix();
Matrix newEmissionMatrix = newModel.getEmissionMatrix();
for (int i = 0; i < oldModel.getNrOfHiddenStates(); i++) {
for (int j = 0; j < oldModel.getNrOfOutputStates(); j++) {
double tmp = oldEmissionMatrix.getQuick(i, j)
- newEmissionMatrix.getQuick(i, j);
diff += tmp * tmp;
}
}
norm += Math.sqrt(diff);
// iteration has converged :)
return norm < epsilon;
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy