All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.classifier.sequencelearning.hmm.HmmModel Maven / Gradle / Ivy

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.classifier.sequencelearning.hmm;

import java.util.Map;
import java.util.Random;

import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;

/**
 * Main class defining a Hidden Markov Model
 */
public class HmmModel implements Cloneable {

  /** Bi-directional Map for storing the observed state names */
  private BiMap outputStateNames;

  /** Bi-Directional Map for storing the hidden state names */
  private BiMap hiddenStateNames;

  /* Number of hidden states */
  private int nrOfHiddenStates;

  /** Number of output states */
  private int nrOfOutputStates;

  /**
   * Transition matrix containing the transition probabilities between hidden
   * states. TransitionMatrix(i,j) is the probability that we change from hidden
   * state i to hidden state j In general: P(h(t+1)=h_j | h(t) = h_i) =
   * transitionMatrix(i,j) Since we have to make sure that each hidden state can
   * be "left", the following normalization condition has to hold:
   * sum(transitionMatrix(i,j),j=1..hiddenStates) = 1
   */
  private Matrix transitionMatrix;

  /**
   * Output matrix containing the probabilities that we observe a given output
   * state given a hidden state. outputMatrix(i,j) is the probability that we
   * observe output state j if we are in hidden state i Formally: P(o(t)=o_j |
   * h(t)=h_i) = outputMatrix(i,j) Since we always have an observation for each
   * hidden state, the following normalization condition has to hold:
   * sum(outputMatrix(i,j),j=1..outputStates) = 1
   */
  private Matrix emissionMatrix;

  /**
   * Vector containing the initial hidden state probabilities. That is
   * P(h(0)=h_i) = initialProbabilities(i). Since we are dealing with
   * probabilities the following normalization condition has to hold:
   * sum(initialProbabilities(i),i=1..hiddenStates) = 1
   */
  private Vector initialProbabilities;


  /**
   * Get a copy of this model
   */
  @Override
  public HmmModel clone() {
    HmmModel model = new HmmModel(transitionMatrix.clone(), emissionMatrix.clone(), initialProbabilities.clone());
    if (hiddenStateNames != null) {
      model.hiddenStateNames = HashBiMap.create(hiddenStateNames);
    }
    if (outputStateNames != null) {
      model.outputStateNames = HashBiMap.create(outputStateNames);
    }
    return model;
  }

  /**
   * Assign the content of another HMM model to this one
   *
   * @param model The HmmModel that will be assigned to this one
   */
  public void assign(HmmModel model) {
    this.nrOfHiddenStates = model.nrOfHiddenStates;
    this.nrOfOutputStates = model.nrOfOutputStates;
    this.hiddenStateNames = model.hiddenStateNames;
    this.outputStateNames = model.outputStateNames;
    // for now clone the matrix/vectors
    this.initialProbabilities = model.initialProbabilities.clone();
    this.emissionMatrix = model.emissionMatrix.clone();
    this.transitionMatrix = model.transitionMatrix.clone();
  }

  /**
   * Construct a valid random Hidden-Markov parameter set with the given number
   * of hidden and output states using a given seed.
   *
   * @param nrOfHiddenStates Number of hidden states
   * @param nrOfOutputStates Number of output states
   * @param seed             Seed for the random initialization, if set to 0 the current time
   *                         is used
   */
  public HmmModel(int nrOfHiddenStates, int nrOfOutputStates, long seed) {
    this.nrOfHiddenStates = nrOfHiddenStates;
    this.nrOfOutputStates = nrOfOutputStates;
    this.transitionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfHiddenStates);
    this.emissionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfOutputStates);
    this.initialProbabilities = new DenseVector(nrOfHiddenStates);
    // initialize a random, valid parameter set
    initRandomParameters(seed);
  }

  /**
   * Construct a valid random Hidden-Markov parameter set with the given number
   * of hidden and output states.
   *
   * @param nrOfHiddenStates Number of hidden states
   * @param nrOfOutputStates Number of output states
   */
  public HmmModel(int nrOfHiddenStates, int nrOfOutputStates) {
    this(nrOfHiddenStates, nrOfOutputStates, 0);
  }

  /**
   * Generates a Hidden Markov model using the specified parameters
   *
   * @param transitionMatrix     transition probabilities.
   * @param emissionMatrix       emission probabilities.
   * @param initialProbabilities initial start probabilities.
   * @throws IllegalArgumentException If the given parameter set is invalid
   */
  public HmmModel(Matrix transitionMatrix, Matrix emissionMatrix, Vector initialProbabilities) {
    this.nrOfHiddenStates = initialProbabilities.size();
    this.nrOfOutputStates = emissionMatrix.numCols();
    this.transitionMatrix = transitionMatrix;
    this.emissionMatrix = emissionMatrix;
    this.initialProbabilities = initialProbabilities;
  }

  /**
   * Initialize a valid random set of HMM parameters
   *
   * @param seed seed to use for Random initialization. Use 0 to use Java-built-in-version.
   */
  private void initRandomParameters(long seed) {
    Random rand;
    // initialize the random number generator
    if (seed == 0) {
      rand = RandomUtils.getRandom();
    } else {
      rand = RandomUtils.getRandom(seed);
    }
    // initialize the initial Probabilities
    double sum = 0; // used for normalization
    for (int i = 0; i < nrOfHiddenStates; i++) {
      double nextRand = rand.nextDouble();
      initialProbabilities.set(i, nextRand);
      sum += nextRand;
    }
    // "normalize" the vector to generate probabilities
    initialProbabilities = initialProbabilities.divide(sum);

    // initialize the transition matrix
    double[] values = new double[nrOfHiddenStates];
    for (int i = 0; i < nrOfHiddenStates; i++) {
      sum = 0;
      for (int j = 0; j < nrOfHiddenStates; j++) {
        values[j] = rand.nextDouble();
        sum += values[j];
      }
      // normalize the random values to obtain probabilities
      for (int j = 0; j < nrOfHiddenStates; j++) {
        values[j] /= sum;
      }
      // set this row of the transition matrix
      transitionMatrix.set(i, values);
    }

    // initialize the output matrix
    values = new double[nrOfOutputStates];
    for (int i = 0; i < nrOfHiddenStates; i++) {
      sum = 0;
      for (int j = 0; j < nrOfOutputStates; j++) {
        values[j] = rand.nextDouble();
        sum += values[j];
      }
      // normalize the random values to obtain probabilities
      for (int j = 0; j < nrOfOutputStates; j++) {
        values[j] /= sum;
      }
      // set this row of the output matrix
      emissionMatrix.set(i, values);
    }
  }

  /**
   * Getter Method for the number of hidden states
   *
   * @return Number of hidden states
   */
  public int getNrOfHiddenStates() {
    return nrOfHiddenStates;
  }

  /**
   * Getter Method for the number of output states
   *
   * @return Number of output states
   */
  public int getNrOfOutputStates() {
    return nrOfOutputStates;
  }

  /**
   * Getter function to get the hidden state transition matrix
   *
   * @return returns the model's transition matrix.
   */
  public Matrix getTransitionMatrix() {
    return transitionMatrix;
  }

  /**
   * Getter function to get the output state probability matrix
   *
   * @return returns the models emission matrix.
   */
  public Matrix getEmissionMatrix() {
    return emissionMatrix;
  }

  /**
   * Getter function to return the vector of initial hidden state probabilities
   *
   * @return returns the model's init probabilities.
   */
  public Vector getInitialProbabilities() {
    return initialProbabilities;
  }

  /**
   * Getter method for the hidden state Names map
   *
   * @return hidden state names.
   */
  public Map getHiddenStateNames() {
    return hiddenStateNames;
  }

  /**
   * Register an array of hidden state Names. We assume that the state name at
   * position i has the ID i
   *
   * @param stateNames names of hidden states.
   */
  public void registerHiddenStateNames(String[] stateNames) {
    if (stateNames != null) {
      hiddenStateNames = HashBiMap.create();
      for (int i = 0; i < stateNames.length; ++i) {
        hiddenStateNames.put(stateNames[i], i);
      }
    }
  }

  /**
   * Register a map of hidden state Names/state IDs
   *
   * @param stateNames  Map that assigns each state name an integer ID
   */
  public void registerHiddenStateNames(Map stateNames) {
    if (stateNames != null) {
      hiddenStateNames = HashBiMap.create(stateNames);
    }
  }

  /**
   * Lookup the name for the given hidden state ID
   *
   * @param id Integer id of the hidden state
   * @return String containing the name for the given ID, null if this ID is not
   *         known or no hidden state names were specified
   */
  public String getHiddenStateName(int id) {
    if (hiddenStateNames == null) {
      return null;
    }
    return hiddenStateNames.inverse().get(id);
  }

  /**
   * Lookup the ID for the given hidden state name
   *
   * @param name Name of the hidden state
   * @return int containing the ID for the given name, -1 if this name is not
   *         known or no hidden state names were specified
   */
  public int getHiddenStateID(String name) {
    if (hiddenStateNames == null) {
      return -1;
    }
    Integer tmp = hiddenStateNames.get(name);
    return tmp == null ? -1 : tmp;
  }

  /**
   * Getter method for the output state Names map
   *
   * @return names of output states.
   */
  public Map getOutputStateNames() {
    return outputStateNames;
  }

  /**
   * Register an array of hidden state Names. We assume that the state name at
   * position i has the ID i
   *
   * @param stateNames state names to register.
   */
  public void registerOutputStateNames(String[] stateNames) {
    if (stateNames != null) {
      outputStateNames = HashBiMap.create();
      for (int i = 0; i < stateNames.length; ++i) {
        outputStateNames.put(stateNames[i], i);
      }
    }
  }

  /**
   * Register a map of hidden state Names/state IDs
   *
   * @param stateNames  Map that assigns each state name an integer ID
   */
  public void registerOutputStateNames(Map stateNames) {
    if (stateNames != null) {
      outputStateNames = HashBiMap.create(stateNames);
    }
  }

  /**
   * Lookup the name for the given output state id
   *
   * @param id Integer id of the output state
   * @return String containing the name for the given id, null if this id is not
   *         known or no output state names were specified
   */
  public String getOutputStateName(int id) {
    if (outputStateNames == null) {
      return null;
    }
    return outputStateNames.inverse().get(id);
  }

  /**
   * Lookup the ID for the given output state name
   *
   * @param name Name of the output state
   * @return int containing the ID for the given name, -1 if this name is not
   *         known or no output state names were specified
   */
  public int getOutputStateID(String name) {
    if (outputStateNames == null) {
      return -1;
    }
    Integer tmp = outputStateNames.get(name);
    return tmp == null ? -1 : tmp;
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy