All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.sequence.CRFLabeler Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2010-2021 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Smile.  If not, see .
 */

package smile.sequence;

import java.io.Serial;
import java.util.Arrays;
import java.util.Properties;
import java.util.function.Function;
import smile.data.Tuple;

/**
 * First-order CRF sequence labeler.
 *
 * @param  the data type of model input objects.
 *
 * @author Haifeng Li
 */
public class CRFLabeler implements SequenceLabeler {
    @Serial
    private static final long serialVersionUID = 2L;

    /** The CRF model. */
    public final CRF model;
    /** The feature function. */
    public final Function features;

    /**
     * Constructor.
     *
     * @param model the CRF model.
     * @param features the feature function.
     */
    public CRFLabeler(CRF model, Function features) {
        this.model = model;
        this.features = features;
    }

    /**
     * Fits a CRF model.
     * @param sequences the training data.
     * @param labels the training sequence labels.
     * @param features the feature function.
     * @param  the data type of observations.
     * @return the model.
     */
    public static  CRFLabeler fit(T[][] sequences, int[][] labels, Function features) {
        return fit(sequences, labels, features, new Properties());
    }

    /**
     * Fits a CRF model.
     * @param sequences the training data.
     * @param labels the training sequence labels.
     * @param features the feature function.
     * @param params the hyperparameters.
     * @param  the data type of observations.
     * @return the model.
     */
    public static  CRFLabeler fit(T[][] sequences, int[][] labels, Function features, Properties params) {
        int ntrees = Integer.parseInt(params.getProperty("smile.crf.trees", "100"));
        int maxDepth = Integer.parseInt(params.getProperty("smile.crf.max_depth", "20"));
        int maxNodes = Integer.parseInt(params.getProperty("smile.crf.max_nodes", "100"));
        int nodeSize = Integer.parseInt(params.getProperty("smile.crf.node_size", "5"));
        double shrinkage = Double.parseDouble(params.getProperty("smile.crf.shrinkage", "1.0"));
        return fit(sequences, labels, features, ntrees, maxDepth, maxNodes, nodeSize, shrinkage);
    }

    /**
     * Fits a CRF.
     *
     * @param sequences the observation sequences.
     * @param labels the state labels of observations, of which states take
     *               values in [0, k), where k is the number of hidden states.
     * @param features the feature function.
     * @param ntrees the number of trees/iterations.
     * @param maxDepth the maximum depth of the tree.
     * @param maxNodes the maximum number of leaf nodes in the tree.
     * @param nodeSize  the number of instances in a node below which the tree will
     *                  not split, setting nodeSize = 5 generally gives good results.
     * @param shrinkage the shrinkage parameter in (0, 1] controls the learning rate of procedure.
     * @param  the data type of observations.
     * @return the model.
     */
    public static  CRFLabeler fit(T[][] sequences, int[][] labels, Function features, int ntrees, int maxDepth, int maxNodes, int nodeSize, double shrinkage) {
        if (sequences.length != labels.length) {
            throw new IllegalArgumentException("The number of observation sequences and that of label sequences are different.");
        }

        CRF model = CRF.fit(
                Arrays.stream(sequences)
                        .map(sequence -> Arrays.stream(sequence).map(features).toArray(Tuple[]::new))
                        .toArray(Tuple[][]::new),
                labels, ntrees, maxDepth, maxNodes, nodeSize, shrinkage);

        return new CRFLabeler<>(model, features);
    }

    @Override
    public String toString() {
        return model.toString();
    }

    /**
     * Translates an observation sequence to internal representation.
     */
    private Tuple[] translate(T[] o) {
        return Arrays.stream(o).map(features).toArray(Tuple[]::new);
    }

    /**
     * Returns the most likely label sequence given the feature sequence by the
     * forward-backward algorithm.
     *
     * @param o the observation sequence.
     * @return the most likely state sequence.
     */
    @Override
    public int[] predict(T[] o) {
        return model.predict(translate(o));
    }

    /**
     * Labels sequence with Viterbi algorithm. Viterbi algorithm
     * returns the whole sequence label that has the maximum probability,
     * which makes sense in applications (e.g.part-of-speech tagging) that
     * require coherent sequential labeling. The forward-backward algorithm
     * labels a sequence by individual prediction on each position.
     * This usually produces better accuracy although the results may not
     * be coherent.
     *
     * @param o the observation sequence.
     * @return the sequence labels.
     */
    public int[] viterbi(T[] o) {
        return model.viterbi(translate(o));
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy