org.jpmml.evaluator.ProbabilityAggregator Maven / Gradle / Ivy
/*
* Copyright (c) 2014 Villu Ruusmann
*
* This file is part of JPMML-Evaluator
*
* JPMML-Evaluator is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-Evaluator is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-Evaluator. If not, see .
*/
package org.jpmml.evaluator;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import com.google.common.base.Function;
abstract
public class ProbabilityAggregator extends KeyValueAggregator {
private List hasProbabilities = null;
private int size = 0;
private Vector weights = null;
public ProbabilityAggregator(int capacity){
this(capacity, null);
}
public ProbabilityAggregator(int capacity, Vector weights){
super(capacity);
if(capacity > 0){
this.hasProbabilities = new ArrayList<>(capacity);
}
this.weights = weights;
}
public void add(HasProbability hasProbability){
if(this.weights != null){
throw new IllegalStateException();
} // End if
if(this.hasProbabilities != null){
this.hasProbabilities.add(hasProbability);
}
Set categories = hasProbability.getCategories();
for(String category : categories){
Double probability = hasProbability.getProbability(category);
add(category, probability);
}
this.size++;
}
public void add(HasProbability hasProbability, double weight){
if(this.weights == null){
throw new IllegalStateException();
} // End if
if(weight < 0d){
throw new IllegalArgumentException();
} // End if
if(this.hasProbabilities != null){
this.hasProbabilities.add(hasProbability);
}
Set categories = hasProbability.getCategories();
for(String category : categories){
Double probability = hasProbability.getProbability(category);
add(category, weight, probability);
}
this.size++;
this.weights.add(weight);
}
public ValueMap averageMap(){
if(this.weights != null){
throw new IllegalStateException();
}
Function, Value> function = new Function, Value>(){
private final int size = ProbabilityAggregator.this.size;
@Override
public Value apply(Vector values){
if(this.size == 0){
throw new UndefinedResultException();
}
return (values.sum()).divide(this.size);
}
};
return new ValueMap<>(asTransformedMap(function));
}
public ValueMap weightedAverageMap(){
if(this.weights == null){
throw new IllegalStateException();
}
Function, Value> function = new Function, Value>(){
private final Value weightSum = ProbabilityAggregator.this.weights.sum();
@Override
public Value apply(Vector values){
if(this.weightSum.equals(0d)){
throw new UndefinedResultException();
}
return (values.sum()).divide(this.weightSum);
}
};
return new ValueMap<>(asTransformedMap(function));
}
public ValueMap maxMap(Collection categories){
if(this.hasProbabilities == null){
throw new IllegalStateException();
} // End if
if(this.weights != null){
throw new IllegalStateException();
}
Function, Value> function = new Function, Value>(){
@Override
public Value apply(Vector values){
return values.max();
}
};
Map> maxMap = asTransformedMap(function);
Map.Entry> winnerEntry = getWinner(maxMap, categories);
if(winnerEntry == null){
return new ValueMap<>();
}
String category = winnerEntry.getKey();
Value maxProbability = winnerEntry.getValue();
List contributors = new ArrayList<>();
Vector values = get(category);
for(int i = 0, max = values.size(); i < max; i++){
Value probability = values.get(i);
if((maxProbability).compareTo(probability) == 0){
HasProbability contributor = this.hasProbabilities.get(i);
contributors.add(contributor);
}
}
return averageMap(contributors);
}
public ValueMap medianMap(Collection categories){
if(this.hasProbabilities == null){
throw new IllegalStateException();
} // End if
if(this.weights != null){
throw new IllegalStateException();
}
Function, Value> function = new Function, Value>(){
@Override
public Value apply(Vector values){
return values.median();
}
};
Map> medianMap = asTransformedMap(function);
Map.Entry> winnerEntry = getWinner(medianMap, categories);
if(winnerEntry == null){
return new ValueMap<>();
}
String category = winnerEntry.getKey();
Value medianProbability = winnerEntry.getValue();
List contributors = new ArrayList<>();
double minDifference = Double.MAX_VALUE;
Vector values = get(category);
for(int i = 0, max = values.size(); i < max; i++){
Value probability = values.get(i);
// Choose models whose probability is closest to the calculated median probability.
// If the number of models is odd (the calculated median probability equals that of the middle model),
// then all the chosen models will have the same probability (ie. difference == 0).
// If the number of models is even (the calculated median probability equals the average of two middle-most models),
// then some of the chosen models will have lower probabilies (ie. difference > 0), whereas the others will have higher probabilities (ie. difference < 0).
double difference = Math.abs(medianProbability.doubleValue() - probability.doubleValue());
if(difference < minDifference){
contributors.clear();
minDifference = difference;
} // End if
if(difference <= minDifference){
HasProbability contributor = this.hasProbabilities.get(i);
contributors.add(contributor);
}
}
return averageMap(contributors);
}
private ValueMap averageMap(List hasProbabilities){
if(hasProbabilities.size() == 1){
HasProbability hasProbability = hasProbabilities.get(0);
ValueFactory valueFactory = getValueFactory();
ValueMap result = new ValueMap<>();
Set categories = hasProbability.getCategories();
for(String category : categories){
Double probability = hasProbability.getProbability(category);
Value value = valueFactory.newValue(probability);
result.put(category, value);
}
return result;
} else
{
ProbabilityAggregator aggregator = new ProbabilityAggregator(0){
@Override
public ValueFactory getValueFactory(){
return ProbabilityAggregator.this.getValueFactory();
}
};
for(HasProbability hasProbability : hasProbabilities){
aggregator.add(hasProbability);
}
return aggregator.averageMap();
}
}
static
private Map.Entry> getWinner(Map> values, Collection categories){
Map.Entry> maxEntry = null;
if(categories == null){
categories = values.keySet();
}
for(String category : categories){
Value value = values.get(category);
if(value == null){
continue;
} // End if
if(maxEntry == null || (maxEntry.getValue()).compareTo(value) < 0){
maxEntry = new AbstractMap.SimpleEntry<>(category, value);
}
}
return maxEntry;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy