com.joliciel.talismane.machineLearning.linearsvm.LinearSVMDecisionMaker Maven / Gradle / Ivy
///////////////////////////////////////////////////////////////////////////////
//Copyright (C) 2014 Joliciel Informatique
//
//This file is part of Talismane.
//
//Talismane is free software: you can redistribute it and/or modify
//it under the terms of the GNU Affero General Public License as published by
//the Free Software Foundation, either version 3 of the License, or
//(at your option) any later version.
//
//Talismane is distributed in the hope that it will be useful,
//but WITHOUT ANY WARRANTY; without even the implied warranty of
//MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
//GNU Affero General Public License for more details.
//
//You should have received a copy of the GNU Affero General Public License
//along with Talismane. If not, see .
//////////////////////////////////////////////////////////////////////////////
package com.joliciel.talismane.machineLearning.linearsvm;
import java.util.ArrayList;
import java.util.List;
import java.util.TreeSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.joliciel.talismane.machineLearning.ClassificationSolution;
import com.joliciel.talismane.machineLearning.Decision;
import com.joliciel.talismane.machineLearning.DecisionMaker;
import com.joliciel.talismane.machineLearning.GeometricMeanScoringStrategy;
import com.joliciel.talismane.machineLearning.ScoringStrategy;
import com.joliciel.talismane.machineLearning.features.FeatureResult;
import de.bwaldvogel.liblinear.Feature;
import de.bwaldvogel.liblinear.Linear;
import de.bwaldvogel.liblinear.Model;
import gnu.trove.map.TObjectIntMap;
class LinearSVMDecisionMaker implements DecisionMaker {
private static final Logger LOG = LoggerFactory.getLogger(LinearSVMDecisionMaker.class);
private Model model;
TObjectIntMap featureIndexMap = null;
List outcomes = null;
private transient ScoringStrategy scoringStrategy = null;
public LinearSVMDecisionMaker(Model model, TObjectIntMap featureIndexMap, List outcomes) {
super();
this.model = model;
this.featureIndexMap = featureIndexMap;
this.outcomes = outcomes;
}
@Override
public List decide(List> featureResults) {
List featureList = LinearSVMUtils.prepareData(featureResults, featureIndexMap);
List decisions = null;
if (featureList.size() == 0) {
LOG.info("No features for current context.");
TreeSet outcomeSet = new TreeSet();
double uniformProb = 1 / outcomes.size();
for (String outcome : outcomes) {
Decision decision = new Decision(outcome, uniformProb);
outcomeSet.add(decision);
}
decisions = new ArrayList(outcomeSet);
} else {
Feature[] instance = new Feature[1];
instance = featureList.toArray(instance);
double[] probabilities = new double[model.getLabels().length];
Linear.predictProbability(model, instance, probabilities);
TreeSet outcomeSet = new TreeSet();
for (int i = 0; i < model.getLabels().length; i++) {
Decision decision = new Decision(outcomes.get(i), probabilities[i]);
outcomeSet.add(decision);
}
decisions = new ArrayList(outcomeSet);
}
return decisions;
}
@Override
public ScoringStrategy getDefaultScoringStrategy() {
if (scoringStrategy == null)
scoringStrategy = new GeometricMeanScoringStrategy();
return scoringStrategy;
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy