All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.joliciel.talismane.machineLearning.perceptron.PerceptronDecisionMaker Maven / Gradle / Ivy

There is a newer version: 6.1.8
Show newest version
///////////////////////////////////////////////////////////////////////////////
//Copyright (C) 2014 Joliciel Informatique
//
//This file is part of Talismane.
//
//Talismane is free software: you can redistribute it and/or modify
//it under the terms of the GNU Affero General Public License as published by
//the Free Software Foundation, either version 3 of the License, or
//(at your option) any later version.
//
//Talismane is distributed in the hope that it will be useful,
//but WITHOUT ANY WARRANTY; without even the implied warranty of
//MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//GNU Affero General Public License for more details.
//
//You should have received a copy of the GNU Affero General Public License
//along with Talismane.  If not, see .
//////////////////////////////////////////////////////////////////////////////
package com.joliciel.talismane.machineLearning.perceptron;

import java.util.ArrayList;
import java.util.List;
import java.util.TreeSet;

import com.joliciel.talismane.machineLearning.AdditiveScoringStrategy;
import com.joliciel.talismane.machineLearning.ClassificationSolution;
import com.joliciel.talismane.machineLearning.Decision;
import com.joliciel.talismane.machineLearning.DecisionMaker;
import com.joliciel.talismane.machineLearning.GeometricMeanScoringStrategy;
import com.joliciel.talismane.machineLearning.ScoringStrategy;
import com.joliciel.talismane.machineLearning.features.FeatureResult;
import com.joliciel.talismane.utils.JolicielException;

class PerceptronDecisionMaker implements DecisionMaker {
  private PerceptronModelParameters modelParameters;
  private transient ScoringStrategy scoringStrategy = null;
  private transient PerceptronScoring perceptronScoring = null;

  public PerceptronDecisionMaker(PerceptronModelParameters params, PerceptronScoring perceptronScoring) {
    super();
    this.modelParameters = params;
    this.perceptronScoring = perceptronScoring;
  }

  @Override
  public List decide(List> featureResults) {
    List featureIndexList = new ArrayList();
    List featureValueList = new ArrayList();
    modelParameters.prepareData(featureResults, featureIndexList, featureValueList);

    double[] results = this.predict(featureIndexList, featureValueList);
    double[] probs = new double[results.length];

    if (this.getPerceptronScoring() == PerceptronScoring.normalisedExponential) {
      // e^(x/absmax)/sum(e^(x/absmax))
      // where x/absmax is in [-1,1]
      // e^(x/absmax) is in [1/e,e]
      double absoluteMax = 1;
      for (int i = 0; i < results.length; i++) {
        if (Math.abs(results[i]) > absoluteMax)
          absoluteMax = Math.abs(results[i]);
      }

      double total = 0.0;
      for (int i = 0; i < results.length; i++) {
        probs[i] = Math.exp(results[i] / absoluteMax);
        total += probs[i];
      }

      for (int i = 0; i < probs.length; i++) {
        probs[i] /= total;
      }
    } else {
      // make all results >= 1
      double min = Double.MAX_VALUE;
      for (int i = 0; i < results.length; i++) {
        if (results[i] < min)
          min = results[i];
      }

      if (min < 0) {
        for (int i = 0; i < results.length; i++) {
          probs[i] = (results[i] - min) + 1;
        }
      }

      // then divide by total to get a probability distribution
      double total = 0.0;
      for (int i = 0; i < probs.length; i++) {
        total += probs[i];
      }

      for (int i = 0; i < probs.length; i++) {
        probs[i] /= total;
      }
    }

    int i = 0;
    TreeSet outcomeSet = new TreeSet();
    for (String outcome : modelParameters.getOutcomes()) {
      Decision decision = new Decision(outcome, results[i], probs[i]);
      outcomeSet.add(decision);
      i++;
    }

    List decisions = new ArrayList(outcomeSet);

    return decisions;

  }

  public double[] predict(List featureIndexList, List featureValueList) {
    double[] results = new double[modelParameters.getOutcomeCount()];
    for (int i = 0; i < featureIndexList.size(); i++) {
      int featureIndex = featureIndexList.get(i);
      double value = featureValueList.get(i);

      for (int j = 0; j < results.length; j++) {
        double[] classWeights = modelParameters.getFeatureWeights()[featureIndex];
        double weight = classWeights[j];
        results[j] += value * weight;
      }
    }

    return results;
  }

  public PerceptronModelParameters getModelParameters() {
    return modelParameters;
  }

  @Override
  public ScoringStrategy getDefaultScoringStrategy() {
    if (scoringStrategy == null) {
      if (this.getPerceptronScoring() == PerceptronScoring.normalisedLinear) {
        scoringStrategy = new GeometricMeanScoringStrategy();
      } else if (this.getPerceptronScoring() == PerceptronScoring.normalisedExponential) {
        scoringStrategy = new GeometricMeanScoringStrategy();
      } else if (this.getPerceptronScoring() == PerceptronScoring.additive) {
        scoringStrategy = new AdditiveScoringStrategy();
      } else {
        throw new JolicielException("Unknown perceptron scoring strategy: " + this.getPerceptronScoring());
      }
    }
    return scoringStrategy;
  }

  public PerceptronScoring getPerceptronScoring() {
    return perceptronScoring;
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy