All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.sequence.HMMLabeler Maven / Gradle / Ivy

/*
 * Copyright (c) 2010-2021 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Smile.  If not, see .
 */

package smile.sequence;

import java.io.Serial;
import java.util.Arrays;
import java.util.function.ToIntFunction;

/**
 * First-order Hidden Markov Model sequence labeler.
 *
 * @param  the data type of model input objects.
 *
 * @author Haifeng Li
 */
public class HMMLabeler implements SequenceLabeler {
    @Serial
    private static final long serialVersionUID = 2L;

    /** The HMM model. */
    public final HMM model;
    /** The lambda returns the ordinal numbers of symbols. */
    public final ToIntFunction ordinal;

    /**
     * Constructor.
     *
     * @param model the HMM model.
     * @param ordinal a lambda returning the ordinal numbers of symbols.
     */
    public HMMLabeler(HMM model, ToIntFunction ordinal) {
        this.model = model;
        this.ordinal = ordinal;
    }

    /**
     * Fits an HMM by maximum likelihood estimation.
     *
     * @param observations the observation sequences.
     * @param labels the state labels of observations, of which states take
     *               values in [0, p), where p is the number of hidden states.
     * @param ordinal a lambda returning the ordinal numbers of symbols.
     * @param  the data type of observations.
     * @return the model.
     */
    public static  HMMLabeler fit(T[][] observations, int[][] labels, ToIntFunction ordinal) {
        if (observations.length != labels.length) {
            throw new IllegalArgumentException("The number of observation sequences and that of label sequences are different.");
        }

        HMM model = HMM.fit(
                Arrays.stream(observations)
                        .map(sequence -> Arrays.stream(sequence).mapToInt(ordinal).toArray())
                        .toArray(int[][]::new),
                labels);

        return new HMMLabeler<>(model, ordinal);
    }

    /**
     * Updates the HMM by the Baum-Welch algorithm.
     *
     * @param observations the training observation sequences.
     * @param iterations the number of iterations to execute.
     */
    public void update(T[][] observations, int iterations) {
        model.update(
                Arrays.stream(observations)
                        .map(sequence -> Arrays.stream(sequence).mapToInt(ordinal).toArray())
                        .toArray(int[][]::new),
                iterations);
    }

    @Override
    public String toString() {
        return model.toString();
    }

    /**
     * Translates an observation sequence to internal representation.
     */
    private int[] translate(T[] o) {
        return Arrays.stream(o).mapToInt(ordinal).toArray();
    }

    /**
     * Returns the joint probability of an observation sequence along a state
     * sequence.
     *
     * @param o an observation sequence.
     * @param s a state sequence.
     * @return the joint probability P(o, s | H) given the model H.
     */
    public double p(T[] o, int[] s) {
        return model.p(translate(o), s);
    }

    /**
     * Returns the log joint probability of an observation sequence along a
     * state sequence.
     *
     * @param o an observation sequence.
     * @param s a state sequence.
     * @return the log joint probability P(o, s | H) given the model H.
     */
    public double logp(T[] o, int[] s) {
        return model.logp(translate(o), s);
    }

    /**
     * Returns the probability of an observation sequence.
     *
     * @param o an observation sequence.
     * @return the probability of this sequence.
     */
    public double p(T[] o) {
        return model.p(translate(o));
    }

    /**
     * Returns the logarithm probability of an observation sequence.
     * A scaling procedure is used in order to avoid underflow when
     * computing the probability of long sequences.
     *
     * @param o an observation sequence.
     * @return the log probability of this sequence.
     */
    public double logp(T[] o) {
        return model.logp(translate(o));
    }

    /**
     * Returns the most likely state sequence given the observation sequence by
     * the Viterbi algorithm, which maximizes the probability of
     * P(I | O, HMM). In the calculation, we may get ties. In this
     * case, one of them is chosen randomly.
     *
     * @param o an observation sequence.
     * @return the most likely state sequence.
     */
    @Override
    public int[] predict(T[] o) {
        return model.predict(translate(o));
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy