
org.jpmml.evaluator.ProbabilityAggregator Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of pmml-evaluator Show documentation
Show all versions of pmml-evaluator Show documentation
JPMML class model evaluator
/*
* Copyright (c) 2014 Villu Ruusmann
*
* This file is part of JPMML-Evaluator
*
* JPMML-Evaluator is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-Evaluator is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-Evaluator. If not, see .
*/
package org.jpmml.evaluator;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import com.google.common.base.Function;
public class ProbabilityAggregator extends ClassificationAggregator {
private List hasProbabilities = null;
public ProbabilityAggregator(){
this(0);
}
public ProbabilityAggregator(int capacity){
super(capacity);
if(capacity > 0){
this.hasProbabilities = new ArrayList<>(capacity);
}
}
public void add(HasProbability hasProbability){
add(hasProbability, 1d);
}
public void add(HasProbability hasProbability, double weight){
if(this.hasProbabilities != null){
this.hasProbabilities.add(hasProbability);
}
Set categories = hasProbability.getCategoryValues();
for(String category : categories){
Double probability = hasProbability.getProbability(category);
add(category, weight != 1d ? (probability * weight) : probability);
}
}
public Map maxMap(Collection categories){
if(this.hasProbabilities == null){
throw new IllegalStateException();
}
Function function = new Function(){
@Override
public Double apply(DoubleVector values){
return values.max();
}
};
Map maxValues = transform(function);
Map.Entry maxMaxValue = getWinner(maxValues, categories);
if(maxMaxValue == null){
return Collections.emptyMap();
}
String category = maxMaxValue.getKey();
double maxProbability = maxMaxValue.getValue();
List contributors = new ArrayList<>();
DoubleVector values = get(category);
for(int i = 0; i < values.size(); i++){
double probability = values.get(i);
if(probability == maxProbability){
HasProbability contributor = this.hasProbabilities.get(i);
contributors.add(contributor);
}
}
return averageMap(contributors);
}
public Map medianMap(Collection categories){
if(this.hasProbabilities == null){
throw new IllegalStateException();
}
Function function = new Function(){
@Override
public Double apply(DoubleVector values){
return values.median();
}
};
Map medianValues = transform(function);
Map.Entry maxMedianValue = getWinner(medianValues, categories);
if(maxMedianValue == null){
return Collections.emptyMap();
}
String category = maxMedianValue.getKey();
double medianProbability = maxMedianValue.getValue();
List contributors = new ArrayList<>();
double minDifference = Double.MAX_VALUE;
DoubleVector values = get(category);
for(int i = 0; i < values.size(); i++){
double probability = values.get(i);
// Choose models whose probability is closest to the calculated median probability.
// If the number of models is odd (the calculated median probability equals that of the middle model),
// then all the chosen models will have the same probability (ie. difference == 0).
// If the number of models is even (the calculated median probability equals the average of two middle-most models),
// then some of the chosen models will have lower probabilies (ie. difference > 0), whereas the others will have higher probabilities (ie. difference < 0).
double difference = Math.abs(medianProbability - probability);
if(difference < minDifference){
contributors.clear();
minDifference = difference;
} // End if
if(difference <= minDifference){
HasProbability contributor = this.hasProbabilities.get(i);
contributors.add(contributor);
}
}
return averageMap(contributors);
}
public Map averageMap(final double denominator){
Function function = new Function(){
@Override
public Double apply(DoubleVector values){
return values.sum() / denominator;
}
};
return transform(function);
}
static
private Map.Entry getWinner(Map values, Collection categories){
if(categories == null || categories.isEmpty()){
throw new EvaluationException();
}
Map.Entry maxEntry = null;
for(String category : categories){
Double value = values.get(category);
if(value == null){
continue;
} // End if
if(maxEntry == null || (maxEntry.getValue()).compareTo(value) < 0){
maxEntry = new AbstractMap.SimpleEntry<>(category, value);
}
}
return maxEntry;
}
static
private Map averageMap(List hasProbabilities){
if(hasProbabilities.size() == 1){
HasProbability hasProbability = hasProbabilities.get(0);
Map result = new LinkedHashMap<>();
Set categories = hasProbability.getCategoryValues();
for(String category : categories){
Double probability = hasProbability.getProbability(category);
result.put(category, probability);
}
return result;
} else
{
ProbabilityAggregator aggregator = new ProbabilityAggregator();
for(HasProbability hasProbability : hasProbabilities){
aggregator.add(hasProbability);
}
return aggregator.averageMap(hasProbabilities.size());
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy