All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.evaluator.ProbabilityAggregator Maven / Gradle / Ivy

/*
 * Copyright (c) 2014 Villu Ruusmann
 *
 * This file is part of JPMML-Evaluator
 *
 * JPMML-Evaluator is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-Evaluator is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-Evaluator.  If not, see .
 */
package org.jpmml.evaluator;

import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;

import com.google.common.base.Function;

abstract
public class ProbabilityAggregator extends KeyValueAggregator {

	private List hasProbabilities = null;

	private int size = 0;

	private Vector weights = null;


	public ProbabilityAggregator(int capacity){
		this(capacity, null);
	}

	public ProbabilityAggregator(int capacity, Vector weights){
		super(capacity);

		if(capacity > 0){
			this.hasProbabilities = new ArrayList<>(capacity);
		}

		this.weights = weights;
	}

	public void add(HasProbability hasProbability){

		if(this.weights != null){
			throw new IllegalStateException();
		} // End if

		if(this.hasProbabilities != null){
			this.hasProbabilities.add(hasProbability);
		}

		Set categories = hasProbability.getCategories();
		for(String category : categories){
			Double probability = hasProbability.getProbability(category);

			add(category, probability);
		}

		this.size++;
	}

	public void add(HasProbability hasProbability, double weight){

		if(this.weights == null){
			throw new IllegalStateException();
		} // End if

		if(weight < 0d){
			throw new IllegalArgumentException();
		} // End if

		if(this.hasProbabilities != null){
			this.hasProbabilities.add(hasProbability);
		}

		Set categories = hasProbability.getCategories();
		for(String category : categories){
			Double probability = hasProbability.getProbability(category);

			add(category, weight, probability);
		}

		this.size++;

		this.weights.add(weight);
	}

	public ValueMap averageMap(){

		if(this.weights != null){
			throw new IllegalStateException();
		}

		Function, Value> function = new Function, Value>(){

			private final int size = ProbabilityAggregator.this.size;


			@Override
			public Value apply(Vector values){

				if(this.size == 0){
					throw new UndefinedResultException();
				}

				return (values.sum()).divide(this.size);
			}
		};

		return new ValueMap<>(asTransformedMap(function));
	}

	public ValueMap weightedAverageMap(){

		if(this.weights == null){
			throw new IllegalStateException();
		}

		Function, Value> function = new Function, Value>(){

			private final Value weightSum = ProbabilityAggregator.this.weights.sum();


			@Override
			public Value apply(Vector values){

				if(this.weightSum.equals(0d)){
					throw new UndefinedResultException();
				}

				return (values.sum()).divide(this.weightSum);
			}
		};

		return new ValueMap<>(asTransformedMap(function));
	}

	public ValueMap maxMap(Collection categories){

		if(this.hasProbabilities == null){
			throw new IllegalStateException();
		} // End if

		if(this.weights != null){
			throw new IllegalStateException();
		}

		Function, Value> function = new Function, Value>(){

			@Override
			public Value apply(Vector values){
				return values.max();
			}
		};

		Map> maxMap = asTransformedMap(function);

		Map.Entry> winnerEntry = getWinner(maxMap, categories);
		if(winnerEntry == null){
			return new ValueMap<>();
		}

		String category = winnerEntry.getKey();
		Value maxProbability = winnerEntry.getValue();

		List contributors = new ArrayList<>();

		Vector values = get(category);

		for(int i = 0, max = values.size(); i < max; i++){
			Value probability = values.get(i);

			if((maxProbability).compareTo(probability) == 0){
				HasProbability contributor = this.hasProbabilities.get(i);

				contributors.add(contributor);
			}
		}

		return averageMap(contributors);
	}

	public ValueMap medianMap(Collection categories){

		if(this.hasProbabilities == null){
			throw new IllegalStateException();
		} // End if

		if(this.weights != null){
			throw new IllegalStateException();
		}

		Function, Value> function = new Function, Value>(){

			@Override
			public Value apply(Vector values){
				return values.median();
			}
		};

		Map> medianMap = asTransformedMap(function);

		Map.Entry> winnerEntry = getWinner(medianMap, categories);
		if(winnerEntry == null){
			return new ValueMap<>();
		}

		String category = winnerEntry.getKey();
		Value medianProbability = winnerEntry.getValue();

		List contributors = new ArrayList<>();

		double minDifference = Double.MAX_VALUE;

		Vector values = get(category);

		for(int i = 0, max = values.size(); i < max; i++){
			Value probability = values.get(i);

			// Choose models whose probability is closest to the calculated median probability.
			// If the number of models is odd (the calculated median probability equals that of the middle model),
			// then all the chosen models will have the same probability (ie. difference == 0).
			// If the number of models is even (the calculated median probability equals the average of two middle-most models),
			// then some of the chosen models will have lower probabilies (ie. difference > 0), whereas the others will have higher probabilities (ie. difference < 0).
			double difference = Math.abs(medianProbability.doubleValue() - probability.doubleValue());

			if(difference < minDifference){
				contributors.clear();

				minDifference = difference;
			} // End if

			if(difference <= minDifference){
				HasProbability contributor = this.hasProbabilities.get(i);

				contributors.add(contributor);
			}
		}

		return averageMap(contributors);
	}

	private ValueMap averageMap(List hasProbabilities){

		if(hasProbabilities.size() == 1){
			HasProbability hasProbability = hasProbabilities.get(0);

			ValueFactory valueFactory = getValueFactory();

			ValueMap result = new ValueMap<>();

			Set categories = hasProbability.getCategories();
			for(String category : categories){
				Double probability = hasProbability.getProbability(category);

				Value value = valueFactory.newValue(probability);

				result.put(category, value);
			}

			return result;
		} else

		{
			ProbabilityAggregator aggregator = new ProbabilityAggregator(0){

				@Override
				public ValueFactory getValueFactory(){
					return ProbabilityAggregator.this.getValueFactory();
				}
			};

			for(HasProbability hasProbability : hasProbabilities){
				aggregator.add(hasProbability);
			}

			return aggregator.averageMap();
		}
	}

	static
	private  Map.Entry> getWinner(Map> values, Collection categories){
		Map.Entry> maxEntry = null;

		if(categories == null){
			categories = values.keySet();
		}

		for(String category : categories){
			Value value = values.get(category);

			if(value == null){
				continue;
			} // End if

			if(maxEntry == null || (maxEntry.getValue()).compareTo(value) < 0){
				maxEntry = new AbstractMap.SimpleEntry<>(category, value);
			}
		}

		return maxEntry;
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy