All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.sequence.HMM Maven / Gradle / Ivy

The newest version!
/*******************************************************************************
 * Copyright (c) 2010 Haifeng Li
 *   
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *  
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *******************************************************************************/
package smile.sequence;

import java.util.HashMap;
import java.util.Map;
import smile.math.Math;

/**
 * First-order Hidden Markov Model. A hidden Markov model (HMM) is a statistical
 * Markov model in which the system being modeled is assumed to be a Markov
 * process with unobserved (hidden) states. An HMM can be considered as the
 * simplest dynamic Bayesian network. 

In a regular Markov model, the state * is directly visible to the observer, and therefore the state transition * probabilities are the only parameters. In a hidden Markov model, the state is * not directly visible, but output, dependent on the state, is visible. Each * state has a probability distribution over the possible output tokens. * Therefore the sequence of tokens generated by an HMM gives some information * about the sequence of states. * * @author Haifeng Li */ public class HMM implements SequenceLabeler { /** * The number of states. */ private int numStates; /** * The number of emission symbols. */ private int numSymbols; /** * Initial state probabilities. */ private double[] pi; /** * State transition probabilities. */ private double[][] a; /** * Symbol emission probabilities. */ private double[][] b; /** * The mapping from emission symbols to indices. */ private Map symbols; /** * Constructor. * * @param numStates the number of states. * @param numSymbols the number of emission symbols. */ private HMM(int numStates, int numSymbols) { if (numStates <= 0) { throw new IllegalArgumentException("Invalid number of states: " + numStates); } if (numSymbols <= 0) { throw new IllegalArgumentException("Invalid number of emission symbols: " + numSymbols); } this.numStates = numStates; this.numSymbols = numSymbols; pi = new double[numStates]; a = new double[numStates][numStates]; b = new double[numStates][numSymbols]; } /** * Constructor. * * @param pi the initial state probabilities. * @param a the state transition probabilities, of which a[i][j] is P(s_j | * s_i); * @param b the symbol emission probabilities, of which b[i][j] is P(o_j | * s_i). */ public HMM(double[] pi, double[][] a, double[][] b) { this(pi, a, b, null); } /** * Constructor. * * @param pi the initial state probabilities. * @param a the state transition probabilities, of which a[i][j] is P(s_j | * s_i); * @param b the symbol emission probabilities, of which b[i][j] is P(o_j | * s_i). * @param symbols the list of emission symbols. */ public HMM(double[] pi, double[][] a, double[][] b, O[] symbols) { if (pi.length == 0) { throw new IllegalArgumentException("Invalid initial state probabilities."); } if (pi.length != a.length) { throw new IllegalArgumentException("Invalid state transition probability matrix."); } if (a.length != b.length) { throw new IllegalArgumentException("Invalid symbol emission probability matrix."); } if (symbols != null) { if (b[0].length != symbols.length) { throw new IllegalArgumentException("Invalid size of emission symbol list."); } this.symbols = new HashMap(); for (int i = 0; i < symbols.length; i++) { this.symbols.put(symbols[i], i); } } numStates = pi.length; numSymbols = b[0].length; for (int i = 0; i < numStates; i++) { if (a[i].length != numStates) { throw new IllegalArgumentException("Invalid state transition probability matrix."); } double sum = 0.0; for (int j = 0; j < numStates; j++) { if (a[i][j] < 0 || a[i][j] > 1.0) { throw new IllegalArgumentException("Invalid state transition probability: " + a[i][j]); } sum += a[i][j]; } if (Math.abs(1.0 - sum) > 1E-7) { throw new IllegalArgumentException(String.format("The row %d of state transition probability matrix doesn't sum to 1.", i)); } } for (int i = 0; i < numStates; i++) { if (b[i].length != numSymbols) { throw new IllegalArgumentException("Invalid symbol emission probability matrix."); } double sum = 0.0; for (int j = 0; j < numSymbols; j++) { if (b[i][j] < 0 || b[i][j] > 1.0) { throw new IllegalArgumentException("Invalid symbol emission probability: " + b[i][j]); } sum += b[i][j]; } if (Math.abs(1.0 - sum) > 1E-7) { throw new IllegalArgumentException(String.format("The row %d of symbol emission probability matrix doesn't sum to 1.", i)); } } this.pi = pi; this.a = a; this.b = b; } /** * Returns the number of states. */ public int numStates() { return numStates; } /** * Returns the number of emission symbols. */ public int numSymbols() { return numSymbols; } /** * Returns the initial state probabilities. */ public double[] getInitialStateProbabilities() { return pi; } /** * Returns the state transition probabilities. */ public double[][] getStateTransitionProbabilities() { return a; } /** * Returns the symbol emission probabilities. */ public double[][] getSymbolEmissionProbabilities() { return b; } /** * Returns natural log without underflow. */ private static double log(double x) { double y; if (x < 1E-300) { y = -690.7755; } else { y = java.lang.Math.log(x); } return y; } /** * Returns the joint probability of an observation sequence along a state * sequence given this HMM. * * @param o an observation sequence. * @param s a state sequence. * @return the joint probability P(o, s | H) given the model H. */ public double p(int[] o, int[] s) { return Math.exp(logp(o, s)); } /** * Returns the log joint probability of an observation sequence along a * state sequence given this HMM. * * @param o an observation sequence. * @param s a state sequence. * @return the log joint probability P(o, s | H) given the model H. */ public double logp(int[] o, int[] s) { if (o.length != s.length) { throw new IllegalArgumentException("The observation sequence and state sequence are not the same length."); } int n = s.length; double p = log(pi[s[0]]) + log(b[s[0]][o[0]]); for (int i = 1; i < n; i++) { p += log(a[s[i - 1]][s[i]]) + log(b[s[i]][o[i]]); } return p; } /** * Returns the probability of an observation sequence given this HMM. * * @param o an observation sequence. * @return the probability of this sequence. */ public double p(int[] o) { return Math.exp(logp(o)); } /** * Returns the logarithm probability of an observation sequence given this * HMM. A scaling procedure is used in order to avoid underflows when * computing the probability of long sequences. * * @param o an observation sequence. * @return the log probability of this sequence. */ public double logp(int[] o) { double[][] alpha = new double[o.length][numStates]; double[] scaling = new double[o.length]; forward(o, alpha, scaling); double p = 0.0; for (int t = 0; t < o.length; t++) { p += java.lang.Math.log(scaling[t]); } return p; } /** * Normalize alpha[t] and put the normalization factor in scaling[t]. */ private void scale(double[] scaling, double[][] alpha, int t) { double[] table = alpha[t]; double sum = 0.0; for (int i = 0; i < table.length; i++) { sum += table[i]; } scaling[t] = sum; for (int i = 0; i < table.length; i++) { table[i] /= sum; } } /** * Scaled forward procedure without underflow. * * @param o an observation sequence. * @param alpha on output, alpha(i, j) holds the scaled total probability of * ending up in state i at time j. * @param scaling on output, it holds scaling factors. */ private void forward(int[] o, double[][] alpha, double[] scaling) { for (int k = 0; k < numStates; k++) { alpha[0][k] = pi[k] * b[k][o[0]]; } scale(scaling, alpha, 0); for (int t = 1; t < o.length; t++) { for (int k = 0; k < numStates; k++) { double sum = 0.0; for (int i = 0; i < numStates; i++) { sum += alpha[t - 1][i] * a[i][k]; } alpha[t][k] = sum * b[k][o[t]]; } scale(scaling, alpha, t); } } /** * Scaled backward procedure without underflow. * * @param o an observation sequence. * @param beta on output, beta(i, j) holds the scaled total probability of * starting up in state i at time j. * @param scaling on input, it should hold scaling factors computed by * forward procedure. */ private void backward(int[] o, double[][] beta, double[] scaling) { int n = o.length - 1; for (int i = 0; i < numStates; i++) { beta[n][i] = 1.0 / scaling[n]; } for (int t = n; t-- > 0;) { for (int i = 0; i < numStates; i++) { double sum = 0.; for (int j = 0; j < numStates(); j++) { sum += beta[t + 1][j] * a[i][j] * b[j][o[t + 1]]; } beta[t][i] = sum / scaling[t]; } } } /** * Returns the most likely state sequence given the observation sequence by * the Viterbi algorithm, which maximizes the probability of * P(I | O, HMM). In the calculation, we may get ties. In this * case, one of them is chosen randomly. * * @param o an observation sequence. * @return the most likely state sequence. */ public int[] predict(int[] o) { // The porbability of the most probable path. double[][] trellis = new double[o.length][numStates]; // Backtrace. int[][] psy = new int[o.length][numStates]; // The most likely state sequence. int[] s = new int[o.length]; // forward for (int i = 0; i < numStates; i++) { trellis[0][i] = log(pi[i]) + log(b[i][o[0]]); psy[0][i] = 0; } for (int t = 1; t < o.length; t++) { for (int j = 0; j < numStates; j++) { double maxDelta = Double.NEGATIVE_INFINITY; int maxPsy = 0; for (int i = 0; i < numStates; i++) { double delta = trellis[t - 1][i] + log(a[i][j]); if (maxDelta < delta) { maxDelta = delta; maxPsy = i; } } trellis[t][j] = maxDelta + log(b[j][o[t]]); psy[t][j] = maxPsy; } } // trace back int n = o.length - 1; double maxDelta = Double.NEGATIVE_INFINITY; for (int i = 0; i < numStates; i++) { if (maxDelta < trellis[n][i]) { maxDelta = trellis[n][i]; s[n] = i; } } for (int t = n; t-- > 0;) { s[t] = psy[t + 1][s[t + 1]]; } return s; } /** * Learn an HMM from labeled observation sequences by maximum likelihood * estimation. * * @param observations the observation sequences, of which symbols take * values in [0, n), where n is the number of unique symbols. * @param labels the state labels of observations, of which states take * values in [0, p), where p is the number of hidden states. */ public HMM(int[][] observations, int[][] labels) { if (observations.length != labels.length) { throw new IllegalArgumentException("The number of observation sequences and that of label sequences are different."); } numStates = 0; numSymbols = 0; for (int i = 0; i < observations.length; i++) { if (observations[i].length != labels[i].length) { throw new IllegalArgumentException(String.format("The length of observation sequence %d and that of corresponding label sequence are different.", i)); } numStates = Math.max(numStates, Math.max(labels[i]) + 1); numSymbols = Math.max(numSymbols, Math.max(observations[i]) + 1); } pi = new double[numStates]; a = new double[numStates][numStates]; b = new double[numStates][numSymbols]; for (int i = 0; i < observations.length; i++) { pi[labels[i][0]]++; b[labels[i][0]][observations[i][0]]++; for (int j = 1; j < observations[i].length; j++) { a[labels[i][j - 1]][labels[i][j]]++; b[labels[i][j]][observations[i][j]]++; } } Math.unitize1(pi); for (int i = 0; i < numStates; i++) { Math.unitize1(a[i]); Math.unitize1(b[i]); } } /** * Learn an HMM from labeled observation sequences by maximum likelihood * estimation. * * @param observations the observation sequences. * @param labels the state labels of observations, of which states take * values in [0, p), where p is the number of hidden states. */ public HMM(O[][] observations, int[][] labels) { if (observations.length != labels.length) { throw new IllegalArgumentException("The number of observation sequences and that of label sequences are different."); } int index = 0; symbols = new HashMap(); for (int i = 0; i < observations.length; i++) { if (observations[i].length != labels[i].length) { throw new IllegalArgumentException(String.format("The length of observation sequence %d and that of corresponding label sequence are different.", i)); } for (int j = 0; j < observations[i].length; j++) { Integer sym = symbols.get(observations[i][j]); if (sym == null) { symbols.put(observations[i][j], index++); } } } int[][] obs = new int[observations.length][]; for (int i = 0; i < obs.length; i++) { obs[i] = translate(symbols, observations[i]); } numStates = 0; numSymbols = 0; for (int i = 0; i < obs.length; i++) { numStates = Math.max(numStates, Math.max(labels[i]) + 1); numSymbols = Math.max(numSymbols, Math.max(obs[i]) + 1); } pi = new double[numStates]; a = new double[numStates][numStates]; b = new double[numStates][numSymbols]; for (int i = 0; i < obs.length; i++) { pi[labels[i][0]]++; b[labels[i][0]][obs[i][0]]++; for (int j = 1; j < obs[i].length; j++) { a[labels[i][j - 1]][labels[i][j]]++; b[labels[i][j]][obs[i][j]]++; } } Math.unitize1(pi); for (int i = 0; i < numStates; i++) { Math.unitize1(a[i]); Math.unitize1(b[i]); } } /** * With this HMM as the initial model, learn an HMM by the Baum-Welch * algorithm. * * @param observations the observation sequences on which the learning is * based. Each sequence must have a length higher or equal to 2. * @param iterations the number of iterations to execute. * @return the updated HMM. */ public HMM learn(O[][] observations, int iterations) { int[][] obs = new int[observations.length][]; for (int i = 0; i < obs.length; i++) { obs[i] = translate(observations[i]); } return learn(obs, iterations); } /** * With this HMM as the initial model, learn an HMM by the Baum-Welch * algorithm. * * @param observations the observation sequences on which the learning is * based. Each sequence must have a length higher or equal to 2. * @param iterations the number of iterations to execute. * @return the updated HMM. */ public HMM learn(int[][] observations, int iterations) { HMM hmm = this; for (int iter = 0; iter < iterations; iter++) { hmm = hmm.iterate(observations); } return hmm; } /** * Performs one iteration of the Baum-Welch algorithm. * * @param sequences the training observation sequences. * * @return an updated HMM. */ private HMM iterate(int[][] sequences) { HMM hmm = new HMM(numStates, numSymbols); hmm.symbols = symbols; // gamma[n] = gamma array associated to observation sequence n double gamma[][][] = new double[sequences.length][][]; // a[i][j] = aijNum[i][j] / aijDen[i] // aijDen[i] = expected number of transitions from state i // aijNum[i][j] = expected number of transitions from state i to j double aijNum[][] = new double[numStates][numStates]; double aijDen[] = new double[numStates]; for (int k = 0; k < sequences.length; k++) { if (sequences[k].length <= 2) { throw new IllegalArgumentException(String.format("Traning sequence %d is too short.", k)); } int[] o = sequences[k]; double[][] alpha = new double[o.length][numStates]; double[][] beta = new double[o.length][numStates]; double[] scaling = new double[o.length]; forward(o, alpha, scaling); backward(o, beta, scaling); double xi[][][] = estimateXi(o, alpha, beta); double g[][] = gamma[k] = estimateGamma(xi); int n = o.length - 1; for (int i = 0; i < numStates; i++) { for (int t = 0; t < n; t++) { aijDen[i] += g[t][i]; for (int j = 0; j < numStates; j++) { aijNum[i][j] += xi[t][i][j]; } } } } for (int i = 0; i < numStates; i++) { if (aijDen[i] == 0.0) // State i is not reachable { System.arraycopy(a[i], 0, hmm.a[i], 0, numStates); } else { for (int j = 0; j < numStates; j++) { hmm.a[i][j] = aijNum[i][j] / aijDen[i]; } } } /* * initial state probability computation */ for (int j = 0; j < sequences.length; j++) { for (int i = 0; i < numStates; i++) { hmm.pi[i] += gamma[j][0][i]; } } for (int i = 0; i < numStates; i++) { hmm.pi[i] /= sequences.length; } /* * emission probability computation */ for (int i = 0; i < numStates; i++) { double sum = 0.0; for (int j = 0; j < sequences.length; j++) { int[] o = sequences[j]; for (int t = 0; t < o.length; t++) { hmm.b[i][o[t]] += gamma[j][t][i]; sum += gamma[j][t][i]; } } for (int j = 0; j < numSymbols; j++) { hmm.b[i][j] /= sum; } } return hmm; } /** * Here, the xi (and, thus, gamma) values are not divided by the probability * of the sequence because this probability might be too small and induce an * underflow. xi[t][i][j] still can be interpreted as P(q_t = i and q_(t+1) * = j | O, HMM) because we assume that the scaling factors are such that * their product is equal to the inverse of the probability of the sequence. */ private double[][][] estimateXi(int[] o, double[][] alpha, double[][] beta) { if (o.length <= 1) { throw new IllegalArgumentException("Observation sequence is too short."); } int n = o.length - 1; double xi[][][] = new double[n][numStates][numStates]; for (int t = 0; t < n; t++) { for (int i = 0; i < numStates; i++) { for (int j = 0; j < numStates; j++) { xi[t][i][j] = alpha[t][i] * a[i][j] * b[j][o[t + 1]] * beta[t + 1][j]; } } } return xi; } /** * gamma[][] could be computed directly using the alpha and beta arrays, but * this (slower) method is preferred because it doesn't change if the xi * array has been scaled (and should be changed with the scaled alpha and * beta arrays). */ private double[][] estimateGamma(double[][][] xi) { double[][] gamma = new double[xi.length + 1][numStates]; for (int t = 0; t < xi.length; t++) { for (int i = 0; i < numStates; i++) { for (int j = 0; j < numStates; j++) { gamma[t][i] += xi[t][i][j]; } } } int n = xi.length - 1; for (int j = 0; j < numStates; j++) { for (int i = 0; i < numStates; i++) { gamma[xi.length][j] += xi[n][i][j]; } } return gamma; } @Override public String toString() { StringBuilder sb = new StringBuilder(); sb.append(String.format("HMM (%d states, %d emission symbols)\n", numStates, numSymbols)); sb.append("\tInitial state probability:\n\t\t"); for (int i = 0; i < numStates; i++) { sb.append(String.format("%.4g ", pi[i])); } sb.append("\n\tState transition probability:"); for (int i = 0; i < numStates; i++) { sb.append("\n\t\t"); for (int j = 0; j < numStates; j++) { sb.append(String.format("%.4g ", a[i][j])); } } sb.append("\n\tSymbol emission probability:"); for (int i = 0; i < numStates; i++) { sb.append("\n\t\t"); for (int j = 0; j < numSymbols; j++) { sb.append(String.format("%.4g ", b[i][j])); } } return sb.toString(); } /** * Translate an observation sequence to internal representation. */ private int[] translate(O[] o) { return translate(symbols, o); } /** * Translate an observation sequence to internal representation. */ private static int[] translate(Map symbols, O[] o) { if (symbols == null) { throw new IllegalArgumentException("No availabe emission symbol list."); } int[] seq = new int[o.length]; for (int i = 0; i < o.length; i++) { Integer sym = symbols.get(o[i]); if (sym != null) { seq[i] = sym; } else { throw new IllegalArgumentException("Invalid observation symbol: " + o[i]); } } return seq; } /** * Returns the joint probability of an observation sequence along a state * sequence given this HMM. * * @param o an observation sequence. * @param s a state sequence. * @return the joint probability P(o, s | H) given the model H. */ public double p(O[] o, int[] s) { return Math.exp(logp(o, s)); } /** * Returns the log joint probability of an observation sequence along a * state sequence given this HMM. * * @param o an observation sequence. * @param s a state sequence. * @return the log joint probability P(o, s | H) given the model H. */ public double logp(O[] o, int[] s) { return logp(translate(o), s); } /** * Returns the probability of an observation sequence given this HMM. * * @param o an observation sequence. * @return the probability of this sequence. */ public double p(O[] o) { return Math.exp(logp(o)); } /** * Returns the logarithm probability of an observation sequence given this * HMM. A scaling procedure is used in order to avoid underflows when * computing the probability of long sequences. * * @param o an observation sequence. * @return the log probability of this sequence. */ public double logp(O[] o) { return logp(translate(o)); } /** * Returns the most likely state sequence given the observation sequence by * the Viterbi algorithm, which maximizes the probability of * P(I | O, HMM). In the calculation, we may get ties. In this * case, one of them is chosen randomly. * * @param o an observation sequence. * @return the most likely state sequence. */ @Override public int[] predict(O[] o) { return predict(translate(o)); } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy