All Downloads are FREE. Search and download functionalities are using the official Maven repository.

opennlp.model.TrainUtil Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package opennlp.model;

import java.io.IOException;
import java.util.Map;

import opennlp.maxent.quasinewton.QNTrainer;
import opennlp.perceptron.PerceptronTrainer;
import opennlp.perceptron.SimplePerceptronSequenceTrainer;

public class TrainUtil {

  public static final String ALGORITHM_PARAM = "Algorithm";
  
  public static final String MAXENT_VALUE = "MAXENT";
  public static final String MAXENT_QN_VALUE = "MAXENT_QN_EXPERIMENTAL";
  public static final String PERCEPTRON_VALUE = "PERCEPTRON";
  public static final String PERCEPTRON_SEQUENCE_VALUE = "PERCEPTRON_SEQUENCE";
  
  
  public static final String CUTOFF_PARAM = "Cutoff";
  private static final int CUTOFF_DEFAULT = 5;
  
  public static final String ITERATIONS_PARAM = "Iterations";
  private static final int ITERATIONS_DEFAULT = 100;
  
  public static final String DATA_INDEXER_PARAM = "DataIndexer";
  public static final String DATA_INDEXER_ONE_PASS_VALUE = "OnePass";
  public static final String DATA_INDEXER_TWO_PASS_VALUE = "TwoPass";
  
  
  private static String getStringParam(Map trainParams, String key,
      String defaultValue, Map reportMap) {

    String valueString = trainParams.get(key);

    if (valueString == null)
      valueString = defaultValue;
    
    if (reportMap != null)
      reportMap.put(key, valueString);
    
    return valueString;
  }
  
  private static int getIntParam(Map trainParams, String key,
      int defaultValue, Map reportMap) {

    String valueString = trainParams.get(key);

    if (valueString != null)
      return Integer.parseInt(valueString);
    else
      return defaultValue;
  }
  
  private static double getDoubleParam(Map trainParams, String key,
      double defaultValue, Map reportMap) {
    
    String valueString = trainParams.get(key);
    
    if (valueString != null)
      return Double.parseDouble(valueString);
    else
      return defaultValue;
  }
  
  private static boolean getBooleanParam(Map trainParams, String key,
      boolean defaultValue, Map reportMap) {

    String valueString = trainParams.get(key);

    if (valueString != null)
      return Boolean.parseBoolean(valueString);
    else
      return defaultValue;
  }
  
  public static boolean isValid(Map trainParams) {

    // TODO: Need to validate all parameters correctly ... error prone?!
    
    String algorithmName = trainParams.get(ALGORITHM_PARAM);

    if (algorithmName != null && !(MAXENT_VALUE.equals(algorithmName) ||
    	MAXENT_QN_VALUE.equals(algorithmName) ||
        PERCEPTRON_VALUE.equals(algorithmName) ||
        PERCEPTRON_SEQUENCE_VALUE.equals(algorithmName))) {
      return false;
    }

    try {
      String cutoffString = trainParams.get(CUTOFF_PARAM);
      if (cutoffString != null) Integer.parseInt(cutoffString);
      
      String iterationsString = trainParams.get(ITERATIONS_PARAM);
      if (iterationsString != null) Integer.parseInt(iterationsString);
    }
    catch (NumberFormatException e) {
      return false;
    }
    
    String dataIndexer = trainParams.get(DATA_INDEXER_PARAM);
    
    if (dataIndexer != null) {
      if (!("OnePass".equals(dataIndexer) || "TwoPass".equals(dataIndexer))) {
        return false;
      }
    }
    
    // TODO: Check data indexing ... 
     
    return true;
  }
  
  
  
  // TODO: Need a way to report results and settings back for inclusion in model ...
  
