All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.joliciel.talismane.machineLearning.perceptron.PerceptronModelParameters Maven / Gradle / Ivy

There is a newer version: 6.1.8
Show newest version
///////////////////////////////////////////////////////////////////////////////
//Copyright (C) 2014 Joliciel Informatique
//
//This file is part of Talismane.
//
//Talismane is free software: you can redistribute it and/or modify
//it under the terms of the GNU Affero General Public License as published by
//the Free Software Foundation, either version 3 of the License, or
//(at your option) any later version.
//
//Talismane is distributed in the hope that it will be useful,
//but WITHOUT ANY WARRANTY; without even the implied warranty of
//MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//GNU Affero General Public License for more details.
//
//You should have received a copy of the GNU Affero General Public License
//along with Talismane.  If not, see .
//////////////////////////////////////////////////////////////////////////////
package com.joliciel.talismane.machineLearning.perceptron;

import gnu.trove.map.TObjectIntMap;
import gnu.trove.map.hash.TObjectIntHashMap;
import gnu.trove.procedure.TObjectIntProcedure;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;

import com.joliciel.talismane.machineLearning.features.FeatureResult;
import com.joliciel.talismane.utils.WeightedOutcome;

/**
 * Perceptron classification model parameters. Weights are a matrix of feature x
 * label.
 * 
 * @author Assaf Urieli
 *
 */
class PerceptronModelParameters implements Serializable {
  private static final long serialVersionUID = 1L;

  private List outcomes = new ArrayList();
  private int featureCount = 0;
  private int outcomeCount = 0;

  private TObjectIntMap outcomeIndexes = new TObjectIntHashMap(10, 0.7f, -1);
  private TObjectIntMap featureIndexes = new TObjectIntHashMap(1000, 0.7f, -1);
  private double[][] featureWeights;
  private int[] featureCounts;

  public PerceptronModelParameters() {
  }

  public PerceptronModelParameters clone() {
    PerceptronModelParameters params = new PerceptronModelParameters(this);
    return params;
  }

  private PerceptronModelParameters(PerceptronModelParameters params) {
    // need to perform deep clone for feature weights
    this.featureWeights = new double[params.getFeatureWeights().length][];
    for (int i = 0; i < params.getFeatureWeights().length; i++) {
      this.featureWeights[i] = params.getFeatureWeights()[i].clone();
    }
    // all the rest can be a shallow copy since it won't change
    this.outcomeIndexes = params.getOutcomeIndexes();
    this.featureIndexes = params.getFeatureIndexes();
    this.featureCounts = params.getFeatureCounts();
    this.outcomes = params.getOutcomes();
    this.featureCount = params.getFeatureCount();
    this.outcomeCount = params.getOutcomeCount();
  }

  public int[] initialise(PerceptronModelParameters oldParams, int cutoff) {
    int[] newIndexes = new int[oldParams.featureCount];
    this.outcomes = oldParams.outcomes;
    this.outcomeIndexes = oldParams.outcomeIndexes;
    this.outcomeCount = oldParams.outcomeCount;
    final String[] featureNames = new String[oldParams.featureCount];
    oldParams.featureIndexes.forEachEntry(new TObjectIntProcedure() {

      @Override
      public boolean execute(String key, int value) {
        featureNames[value] = key;
        return true;
      }
    });
    int i = 0;
    for (int count : oldParams.featureCounts) {
      if (count >= cutoff) {
        int newIndex = this.getOrCreateFeatureIndex(featureNames[i]);
        newIndexes[i] = newIndex;
      } else {
        newIndexes[i] = -1;
      }
      i++;
    }

    return newIndexes;
  }

  public int getOutcomeIndex(String outcome) {
    return outcomeIndexes.get(outcome);
  }

  public int getOrCreateOutcomeIndex(String outcome) {
    int outcomeIndex = outcomeIndexes.get(outcome);
    if (outcomeIndex < 0) {
      outcomeIndex = outcomeCount++;
      outcomeIndexes.put(outcome, outcomeIndex);
      outcomes.add(outcome);
    }
    return outcomeIndex;
  }

  public int getFeatureIndex(String featureName) {
    return featureIndexes.get(featureName);
  }

  public int getOrCreateFeatureIndex(String featureName) {
    int featureIndex = featureIndexes.get(featureName);
    if (featureIndex < 0) {
      featureIndex = featureCount++;
      featureIndexes.put(featureName, featureIndex);
    }
    return featureIndex;
  }

  public void initialiseCounts() {
    featureCounts = new int[featureCount];
  }

  public void initialiseWeights() {
    featureWeights = new double[featureCount][outcomeCount];
  }

  public double[][] getFeatureWeights() {
    return featureWeights;
  }

  public int[] getFeatureCounts() {
    return featureCounts;
  }

  public int getFeatureCount() {
    return featureCount;
  }

  public int getOutcomeCount() {
    return outcomeCount;
  }

  public List getOutcomes() {
    return outcomes;
  }

  public TObjectIntMap getOutcomeIndexes() {
    return outcomeIndexes;
  }

  public TObjectIntMap getFeatureIndexes() {
    return featureIndexes;
  }

  /**
   * Prepare the feature index list and weight list, based on the feature
   * results provided. If a feature is not in the model, leave it out.
   * 
   * @param featureResults
   *          the results to analyse
   * @param featureIndexList
   *          the list of feature indexes to populate
   * @param featureValueList
   *          the list of feature values to populate
   */
  public void prepareData(List> featureResults, List featureIndexList, List featureValueList) {
    this.prepareData(featureResults, featureIndexList, featureValueList, false);
  }

  /**
   * Prepare the feature index list and weight list, based on the feature
   * results provided.
   * 
   * @param create
   *          If true and a feature is not in the model, create it. Otherwise
   *          leave it out.
   */
  public void prepareData(List> featureResults, List featureIndexList, List featureValueList, boolean create) {
    for (FeatureResult featureResult : featureResults) {
      if (featureResult != null) {
        if (featureResult.getOutcome() instanceof List) {
          @SuppressWarnings("unchecked")
          FeatureResult>> stringCollectionResult = (FeatureResult>>) featureResult;
          for (WeightedOutcome stringOutcome : stringCollectionResult.getOutcome()) {
            String featureName = featureResult.getTrainingName() + "|" + featureResult.getTrainingOutcome(stringOutcome.getOutcome());
            int featureIndex = -1;
            if (create) {
              featureIndex = this.getOrCreateFeatureIndex(featureName);
            } else {
              featureIndex = this.getFeatureIndex(featureName);
            }
            if (featureIndex >= 0) {
              featureIndexList.add(featureIndex);
              featureValueList.add(stringOutcome.getWeight());
            }
          }
        } else {
          double value = 1;
          if (featureResult.getOutcome() instanceof Double) {
            @SuppressWarnings("unchecked")
            FeatureResult doubleResult = (FeatureResult) featureResult;
            value = doubleResult.getOutcome();
          }
          String featureName = featureResult.getTrainingName();
          int featureIndex = -1;
          if (create) {
            featureIndex = this.getOrCreateFeatureIndex(featureName);
          } else {
            featureIndex = this.getFeatureIndex(featureName);
          }
          if (featureIndex >= 0) {
            featureIndexList.add(featureIndex);
            featureValueList.add(value);
          }

        }
      }
    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy