All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.gengoai.apollo.ml.model.sequence.MalletCrf Maven / Gradle / Ivy

There is a newer version: 2.1
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.gengoai.apollo.ml.model.sequence;

import cc.mallet.fst.*;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.pipe.TokenSequence2FeatureVectorSequence;
import cc.mallet.types.*;
import cc.mallet.util.MalletLogger;
import com.gengoai.ParameterDef;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.Datum;
import com.gengoai.apollo.ml.model.LabelType;
import com.gengoai.apollo.ml.model.Params;
import com.gengoai.apollo.ml.model.SingleSourceFitParameters;
import com.gengoai.apollo.ml.model.SingleSourceModel;
import com.gengoai.apollo.ml.observation.Observation;
import com.gengoai.apollo.ml.observation.Variable;
import com.gengoai.apollo.ml.observation.VariableSequence;
import com.gengoai.conversion.Cast;
import lombok.NonNull;

import java.util.Arrays;
import java.util.function.Consumer;
import java.util.logging.Level;
import java.util.regex.Pattern;

import static com.gengoai.apollo.ml.model.sequence.Order.FIRST;
import static com.gengoai.collection.Arrays2.arrayOfInt;
import static com.gengoai.function.Functional.with;

/**
 * 

A wrapper around Mallet's CRF implementation

* * @author David B. Bracewell */ public class MalletCrf extends SingleSourceModel { private static final long serialVersionUID = 1L; public static final ParameterDef FULLY_CONNECTED = ParameterDef.boolParam("fullyConnected"); public static final ParameterDef ORDER = ParameterDef.param("order", Order.class); public static final ParameterDef START_STATE = ParameterDef.strParam("startState"); public static final ParameterDef THREADS = ParameterDef.intParam("numThreads"); private SerialPipes pipes; private CRF model; private String startState; /** * Instantiates a new MalletCrf with default parameters. */ public MalletCrf() { super(new Parameters()); } /** * Instantiates a new MalletCrf with the given parameters. * * @param parameters the parameters */ public MalletCrf(@NonNull Parameters parameters) { super(parameters); } /** * Instantiates a new MalletCrf with the given parameter updater. * * @param updater the updater */ public MalletCrf(@NonNull Consumer updater) { super(with(new Parameters(), updater)); } @Override public void estimate(DataSet preprocessed) { if(parameters.verbose.value()) { MalletLogger.getLogger(ThreadedOptimizable.class.getName()) .setLevel(Level.INFO); MalletLogger.getLogger(CRFTrainerByValueGradients.class.getName()) .setLevel(Level.INFO); MalletLogger.getLogger(CRF.class.getName()) .setLevel(Level.INFO); MalletLogger.getLogger(CRFOptimizableByBatchLabelLikelihood.class.getName()) .setLevel(Level.INFO); MalletLogger.getLogger(LimitedMemoryBFGS.class.getName()) .setLevel(Level.INFO); } else { MalletLogger.getLogger(ThreadedOptimizable.class.getName()) .setLevel(Level.OFF); MalletLogger.getLogger(CRFTrainerByValueGradients.class.getName()) .setLevel(Level.OFF); MalletLogger.getLogger(CRF.class.getName()) .setLevel(Level.OFF); MalletLogger.getLogger(CRFOptimizableByBatchLabelLikelihood.class.getName()) .setLevel(Level.OFF); MalletLogger.getLogger(LimitedMemoryBFGS.class.getName()) .setLevel(Level.OFF); } Alphabet dataAlphabet = new Alphabet(); pipes = new SerialPipes(Arrays.asList(new SequenceToTokenSequence(), new TokenSequence2FeatureVectorSequence(dataAlphabet, false, true))); pipes.setDataAlphabet(dataAlphabet); pipes.setTargetAlphabet(new LabelAlphabet()); InstanceList trainingData = new InstanceList(pipes); for(Datum datum : preprocessed) { com.gengoai.apollo.ml.observation.Sequence x = datum.get(parameters.input.value()).asSequence(); com.gengoai.apollo.ml.observation.Sequence y = datum.get(parameters.output.value()).asSequence(); Label[] target = new Label[x.size()]; LabelAlphabet labelAlphabet = Cast.as(trainingData.getTargetAlphabet()); for(int j = 0; j < target.length; j++) { target[j] = labelAlphabet.lookupLabel(y.get(j).asVariable().getName(), true); } trainingData.addThruPipe(new Instance(x, new LabelSequence(target), null, null)); } model = new CRF(pipes, null); int[] order = {}; switch(parameters.order.value()) { case FIRST: order = arrayOfInt(1); break; case SECOND: order = arrayOfInt(1, 2); break; case THIRD: order = arrayOfInt(1, 2, 3); break; } MalletSequenceValidator sv = Cast.as(parameters.validator.value() instanceof MalletSequenceValidator ? parameters.validator.value() : null); Pattern allowed = sv == null ? null : sv.getAllowed(); Pattern forbidden = sv == null ? null : sv.getForbidden(); model.addOrderNStates(trainingData, order, null, parameters.startState.value(), forbidden, allowed, parameters.fullyConnected.value()); this.startState = parameters.startState.value(); model.setWeightsDimensionAsIn(trainingData, false); CRFOptimizableByBatchLabelLikelihood batchOptLabel = new CRFOptimizableByBatchLabelLikelihood(model, trainingData, parameters.numberOfThreads .value()); ThreadedOptimizable optLabel = new ThreadedOptimizable(batchOptLabel, trainingData, model.getParameters().getNumFactors(), new CRFCacheStaleIndicator(model)); Optimizable.ByGradientValue[] opts = {optLabel}; CRFTrainerByValueGradients crfTrainer = new CRFTrainerByValueGradients(model, opts); crfTrainer.setMaxResets(0); crfTrainer.train(trainingData, parameters.maxIterations.value()); optLabel.shutdown(); } @Override public Parameters getFitParameters() { return new Parameters(); } @Override public LabelType getLabelType(@NonNull String name) { if(parameters.output.value().equals(name)) { return LabelType.Sequence; } throw new IllegalArgumentException("'" + name + "' is not a valid output for this model."); } @Override protected Observation transform(@NonNull Observation observation) { int length = observation.asSequence().size(); Sequence sequence = Cast.as(model.getInputPipe() .instanceFrom(new Instance(observation, null, null, null)).getData()); Sequence bestOutput = model.transduce(sequence); SumLattice lattice = new SumLatticeDefault(model, sequence, true); Transducer.State sj = model.getState(startState); VariableSequence labeling = new VariableSequence(); for(int i = 0; i < sequence.size(); i++) { Transducer.State si = model.getState((String) bestOutput.get(i)); String label = (String) bestOutput.get(i); double pS = lattice.getGammaProbability(i, si); double PSjSi = lattice.getXiProbability(i, sj, si); double score = Math.max(pS, PSjSi); sj = si; labeling.add(Variable.real(label, score)); } return labeling; } @Override protected void updateMetadata(@NonNull DataSet data) { data.updateMetadata(parameters.output.value(), m -> { m.setEncoder(null); m.setType(VariableSequence.class); m.setDimension(-1); }); } /** * MalletCrf Fit Parameters */ public static class Parameters extends SingleSourceFitParameters { private static final long serialVersionUID = 1L; /** * The sequence validator to user during inference (default {@link SequenceValidator#ALWAYS_TRUE}) */ public final Parameter validator = parameter(Params.Sequence.validator, SequenceValidator.ALWAYS_TRUE); /** * The number of threads to use for training (default 20) */ public final Parameter numberOfThreads = parameter(THREADS, 20); /** * The order of the CRF (default {@link Order#FIRST}) */ public final Parameter order = parameter(ORDER, FIRST); /** * The maximum number of iterations to run for (default 250). */ public final Parameter maxIterations = parameter(Params.Optimizable.maxIterations, 250); /** * Parameter denoting whether the generated graphical model is fully connected (default true). */ public final Parameter fullyConnected = parameter(FULLY_CONNECTED, true); /** * Parameter denoting the start state (default O). */ public final Parameter startState = parameter(START_STATE, "O"); } }//END OF MalletCRF




© 2015 - 2025 Weber Informatics LLC | Privacy Policy