  public static AbstractModel train(EventStream events, Map trainParams, Map reportMap) 
      throws IOException {
    
    if (!isValid(trainParams))
        throw new IllegalArgumentException("trainParams are not valid!");
    
    if(isSequenceTraining(trainParams))
      throw new IllegalArgumentException("sequence training is not supported by this method!");
    
    String algorithmName = getStringParam(trainParams, ALGORITHM_PARAM, MAXENT_VALUE, reportMap);
    
    int iterations = getIntParam(trainParams, ITERATIONS_PARAM, ITERATIONS_DEFAULT, reportMap);
        
    int cutoff = getIntParam(trainParams, CUTOFF_PARAM, CUTOFF_DEFAULT, reportMap);

    boolean sortAndMerge;
    
    if (MAXENT_VALUE.equals(algorithmName) || MAXENT_QN_VALUE.equals(algorithmName))
      sortAndMerge = true;
    else if (PERCEPTRON_VALUE.equals(algorithmName))
      sortAndMerge = false;
    else
      throw new IllegalStateException("Unexpected algorithm name: " + algorithmName);

    HashSumEventStream hses = new HashSumEventStream(events);
    
    String dataIndexerName = getStringParam(trainParams, DATA_INDEXER_PARAM,
        DATA_INDEXER_TWO_PASS_VALUE, reportMap);

    DataIndexer indexer = null;
    
    if (DATA_INDEXER_ONE_PASS_VALUE.equals(dataIndexerName)) {
      indexer = new OnePassDataIndexer(hses, cutoff, sortAndMerge);
    }
    else if (DATA_INDEXER_TWO_PASS_VALUE.equals(dataIndexerName)) {
      indexer = new TwoPassDataIndexer(hses, cutoff, sortAndMerge);
    }
    else {
      throw new IllegalStateException("Unexpected data indexer name: " +  dataIndexerName);
    }
    
    AbstractModel model;
    if (MAXENT_VALUE.equals(algorithmName)) {
      
      int threads = getIntParam(trainParams, "Threads", 1, reportMap);
      
      model = opennlp.maxent.GIS.trainModel(iterations, indexer,
          true, false, null, 0, threads);
    }
    else if (MAXENT_QN_VALUE.equals(algorithmName)) {
      int m = getIntParam(trainParams, "numOfUpdates", QNTrainer.DEFAULT_M, reportMap);
      int maxFctEval = getIntParam(trainParams, "maxFctEval", QNTrainer.DEFAULT_MAX_FCT_EVAL, reportMap);
      model = new QNTrainer(m, maxFctEval, true).trainModel(indexer);
    }
    else if (PERCEPTRON_VALUE.equals(algorithmName)) {
      boolean useAverage = getBooleanParam(trainParams, "UseAverage", true, reportMap);
      
      boolean useSkippedAveraging = getBooleanParam(trainParams, "UseSkippedAveraging", false, reportMap);
      
      // overwrite otherwise it might not work
      if (useSkippedAveraging)
        useAverage = true;
      
      double stepSizeDecrease = getDoubleParam(trainParams, "StepSizeDecrease", 0, reportMap);
      
      double tolerance = getDoubleParam(trainParams, "Tolerance", PerceptronTrainer.TOLERANCE_DEFAULT, reportMap);
      
      opennlp.perceptron.PerceptronTrainer perceptronTrainer = new opennlp.perceptron.PerceptronTrainer();
      perceptronTrainer.setSkippedAveraging(useSkippedAveraging);
      
      if (stepSizeDecrease > 0)
        perceptronTrainer.setStepSizeDecrease(stepSizeDecrease);
      
      perceptronTrainer.setTolerance(tolerance);
      
      model = perceptronTrainer.trainModel(
          iterations, indexer, cutoff, useAverage);
    }
    else {
      throw new IllegalStateException("Algorithm not supported: " + algorithmName);
    }
    
    if (reportMap != null)
        reportMap.put("Training-Eventhash", hses.calculateHashSum().toString(16));
    
    return model;
  }
  
  /**
   * Detects if the training algorithm requires sequence based feature generation
   * or not.
   */
  public static boolean isSequenceTraining(Map trainParams) {
    return PERCEPTRON_SEQUENCE_VALUE.equals(trainParams.get(ALGORITHM_PARAM));
  }
  
  public static AbstractModel train(SequenceStream events, Map trainParams,
      Map reportMap) throws IOException {
    
    if (!isValid(trainParams))
      throw new IllegalArgumentException("trainParams are not valid!");
  
    if (!isSequenceTraining(trainParams))
      throw new IllegalArgumentException("Algorithm must be a sequence algorithm!");
    
    int iterations = getIntParam(trainParams, ITERATIONS_PARAM, ITERATIONS_DEFAULT, reportMap);
    int cutoff = getIntParam(trainParams, CUTOFF_PARAM, CUTOFF_DEFAULT, reportMap);
    
    boolean useAverage = getBooleanParam(trainParams, "UseAverage", true, reportMap);
    
    return new SimplePerceptronSequenceTrainer().trainModel(
        iterations, events, cutoff,useAverage);
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy