All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.enterprisemath.math.statistics.HMM Maven / Gradle / Ivy

The newest version!
package com.enterprisemath.math.statistics;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.apache.commons.lang3.builder.EqualsBuilder;
import org.apache.commons.lang3.builder.HashCodeBuilder;
import org.apache.commons.lang3.builder.ToStringBuilder;

import com.enterprisemath.math.probability.ProbabilityDistribution;
import com.enterprisemath.math.statistics.observation.ObservationIterator;
import com.enterprisemath.math.statistics.observation.ObservationProvider;
import com.enterprisemath.utils.DomainUtils;
import com.enterprisemath.utils.ValidationUtils;

/**
 * Basic definition of the hidden markov model.
 * 
 * @author radek.hecl
 *
 * @param  type of the hidden state
 * @param  type of the observation
 */
public class HMM {

    /**
     * Builder object.
     *
     * @param  type of the hidden state
     * @param  type of the observation
     */
    public static class Builder {

        /**
         * Set of hidden states.
         */
        private Set hiddenStates = new HashSet();

        /**
         * Initial probability. Must be 1 over all states.
         */
        private ProbabilityDistribution initialProbability;

        /**
         * Transition probabilities between states.
         */
        private Map> transitionProbabilities = new HashMap>();

        /**
         * Emission probabilities to the observations.
         */
        private Map> emissionProbabilities = new HashMap>();

        /**
         * Sets the hidden states.
         * 
         * @param hiddenStates hidden states
         * @return this instance
         */
        public Builder setHiddenStates(Set hiddenStates) {
            this.hiddenStates = DomainUtils.softCopySet(hiddenStates);
            return this;
        }

        /**
         * Adds hidden state.
         * 
         * @param hiddenState hidden state
         * @return this instance
         */
        public Builder addHiddenState(State hiddenState) {
            this.hiddenStates.add(hiddenState);
            return this;
        }

        /**
         * Sets initial probability distribution.
         * 
         * @param initialProbability initial probability distribution
         * @return this instance
         */
        public Builder setInitialProbability(ProbabilityDistribution initialProbability) {
            this.initialProbability = initialProbability;
            return this;
        }

        /**
         * Sets transition probability distributions.
         * 
         * @param transitionProbabilities transition probability distributions
         * @return this instance
         */
        public Builder setTransitionProbabilities(Map> transitionProbabilities) {
            this.transitionProbabilities = DomainUtils.softCopyMap(transitionProbabilities);
            return this;
        }

        /**
         * Add transition probability distribution.
         * 
         * @param state state from which transition starts
         * @param probability probability distribution to other transitions
         * @return this instance
         */
        public Builder addTransitionProbability(State state, ProbabilityDistribution probability) {
            this.transitionProbabilities.put(state, probability);
            return this;
        }

        /**
         * Sets emission probability distribution.
         * 
         * @param emissionProbabilities emission probability distributions
         * @return this instance
         */
        public Builder setEmissionProbabilities(Map> emissionProbabilities) {
            this.emissionProbabilities = DomainUtils.softCopyMap(emissionProbabilities);
            return this;
        }

        /**
         * Adds emission probability distribution.
         * 
         * @param state state of the emission
         * @param probability probability distribution of emission
         * @return this instance
         */
        public Builder addEmissionProbability(State state, ProbabilityDistribution probability) {
            this.emissionProbabilities.put(state, probability);
            return this;
        }

        /**
         * Builds the result object.
         * 
         * @return created object
         */
        public HMM build() {
            return new HMM(this);
        }
    }

    /**
     * Set of hidden states.
     */
    private Set hiddenStates;

    /**
     * Initial probability. Must be 1 over all states.
     */
    private ProbabilityDistribution initialProbability;

    /**
     * Transition probabilities between states.
     */
    private Map> transitionProbabilities;

    /**
     * Emission probabilities to the observations.
     */
    private Map> emissionProbabilities;

    /**
     * Creates new instance.
     * 
     * @param builder builder object
     */
    public HMM(Builder builder) {
        hiddenStates = Collections.unmodifiableSet(DomainUtils.softCopySet(builder.hiddenStates));
        initialProbability = builder.initialProbability;
        transitionProbabilities = Collections.unmodifiableMap(DomainUtils.softCopyMap(builder.transitionProbabilities));
        emissionProbabilities = Collections.unmodifiableMap(DomainUtils.softCopyMap(builder.emissionProbabilities));
        guardInvariants();
    }

    /**
     * Guards this object to be consistent. Throws exception if this is not the case.
     */
    private void guardInvariants() {
        //
        // first level validation
        ValidationUtils.guardNotNullCollection(hiddenStates, "hiddenStates cannot have null element");
        ValidationUtils.guardPositiveInt(hiddenStates.size(), "hiddenStates cannot be empty");
        ValidationUtils.guardNotNull(initialProbability, "initialProbability canno be null");
        ValidationUtils.guardNotNullMap(transitionProbabilities, "transitionProbabilities cannot have null element");
        ValidationUtils.guardNotNullMap(emissionProbabilities, "emissionProbabilities cannot have null element");
        //
        // second level validation
        ValidationUtils.guardEquals(hiddenStates, transitionProbabilities.keySet(),
                "transitionProbabilities must be defined for all hiddenStates");
        ValidationUtils.guardEquals(hiddenStates, emissionProbabilities.keySet(),
                "emissionProbabilities must be defined for all hiddenStates");
        //
        // third level validation
        double h = 0;
        for (State state : hiddenStates) {
            h = h + initialProbability.getValue(state);
        }
        ValidationUtils.guardGreaterOrEqualDouble(0.0001, Math.abs(1 - h),
                "initialProbability must be 1 in sum through all hiddenStates");
        for (State state : hiddenStates) {
            h = 0;
            for (State until : hiddenStates) {
                h = h + transitionProbabilities.get(state).getValue(until);
            }
            ValidationUtils.guardGreaterOrEqualDouble(0.0001, Math.abs(1 - h),
                    "transitionProbability must be 1 in sum for " + state);
        }
    }

    /**
     * Returns the most probable sequence of the hidden states for the given observations.
     * If there only one most probable sequence exists (which is expected in the majority of real world cases),
     * then there is not problem.
     * If there are multiple most probable sequences, then one of them is returned.
     * In that case there is not guaranteed which one and even on 2 calls with same observations
     * the returned sequence may be different.
     * 
     * @param observations observations
     * @return this instance
     */
    public List viterbi(ObservationProvider observations) {
        ObservationIterator iterator = observations.getIterator();

        //
        // handling edge case
        if (!iterator.isNextAvailable()) {
            return Collections.emptyList();
        }

        //
        // initialization
        List states = new ArrayList(hiddenStates);
        List delta = new ArrayList();
        List psi = new ArrayList();

        //
        // first observation
        {
            Observation obs = iterator.getNext();
            double[] deltaelm = new double[states.size()];
            int[] psielm = new int[states.size()];
            for (int i = 0; i < states.size(); ++i) {
                deltaelm[i] = initialProbability.getLnValue(states.get(i)) + emissionProbabilities.get(states.get(i)).getLnValue(obs);
                psielm[i] = i;
            }
            delta.add(deltaelm);
            psi.add(psielm);
        }

        //
        // iterations
        while (iterator.isNextAvailable()) {
            Observation obs = iterator.getNext();
            double[] deltaelm = new double[states.size()];
            int[] psielm = new int[states.size()];
            for (int i = 0; i < states.size(); ++i) {
                double max = Double.NEGATIVE_INFINITY;
                int argmax = 0;
                for (int j = 0; j < states.size(); ++j) {
                    double val = delta.get(delta.size() - 1)[j] + transitionProbabilities.get(states.get(j)).getLnValue(states.get(i));
                    if (val > max) {
                        max = val;
                        argmax = j;
                    }
                }
                deltaelm[i] = max + emissionProbabilities.get(states.get(i)).getLnValue(obs);
                psielm[i] = argmax;
            }
            delta.add(deltaelm);
            psi.add(psielm);
        }

        //
        // taking the last (current) observation
        List res = new ArrayList();
        {
            double max = Double.NEGATIVE_INFINITY;
            int argmax = -1;
            for (int j = 0; j < states.size(); ++j) {
                if (delta.get(delta.size() - 1)[j] > max) {
                    max = delta.get(delta.size() - 1)[j];
                    argmax = j;
                }
            }
            res.add(states.get(argmax));
        }

        //
        // back tracking to get the whole sequence
        for (int i = delta.size() - 1; i >= 1; --i) {
            double max = Double.NEGATIVE_INFINITY;
            int argmax = -1;
            for (int j = 0; j < states.size(); ++j) {
                if (delta.get(i)[j] > max) {
                    max = delta.get(i)[j];
                    argmax = j;
                }
            }
            res.add(states.get(psi.get(i)[argmax]));
        }
        Collections.reverse(res);
        return res;
    }

    @Override
    public int hashCode() {
        return HashCodeBuilder.reflectionHashCode(this);
    }

    @Override
    public boolean equals(Object obj) {
        return EqualsBuilder.reflectionEquals(this, obj);
    }

    @Override
    public String toString() {
        return ToStringBuilder.reflectionToString(this);
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy