All Downloads are FREE. Search and download functionalities are using the official Maven repository.

eus.ixa.ixa.pipe.ml.SequenceLabelerTrainer Maven / Gradle / Ivy

There is a newer version: 0.0.8
Show newest version
/*
 *  Copyright 2016 Rodrigo Agerri

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.
 */

package eus.ixa.ixa.pipe.ml;

import java.io.IOException;
import java.nio.charset.Charset;
import java.util.Map;

import eus.ixa.ixa.pipe.ml.features.XMLFeatureDescriptor;
import eus.ixa.ixa.pipe.ml.formats.CoNLL02Format;
import eus.ixa.ixa.pipe.ml.formats.CoNLL03Format;
import eus.ixa.ixa.pipe.ml.formats.LemmatizerFormat;
import eus.ixa.ixa.pipe.ml.formats.TabulatedFormat;
import eus.ixa.ixa.pipe.ml.resources.LoadModelResources;
import eus.ixa.ixa.pipe.ml.sequence.BilouCodec;
import eus.ixa.ixa.pipe.ml.sequence.BioCodec;
import eus.ixa.ixa.pipe.ml.sequence.SequenceLabelSample;
import eus.ixa.ixa.pipe.ml.sequence.SequenceLabelSampleTypeFilter;
import eus.ixa.ixa.pipe.ml.sequence.SequenceLabelerCodec;
import eus.ixa.ixa.pipe.ml.sequence.SequenceLabelerEvaluator;
import eus.ixa.ixa.pipe.ml.sequence.SequenceLabelerFactory;
import eus.ixa.ixa.pipe.ml.sequence.SequenceLabelerME;
import eus.ixa.ixa.pipe.ml.sequence.SequenceLabelerModel;
import eus.ixa.ixa.pipe.ml.utils.Flags;
import eus.ixa.ixa.pipe.ml.utils.IOUtils;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.TrainingParameters;

/**
 * Trainer based on Apache OpenNLP Machine Learning API. This class creates a
 * feature set based on the features activated in the trainParams.properties
 * file:
 * 
    *
  1. Window: specify left and right window lengths. *
  2. TokenFeatures: tokens as features in a window length. *
  3. TokenClassFeatures: token shape features in a window length. *
  4. WordShapeSuperSenseFeatures: token shape features from Ciaramita and * Altun (2006). *
  5. OutcomePriorFeatures: take into account previous outcomes. *
  6. PreviousMapFeatures: add features based on tokens and previous decisions. *
  7. SentenceFeatures: add beginning and end of sentence words. *
  8. PrefixFeatures: first 4 characters in current token. *
  9. SuffixFeatures: last 4 characters in current token. *
  10. BigramClassFeatures: bigrams of tokens and token class. *
  11. TrigramClassFeatures: trigrams of token and token class. *
  12. FourgramClassFeatures: fourgrams of token and token class. *
  13. FivegramClassFeatures: fivegrams of token and token class. *
  14. CharNgramFeatures: character ngram features of current token. *
  15. DictionaryFeatures: check if current token appears in some gazetteer. *
  16. ClarkClusterFeatures: use the clustering class of a token as a feature. *
  17. BrownClusterFeatures: use brown clusters as features for each feature * containing a token. *
  18. Word2VecClusterFeatures: use the word2vec clustering class of a token as * a feature. *
  19. POSTagModelFeatures: use pos tags, pos tag class as features. *
  20. LemmaModelFeatures: use lemma as features. *
  21. LemmaDictionaryFeatures: use lemma from a dictionary as features. *
  22. MFSFeatures: Most Frequent sense feature. *
  23. SuperSenseFeatures: Ciaramita and Altun (2006) features for super sense * tagging. *
  24. POSBaselineFeatures: train a baseline POS tagger. *
  25. LemmaBaselineFeatures: train a baseline Lemmatizer. *
  26. ChunkBaselineFeatures: train a baseline chunker. *
* * @author ragerri * @version 2016-05-06 */ public class SequenceLabelerTrainer { /** * The language. */ private final String lang; /** * String holding the training data. */ private final String trainData; /** * String holding the testData. */ private final String testData; /** * ObjectStream of the training data. */ private ObjectStream trainSamples; /** * ObjectStream of the test data. */ private ObjectStream testSamples; /** * The corpus format: conll02, conll03, lemmatizer, tabulated. */ private final String corpusFormat; /** * The sequence encoding of the named entity spans, e.g., BIO or BILOU. */ private String sequenceCodec; /** * Reset the adaptive features every newline in the training data. */ private final String clearTrainingFeatures; /** * Reset the adaptive features every newline in the testing data. */ private final String clearEvaluationFeatures; /** * features needs to be implemented by any class extending this one. */ private SequenceLabelerFactory nameClassifierFactory; /** * Construct a trainer with training and test data, and with options for * language, beamsize for decoding, sequence codec and corpus format (conll or * opennlp). * * @param params * the training parameters * @throws IOException * io exception */ public SequenceLabelerTrainer(final TrainingParameters params) throws IOException { this.lang = Flags.getLanguage(params); this.clearTrainingFeatures = Flags.getClearTrainingFeatures(params); this.clearEvaluationFeatures = Flags.getClearEvaluationFeatures(params); this.corpusFormat = Flags.getCorpusFormat(params); this.trainData = params.getSettings().get("TrainSet"); this.testData = params.getSettings().get("TestSet"); this.trainSamples = getSequenceStream(this.trainData, this.clearTrainingFeatures, this.corpusFormat); this.testSamples = getSequenceStream(this.testData, this.clearEvaluationFeatures, this.corpusFormat); this.sequenceCodec = Flags.getSequenceCodec(params); if (params.getSettings().get("Types") != null) { final String netypes = params.getSettings().get("Types"); final String[] neTypes = netypes.split(","); this.trainSamples = new SequenceLabelSampleTypeFilter(neTypes, this.trainSamples); this.testSamples = new SequenceLabelSampleTypeFilter(neTypes, this.testSamples); } createSequenceLabelerFactory(params); } /** * Create {@code SequenceLabelerFactory} with custom features. * * @param params * the parameter training file * @throws IOException * if io error */ public void createSequenceLabelerFactory(final TrainingParameters params) throws IOException { final String seqCodec = getSequenceCodec(); final SequenceLabelerCodec sequenceCodec = SequenceLabelerFactory .instantiateSequenceCodec(seqCodec); final String featureDescription = XMLFeatureDescriptor .createXMLFeatureDescriptor(params); System.err.println(featureDescription); final byte[] featureGeneratorBytes = featureDescription .getBytes(Charset.forName("UTF-8")); final Map resources = LoadModelResources .loadSequenceResources(params); setSequenceLabelerFactory( SequenceLabelerFactory.create(SequenceLabelerFactory.class.getName(), featureGeneratorBytes, resources, sequenceCodec)); } public final SequenceLabelerModel train(final TrainingParameters params) { if (getSequenceLabelerFactory() == null) { throw new IllegalStateException( "The SequenceLabelerFactory must be instantiated!!"); } SequenceLabelerModel trainedModel = null; try { trainedModel = SequenceLabelerME.train(this.lang, this.trainSamples, params, this.nameClassifierFactory); final SequenceLabelerME seqLabeler = new SequenceLabelerME(trainedModel); trainingEvaluate(seqLabeler); } catch (final IOException e) { System.err.println("IO error while loading traing and test sets!"); e.printStackTrace(); System.exit(1); } return trainedModel; } private void trainingEvaluate(final SequenceLabelerME sequenceLabeler) { if (this.corpusFormat.equalsIgnoreCase("lemmatizer") || this.corpusFormat.equalsIgnoreCase("tabulated")) { final SequenceLabelerEvaluator evaluator = new SequenceLabelerEvaluator( this.trainSamples, this.corpusFormat, sequenceLabeler); try { evaluator.evaluate(this.testSamples); } catch (final IOException e) { e.printStackTrace(); } System.out.println(); System.out.println("Word Accuracy: " + evaluator.getWordAccuracy()); System.out .println("Sentence Accuracy: " + evaluator.getSentenceAccuracy()); } else { final SequenceLabelerEvaluator evaluator = new SequenceLabelerEvaluator(this.corpusFormat, sequenceLabeler); try { evaluator.evaluate(this.testSamples); } catch (final IOException e) { e.printStackTrace(); } System.out.println("Final Result: \n" + evaluator.getFMeasure()); } } /** * Getting the stream with the right corpus format. * * @param inputData * the input data * @param clearFeatures * clear the features * @param aCorpusFormat * the corpus format * @return the stream from the several corpus formats * @throws IOException * the io exception */ public static ObjectStream getSequenceStream( final String inputData, final String clearFeatures, final String aCorpusFormat) throws IOException { ObjectStream samples = null; if (aCorpusFormat.equalsIgnoreCase("conll03")) { final ObjectStream nameStream = IOUtils .readFileIntoMarkableStreamFactory(inputData); samples = new CoNLL03Format(clearFeatures, nameStream); } else if (aCorpusFormat.equalsIgnoreCase("conll02")) { final ObjectStream nameStream = IOUtils .readFileIntoMarkableStreamFactory(inputData); samples = new CoNLL02Format(clearFeatures, nameStream); } else if (aCorpusFormat.equalsIgnoreCase("tabulated")) { final ObjectStream nameStream = IOUtils .readFileIntoMarkableStreamFactory(inputData); samples = new TabulatedFormat(clearFeatures, nameStream); } else if (aCorpusFormat.equalsIgnoreCase("lemmatizer")) { final ObjectStream seqStream = IOUtils .readFileIntoMarkableStreamFactory(inputData); samples = new LemmatizerFormat(clearFeatures, seqStream); } else { System.err.println("Test set corpus format not valid!!"); System.exit(1); } return samples; } /** * Get the features which are implemented in each of the trainers extending * this class. * * @return the features */ public final SequenceLabelerFactory getSequenceLabelerFactory() { return this.nameClassifierFactory; } public final SequenceLabelerFactory setSequenceLabelerFactory( final SequenceLabelerFactory tokenNameFinderFactory) { this.nameClassifierFactory = tokenNameFinderFactory; return this.nameClassifierFactory; } /** * Get the Sequence codec. * * @return the sequence codec */ public final String getSequenceCodec() { String seqCodec = null; if ("BIO".equals(this.sequenceCodec)) { seqCodec = BioCodec.class.getName(); } else if ("BILOU".equals(this.sequenceCodec)) { seqCodec = BilouCodec.class.getName(); } return seqCodec; } /** * Set the sequence codec. * * @param aSeqCodec * the sequence codec to be set */ public final void setSequenceCodec(final String aSeqCodec) { this.sequenceCodec = aSeqCodec; } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy