All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.joliciel.talismane.machineLearning.linearsvm.LinearSVMOneVsRestDecisionMaker Maven / Gradle / Ivy

There is a newer version: 6.1.8
Show newest version
///////////////////////////////////////////////////////////////////////////////
//Copyright (C) 2014 Joliciel Informatique
//
//This file is part of Talismane.
//
//Talismane is free software: you can redistribute it and/or modify
//it under the terms of the GNU Affero General Public License as published by
//the Free Software Foundation, either version 3 of the License, or
//(at your option) any later version.
//
//Talismane is distributed in the hope that it will be useful,
//but WITHOUT ANY WARRANTY; without even the implied warranty of
//MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//GNU Affero General Public License for more details.
//
//You should have received a copy of the GNU Affero General Public License
//along with Talismane.  If not, see .
//////////////////////////////////////////////////////////////////////////////
package com.joliciel.talismane.machineLearning.linearsvm;

import java.util.ArrayList;
import java.util.List;
import java.util.TreeSet;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.joliciel.talismane.machineLearning.ClassificationSolution;
import com.joliciel.talismane.machineLearning.Decision;
import com.joliciel.talismane.machineLearning.DecisionMaker;
import com.joliciel.talismane.machineLearning.GeometricMeanScoringStrategy;
import com.joliciel.talismane.machineLearning.ScoringStrategy;
import com.joliciel.talismane.machineLearning.features.FeatureResult;

import de.bwaldvogel.liblinear.Feature;
import de.bwaldvogel.liblinear.Linear;
import de.bwaldvogel.liblinear.Model;
import gnu.trove.map.TObjectIntMap;

class LinearSVMOneVsRestDecisionMaker implements DecisionMaker {
  private static final Logger LOG = LoggerFactory.getLogger(LinearSVMOneVsRestDecisionMaker.class);

  private List models;
  private TObjectIntMap featureIndexMap = null;
  private List outcomes = null;
  private transient ScoringStrategy scoringStrategy = null;

  public LinearSVMOneVsRestDecisionMaker(List models, TObjectIntMap featureIndexMap, List outcomes) {
    super();
    this.models = models;
    this.featureIndexMap = featureIndexMap;
    this.outcomes = outcomes;
  }

  @Override
  public List decide(List> featureResults) {
    List featureList = LinearSVMUtils.prepareData(featureResults, featureIndexMap);

    List decisions = null;

    if (featureList.size() == 0) {
      LOG.info("No features for current context.");
      TreeSet outcomeSet = new TreeSet();
      double uniformProb = 1 / outcomes.size();
      for (String outcome : outcomes) {
        Decision decision = new Decision(outcome, uniformProb);
        outcomeSet.add(decision);
      }
      decisions = new ArrayList(outcomeSet);
    } else {
      Feature[] instance = new Feature[1];
      instance = featureList.toArray(instance);

      TreeSet outcomeSet = new TreeSet();

      int i = 0;
      for (Model model : models) {
        int myLabel = 0;
        for (int j = 0; j < model.getLabels().length; j++)
          if (model.getLabels()[j] == 1)
            myLabel = j;
        double[] probabilities = new double[2];
        Linear.predictProbability(model, instance, probabilities);

        Decision decision = new Decision(outcomes.get(i), probabilities[myLabel]);
        outcomeSet.add(decision);
        i++;
      }
      decisions = new ArrayList(outcomeSet);
    }
    return decisions;
  }

  @Override
  public ScoringStrategy getDefaultScoringStrategy() {
    if (scoringStrategy == null)
      scoringStrategy = new GeometricMeanScoringStrategy();
    return scoringStrategy;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy