All Downloads are FREE. Search and download functionalities are using the official Maven repository.
Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.jpmml.evaluator.ProbabilityAggregator Maven / Gradle / Ivy
/*
* Copyright (c) 2014 Villu Ruusmann
*
* This file is part of JPMML-Evaluator
*
* JPMML-Evaluator is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-Evaluator is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-Evaluator. If not, see .
*/
package org.jpmml.evaluator;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import com.google.common.base.Function;
public class ProbabilityAggregator extends KeyValueAggregator {
private List hasProbabilities = null;
private int size = 0;
private Vector weights = null;
protected ProbabilityAggregator(ValueFactory valueFactory, int capacity){
this(valueFactory, capacity, false);
}
protected ProbabilityAggregator(ValueFactory valueFactory, int capacity, boolean weighted){
super(valueFactory, capacity);
if(capacity > 0){
this.hasProbabilities = new ArrayList<>(capacity);
} // End if
if(weighted){
this.weights = valueFactory.newVector(0);
}
}
public void add(HasProbability hasProbability){
if(this.weights != null){
throw new IllegalStateException();
} // End if
if(this.hasProbabilities != null){
this.hasProbabilities.add(hasProbability);
}
Set> categories = hasProbability.getCategories();
for(Object category : categories){
Double probability = hasProbability.getProbability(category);
add(category, probability);
}
this.size++;
}
/**
* @param probabilities An array of numbers that sum to 1.
*
* @see #init(Collection)
*/
public void add(Number[] probabilities){
if(this.weights != null || this.hasProbabilities != null){
throw new IllegalStateException();
}
Collection> mapValues = values();
if(mapValues.size() != probabilities.length){
throw new IllegalArgumentException();
}
Iterator> valueIt = mapValues.iterator();
for(int i = 0; valueIt.hasNext(); i++){
Vector values = valueIt.next();
values.add(probabilities[i]);
}
this.size++;
}
public void add(HasProbability hasProbability, Number weight){
if(this.weights == null){
throw new IllegalStateException();
} // End if
if(weight.doubleValue() < 0d){
throw new IllegalArgumentException();
} // End if
if(this.hasProbabilities != null){
this.hasProbabilities.add(hasProbability);
}
Set> categories = hasProbability.getCategories();
for(Object category : categories){
Double probability = hasProbability.getProbability(category);
add(category, weight, probability);
}
this.size++;
this.weights.add(weight);
}
/**
* @param probabilities An array of numbers that sum to 1.
*
* @see #init(Collection)
*/
public void add(Number[] probabilities, Number weight){
if(this.weights == null || this.hasProbabilities != null){
throw new IllegalStateException();
}
Collection> mapValues = values();
if(mapValues.size() != probabilities.length){
throw new IllegalArgumentException();
} // End if
if(weight.doubleValue() < 0d){
throw new IllegalArgumentException();
}
Iterator> valueIt = mapValues.iterator();
for(int i = 0; valueIt.hasNext(); i++){
Vector values = valueIt.next();
if(weight.doubleValue() != 1d){
values.add(weight, probabilities[i]);
} else
{
values.add(probabilities[i]);
}
}
this.size++;
this.weights.add(weight);
}
public ValueMap averageMap(){
if(this.weights != null){
throw new IllegalStateException();
}
Function, Value> function = new Function, Value>(){
private int size = ProbabilityAggregator.this.size;
@Override
public Value apply(Vector values){
if(this.size == 0){
throw new UndefinedResultException();
}
return (values.sum()).divide(this.size);
}
};
return new ValueMap<>(asTransformedMap(function));
}
public ValueMap weightedAverageMap(){
if(this.weights == null){
throw new IllegalStateException();
}
Function, Value> function = new Function, Value>(){
private Value weightSum = ProbabilityAggregator.this.weights.sum();
@Override
public Value apply(Vector values){
if(this.weightSum.isZero()){
throw new UndefinedResultException();
}
return (values.sum()).divide(this.weightSum);
}
};
return new ValueMap<>(asTransformedMap(function));
}
public ValueMap maxMap(Collection> categories){
if(this.hasProbabilities == null){
throw new IllegalStateException();
} // End if
if(this.weights != null){
throw new IllegalStateException();
}
Map, Value> maxMap = asTransformedMap(Vector::max);
Map.Entry, Value> winnerEntry = getWinner(maxMap, categories);
if(winnerEntry == null){
return new ValueMap<>();
}
Object category = winnerEntry.getKey();
Value maxProbability = winnerEntry.getValue();
List contributors = new ArrayList<>();
Vector values = get(category);
for(int i = 0, max = values.size(); i < max; i++){
Value probability = values.get(i);
if((maxProbability).compareTo(probability) == 0){
HasProbability contributor = this.hasProbabilities.get(i);
contributors.add(contributor);
}
}
return averageMap(contributors);
}
public ValueMap medianMap(Collection> categories){
if(this.hasProbabilities == null){
throw new IllegalStateException();
} // End if
if(this.weights != null){
throw new IllegalStateException();
}
Map, Value> medianMap = asTransformedMap(Vector::median);
Map.Entry, Value> winnerEntry = getWinner(medianMap, categories);
if(winnerEntry == null){
return new ValueMap<>();
}
Object category = winnerEntry.getKey();
Value medianProbability = winnerEntry.getValue();
List contributors = new ArrayList<>();
double minDifference = Double.MAX_VALUE;
Vector values = get(category);
for(int i = 0, max = values.size(); i < max; i++){
Value probability = values.get(i);
// Choose models whose probability is closest to the calculated median probability.
// If the number of models is odd (the calculated median probability equals that of the middle model),
// then all the chosen models will have the same probability (ie. difference == 0).
// If the number of models is even (the calculated median probability equals the average of two middle-most models),
// then some of the chosen models will have lower probabilies (ie. difference > 0), whereas the others will have higher probabilities (ie. difference < 0).
double difference = Math.abs(medianProbability.doubleValue() - probability.doubleValue());
if(difference < minDifference){
contributors.clear();
minDifference = difference;
} // End if
if(difference <= minDifference){
HasProbability contributor = this.hasProbabilities.get(i);
contributors.add(contributor);
}
}
return averageMap(contributors);
}
private ValueMap averageMap(List hasProbabilities){
ValueFactory valueFactory = getValueFactory();
if(hasProbabilities.size() == 1){
HasProbability hasProbability = hasProbabilities.get(0);
ValueMap result = new ValueMap<>();
Set> categories = keySet();
for(Object category : categories){
Double probability = hasProbability.getProbability(category);
Value value = valueFactory.newValue(probability);
result.put(category, value);
}
return result;
} else
{
ProbabilityAggregator aggregator = new ProbabilityAggregator.Average<>(valueFactory);
aggregator.init(keySet());
for(int i = 0, max = hasProbabilities.size(); i < max; i++){
HasProbability hasProbability = hasProbabilities.get(i);
aggregator.add(hasProbability);
}
return aggregator.averageMap();
}
}
static
private Map.Entry> getWinner(Map, Value> values, Collection> categories){
Map.Entry> maxEntry = null;
if(categories == null){
categories = values.keySet();
}
for(Object category : categories){
Value value = values.get(category);
if(value == null){
continue;
} // End if
if(maxEntry == null || (maxEntry.getValue()).compareTo(value) < 0){
maxEntry = new AbstractMap.SimpleEntry<>(category, value);
}
}
return maxEntry;
}
static
public class Average extends ProbabilityAggregator {
public Average(ValueFactory valueFactory){
super(valueFactory, 0);
}
}
static
public class WeightedAverage extends ProbabilityAggregator {
public WeightedAverage(ValueFactory valueFactory){
super(valueFactory, 0, true);
}
}
static
public class Max extends ProbabilityAggregator {
public Max(ValueFactory valueFactory, int capacity){
super(valueFactory, capacity);
}
}
static
public class Median extends ProbabilityAggregator {
public Median(ValueFactory valueFactory, int capacity){
super(valueFactory, capacity);
}
}
}