All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.joliciel.talismane.machineLearning.maxent.OpenNLPDecisionMaker Maven / Gradle / Ivy

There is a newer version: 6.1.8
Show newest version
///////////////////////////////////////////////////////////////////////////////
//Copyright (C) 2014 Joliciel Informatique
//
//This file is part of Talismane.
//
//Talismane is free software: you can redistribute it and/or modify
//it under the terms of the GNU Affero General Public License as published by
//the Free Software Foundation, either version 3 of the License, or
//(at your option) any later version.
//
//Talismane is distributed in the hope that it will be useful,
//but WITHOUT ANY WARRANTY; without even the implied warranty of
//MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//GNU Affero General Public License for more details.
//
//You should have received a copy of the GNU Affero General Public License
//along with Talismane.  If not, see .
//////////////////////////////////////////////////////////////////////////////
package com.joliciel.talismane.machineLearning.maxent;

import java.util.ArrayList;
import java.util.List;
import java.util.TreeSet;

import com.joliciel.talismane.machineLearning.ClassificationSolution;
import com.joliciel.talismane.machineLearning.Decision;
import com.joliciel.talismane.machineLearning.DecisionMaker;
import com.joliciel.talismane.machineLearning.GeometricMeanScoringStrategy;
import com.joliciel.talismane.machineLearning.ScoringStrategy;
import com.joliciel.talismane.machineLearning.features.FeatureResult;
import com.joliciel.talismane.utils.WeightedOutcome;

import opennlp.model.MaxentModel;

class OpenNLPDecisionMaker implements DecisionMaker {
  private MaxentModel model;
  private transient ScoringStrategy scoringStrategy = null;

  public OpenNLPDecisionMaker(MaxentModel model) {
    super();
    this.model = model;
  }

  @Override
  public List decide(List> featureResults) {
    List contextList = new ArrayList();
    List weightList = new ArrayList();
    OpenNLPDecisionMaker.prepareData(featureResults, contextList, weightList);

    String[] contexts = new String[contextList.size()];
    float[] weights = new float[weightList.size()];

    int i = 0;
    for (String context : contextList) {
      contexts[i++] = context;
    }
    i = 0;
    for (Float weight : weightList) {
      weights[i++] = weight;
    }

    double[] probs = model.eval(contexts, weights);

    String[] outcomes = new String[probs.length];
    for (i = 0; i < probs.length; i++)
      outcomes[i] = model.getOutcome(i);

    TreeSet outcomeSet = new TreeSet();
    for (i = 0; i < probs.length; i++) {
      Decision decision = new Decision(outcomes[i], probs[i]);
      outcomeSet.add(decision);
    }

    List decisions = new ArrayList(outcomeSet);

    return decisions;
  }

  static void prepareData(List> featureResults, List contextList, List weightList) {
    for (FeatureResult featureResult : featureResults) {
      if (featureResult != null) {
        if (featureResult.getOutcome() instanceof List) {
          @SuppressWarnings("unchecked")
          FeatureResult>> stringCollectionResult = (FeatureResult>>) featureResult;
          for (WeightedOutcome stringOutcome : stringCollectionResult.getOutcome()) {
            contextList.add(featureResult.getTrainingName() + "|" + featureResult.getTrainingOutcome(stringOutcome.getOutcome()));
            weightList.add(((Double) stringOutcome.getWeight()).floatValue());
          }
        } else {
          float weight = 1;
          if (featureResult.getOutcome() instanceof Double) {
            @SuppressWarnings("unchecked")
            FeatureResult doubleResult = (FeatureResult) featureResult;
            weight = doubleResult.getOutcome().floatValue();
          }
          contextList.add(featureResult.getTrainingName());
          weightList.add(weight);
        }
      }
    }
  }

  @Override
  public ScoringStrategy getDefaultScoringStrategy() {
    if (scoringStrategy == null)
      scoringStrategy = new GeometricMeanScoringStrategy();
    return scoringStrategy;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy