All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.joliciel.talismane.machineLearning.perceptron.PerceptronClassificationModelTrainer Maven / Gradle / Ivy

There is a newer version: 6.1.8
Show newest version
///////////////////////////////////////////////////////////////////////////////
//Copyright (C) 2014 Joliciel Informatique
//
//This file is part of Talismane.
//
//Talismane is free software: you can redistribute it and/or modify
//it under the terms of the GNU Affero General Public License as published by
//the Free Software Foundation, either version 3 of the License, or
//(at your option) any later version.
//
//Talismane is distributed in the hope that it will be useful,
//but WITHOUT ANY WARRANTY; without even the implied warranty of
//MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//GNU Affero General Public License for more details.
//
//You should have received a copy of the GNU Affero General Public License
//along with Talismane.  If not, see .
//////////////////////////////////////////////////////////////////////////////
package com.joliciel.talismane.machineLearning.perceptron;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.Writer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Scanner;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.joliciel.talismane.TalismaneException;
import com.joliciel.talismane.machineLearning.ClassificationEvent;
import com.joliciel.talismane.machineLearning.ClassificationEventStream;
import com.joliciel.talismane.machineLearning.ClassificationModel;
import com.joliciel.talismane.machineLearning.ClassificationModelTrainer;
import com.joliciel.talismane.machineLearning.MachineLearningModel;
import com.joliciel.talismane.utils.LogUtils;
import com.typesafe.config.Config;

public class PerceptronClassificationModelTrainer implements ClassificationModelTrainer {
  /**
   * A parameter accepted by the perceptron model trainer.
   * 
   * @author Assaf Urieli
   *
   */
  public enum PerceptronModelParameter {
    Iterations(Integer.class),
    Cutoff(Integer.class),
    Tolerance(Double.class),
    AverageAtIntervals(Boolean.class);

    private Class parameterType;

    private PerceptronModelParameter(Class parameterType) {
      this.parameterType = parameterType;
    }

    public Class getParameterType() {
      return parameterType;
    }
  }

  private static final Logger LOG = LoggerFactory.getLogger(PerceptronClassificationModelTrainer.class);
  private int iterations;
  private int cutoff;
  private double tolerance;
  private PerceptronScoring scoring;

  private double[][] totalFeatureWeights;
  private PerceptronModelParameters params;
  private File eventFile;
  private PerceptronDecisionMaker decisionMaker;
  private Map> descriptors;
  private ClassificationEventStream corpusEventStream;
  private PerceptronModelTrainerObserver observer;
  private List observationPoints;
  private boolean averageAtIntervals = false;

  private Config config;

  public PerceptronClassificationModelTrainer() {
  }

  void prepareData(ClassificationEventStream eventStream) throws TalismaneException {
    try {
      eventFile = File.createTempFile("events", "txt");
      eventFile.deleteOnExit();
      Writer eventWriter = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(eventFile), "UTF-8"));
      while (eventStream.hasNext()) {
        ClassificationEvent corpusEvent = eventStream.next();
        PerceptronEvent event = new PerceptronEvent(corpusEvent, params);
        event.write(eventWriter);
      }
      eventWriter.flush();
      eventWriter.close();

      if (cutoff > 1) {
        params.initialiseCounts();
        File originalEventFile = eventFile;
        try (Scanner scanner = new Scanner(new BufferedReader(new InputStreamReader(new FileInputStream(eventFile), "UTF-8")))) {

          while (scanner.hasNextLine()) {
            String line = scanner.nextLine();
            PerceptronEvent event = new PerceptronEvent(line);
            for (int featureIndex : event.getFeatureIndexes()) {
              params.getFeatureCounts()[featureIndex]++;
            }
          }
        }

        if (LOG.isDebugEnabled()) {
          int[] cutoffCounts = new int[21];
          for (int count : params.getFeatureCounts()) {
            for (int i = 1; i < 21; i++) {
              if (count >= i) {
                cutoffCounts[i]++;
              }
            }
          }
          LOG.debug("Feature counts:");
          for (int i = 1; i < 21; i++) {
            LOG.debug("Cutoff " + i + ": " + cutoffCounts[i]);
          }
        }
        PerceptronModelParameters cutoffParams = new PerceptronModelParameters();
        int[] newIndexes = cutoffParams.initialise(params, cutoff);
        decisionMaker = new PerceptronDecisionMaker(cutoffParams, this.scoring);
        try (Scanner scanner = new Scanner(new BufferedReader(new InputStreamReader(new FileInputStream(eventFile), "UTF-8")))) {
          eventFile = File.createTempFile("eventsCutoff", "txt");
          eventFile.deleteOnExit();
          try (Writer eventCutoffWriter = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(eventFile), "UTF-8"))) {
            while (scanner.hasNextLine()) {
              String line = scanner.nextLine();
              PerceptronEvent oldEvent = new PerceptronEvent(line);
              PerceptronEvent newEvent = new PerceptronEvent(oldEvent, newIndexes);
              newEvent.write(eventCutoffWriter);
            }
            eventCutoffWriter.flush();
          }
          params = cutoffParams;
          originalEventFile.delete();
        }
      }

      params.initialiseWeights();
      totalFeatureWeights = new double[params.getFeatureCount()][params.getOutcomeCount()];
    } catch (IOException e) {
      LogUtils.logError(LOG, e);
      throw new RuntimeException(e);
    }
  }

  void train() {
    try {
      double prevAccuracy1 = 0.0;
      double prevAccuracy2 = 0.0;
      double prevAccuracy3 = 0.0;
      int i = 0;
      int averagingCount = 0;
      for (i = 1; i <= iterations; i++) {
        LOG.debug("Iteration " + i);
        int totalErrors = 0;
        int totalEvents = 0;

        try (Scanner scanner = new Scanner(new BufferedReader(new InputStreamReader(new FileInputStream(eventFile), "UTF-8")))) {

          while (scanner.hasNextLine()) {
            String line = scanner.nextLine();
            PerceptronEvent event = new PerceptronEvent(line);
            totalEvents++;

            // don't normalise unless we calculate the
            // log-likelihood,
            // to avoid mathematical cost of normalising
            double[] results = decisionMaker.predict(event.getFeatureIndexes(), event.getFeatureValues());
            double maxValue = results[0];
            int predicted = 0;
            for (int j = 1; j < results.length; j++) {
              if (results[j] > maxValue) {
                maxValue = results[j];
                predicted = j;
              }
            }

            int actual = event.getOutcomeIndex();

            if (actual != predicted) {
              for (int j = 0; j < event.getFeatureIndexes().size(); j++) {
                double[] classWeights = params.getFeatureWeights()[event.getFeatureIndexes().get(j)];
                classWeights[actual] += event.getFeatureValues().get(j);
                classWeights[predicted] -= event.getFeatureValues().get(j);
              }
              totalErrors++;
            } // correct outcome?
          } // next event
        }

        // Add feature weights for this iteration
        boolean addAverage = true;
        if (this.isAverageAtIntervals()) {
          if (i <= 20 || i == 25 || i == 36 || i == 49 || i == 64 || i == 81 || i == 100 || i == 121 || i == 144 || i == 169 || i == 196) {
            addAverage = true;
            LOG.debug("Averaging at iteration: " + i);
          } else
            addAverage = false;
        }

        if (addAverage) {
          for (int j = 0; j < params.getFeatureWeights().length; j++) {
            double[] totalClassWeights = totalFeatureWeights[j];
            double[] classWeights = params.getFeatureWeights()[j];
            for (int k = 0; k < params.getOutcomeCount(); k++) {
              totalClassWeights[k] += classWeights[k];
            }
          }
          averagingCount++;
        }

        if (observer != null && observationPoints.contains(i)) {
          PerceptronModelParameters cloneParams = params.clone();
          // average the weights for this model
          for (int j = 0; j < cloneParams.getFeatureWeights().length; j++) {
            double[] totalClassWeights = totalFeatureWeights[j];
            double[] classWeights = cloneParams.getFeatureWeights()[j];
            for (int k = 0; k < cloneParams.getOutcomeCount(); k++) {
              classWeights[k] = totalClassWeights[k] / averagingCount;
            }
          }
          ClassificationModel model = this.getModel(cloneParams, i);
          observer.onNextModel(model, i);
          cloneParams = null;
        }

        double accuracy = (double) (totalEvents - totalErrors) / (double) totalEvents;
        LOG.debug("Accuracy: " + accuracy);

        // exit if accuracy hasn't significantly changed in 3 iterations
        if (Math.abs(accuracy - prevAccuracy1) < tolerance && Math.abs(accuracy - prevAccuracy2) < tolerance
            && Math.abs(accuracy - prevAccuracy3) < tolerance) {
          LOG.info("Accuracy change < " + tolerance + " for 3 iterations: exiting after " + i + " iterations");
          break;
        }

        prevAccuracy3 = prevAccuracy2;
        prevAccuracy2 = prevAccuracy1;
        prevAccuracy1 = accuracy;
      } // next iteration

      // average the final weights
      for (int j = 0; j < params.getFeatureWeights().length; j++) {
        double[] totalClassWeights = totalFeatureWeights[j];
        double[] classWeights = params.getFeatureWeights()[j];
        for (int k = 0; k < params.getOutcomeCount(); k++) {
          classWeights[k] = totalClassWeights[k] / averagingCount;
        }
      }

    } catch (IOException e) {
      LogUtils.logError(LOG, e);
      throw new RuntimeException(e);
    }
  }

  private static final class PerceptronEvent {
    List featureIndexes;
    List featureValues;
    int outcomeIndex;

    public PerceptronEvent(ClassificationEvent corpusEvent, PerceptronModelParameters params) {
      featureIndexes = new ArrayList();
      featureValues = new ArrayList();
      params.prepareData(corpusEvent.getFeatureResults(), featureIndexes, featureValues, true);
      outcomeIndex = params.getOrCreateOutcomeIndex(corpusEvent.getClassification());
    }

    public PerceptronEvent(String line) {
      String[] parts = line.split(" ");
      this.outcomeIndex = Integer.parseInt(parts[0]);
      int featureCount = (parts.length - 1) / 2;
      featureIndexes = new ArrayList(featureCount);
      featureValues = new ArrayList(featureCount);
      int j = 1;
      for (int i = 0; i < featureCount; i++) {
        featureIndexes.add(Integer.parseInt(parts[j++]));
        featureValues.add(Double.parseDouble(parts[j++]));
      }
    }

    public PerceptronEvent(PerceptronEvent oldEvent, int[] newIndexes) {
      featureIndexes = new ArrayList();
      featureValues = new ArrayList();
      int i = 0;
      for (int oldIndex : oldEvent.featureIndexes) {
        if (newIndexes[oldIndex] >= 0) {
          featureIndexes.add(newIndexes[oldIndex]);
          featureValues.add(oldEvent.featureValues.get(i));
        }
        i++;
      }
      outcomeIndex = oldEvent.outcomeIndex;
    }

    public List getFeatureIndexes() {
      return featureIndexes;
    }

    public List getFeatureValues() {
      return featureValues;
    }

    public int getOutcomeIndex() {
      return outcomeIndex;
    }

    public void write(Writer writer) throws IOException {
      writer.write("" + outcomeIndex);
      for (int i = 0; i < featureIndexes.size(); i++) {
        writer.write(" ");
        writer.write("" + featureIndexes.get(i));
        writer.write(" ");
        writer.write("" + featureValues.get(i));
      }
      writer.write("\n");
      writer.flush();
    }

  }

  /**
   * The maximum number of training iterations to run.
   */
  public int getIterations() {
    return iterations;
  }

  public void setIterations(int iterations) {
    this.iterations = iterations;
  }

  @Override
  public int getCutoff() {
    return cutoff;
  }

  @Override
  public void setCutoff(int cutoff) {
    this.cutoff = cutoff;
  }

  public double getTolerance() {
    return tolerance;
  }

  public void setTolerance(double tolerance) {
    this.tolerance = tolerance;
  }

  /**
   * If true, will only average for iterations <= 20 and then for all perfect
   * squares (25, 36, 49, 64, 81, 100, etc.).
   */
  public boolean isAverageAtIntervals() {
    return averageAtIntervals;
  }

  public void setAverageAtIntervals(boolean averageAtIntervals) {
    this.averageAtIntervals = averageAtIntervals;
  }

  public void trainModelsWithObserver(ClassificationEventStream corpusEventStream, List featureDescriptors, PerceptronModelTrainerObserver observer,
      List observationPoints) throws TalismaneException {
    Map> descriptors = new HashMap>();
    descriptors.put(MachineLearningModel.FEATURE_DESCRIPTOR_KEY, featureDescriptors);
    this.trainModelsWithObserver(corpusEventStream, descriptors, observer, observationPoints);
  }

  public void trainModelsWithObserver(ClassificationEventStream corpusEventStream, Map> descriptors,
      PerceptronModelTrainerObserver observer, List observationPoints) throws TalismaneException {
    params = new PerceptronModelParameters();
    decisionMaker = new PerceptronDecisionMaker(params, this.getScoring());
    this.descriptors = descriptors;
    this.observer = observer;
    this.observationPoints = observationPoints;
    this.corpusEventStream = corpusEventStream;
    this.prepareData(corpusEventStream);
    this.train();

    if (this.eventFile != null) {
      this.eventFile.delete();
    }

  }

  @Override
  public ClassificationModel trainModel(ClassificationEventStream corpusEventStream, List featureDescriptors) throws TalismaneException {
    Map> descriptors = new HashMap>();
    descriptors.put(MachineLearningModel.FEATURE_DESCRIPTOR_KEY, featureDescriptors);
    return this.trainModel(corpusEventStream, descriptors);
  }

  @Override
  public ClassificationModel trainModel(ClassificationEventStream corpusEventStream, Map> descriptors) throws TalismaneException {
    params = new PerceptronModelParameters();
    decisionMaker = new PerceptronDecisionMaker(params, this.getScoring());
    this.descriptors = descriptors;
    this.corpusEventStream = corpusEventStream;
    this.prepareData(corpusEventStream);
    this.train();
    ClassificationModel model = this.getModel(params, this.getIterations());

    if (this.eventFile != null)
      this.eventFile.delete();

    return model;
  }

  ClassificationModel getModel(PerceptronModelParameters params, int iterations) {
    PerceptronClassificationModel model = new PerceptronClassificationModel(params, config, descriptors);
    model.addModelAttribute("cutoff", this.getCutoff());
    model.addModelAttribute("iterations", this.getIterations());
    model.addModelAttribute("tolerance", this.getTolerance());
    model.addModelAttribute("averageAtIntervals", this.isAverageAtIntervals());
    model.addModelAttribute("scoring", this.getScoring());

    model.getModelAttributes().putAll(corpusEventStream.getAttributes());

    return model;
  }

  @Override
  public void setParameters(Config config) {
    this.config = config;

    Config perceptronConfig = config.getConfig("Perceptron");

    this.setCutoff(config.getInt("cutoff"));
    this.setIterations(config.getInt("iterations"));
    this.setTolerance(perceptronConfig.getDouble("tolerance"));
    this.setAverageAtIntervals(perceptronConfig.getBoolean("average-at-intervals"));
    this.setScoring(PerceptronScoring.valueOf(perceptronConfig.getString("scoring")));
  }

  public PerceptronScoring getScoring() {
    return scoring;
  }

  public void setScoring(PerceptronScoring scoring) {
    this.scoring = scoring;
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy