All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.evaluator.ProbabilityAggregator Maven / Gradle / Ivy

There is a newer version: 1.6.11
Show newest version
/*
 * Copyright (c) 2014 Villu Ruusmann
 *
 * This file is part of JPMML-Evaluator
 *
 * JPMML-Evaluator is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-Evaluator is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-Evaluator.  If not, see .
 */
package org.jpmml.evaluator;

import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

import com.google.common.base.Function;

public class ProbabilityAggregator extends ClassificationAggregator {

	private List hasProbabilities = null;


	public ProbabilityAggregator(){
		this(0);
	}

	public ProbabilityAggregator(int capacity){
		super(capacity);

		if(capacity > 0){
			this.hasProbabilities = new ArrayList<>(capacity);
		}
	}

	public void add(HasProbability hasProbability){
		add(hasProbability, 1d);
	}

	public void add(HasProbability hasProbability, double weight){

		if(this.hasProbabilities != null){
			this.hasProbabilities.add(hasProbability);
		}

		Set categories = hasProbability.getCategoryValues();
		for(String category : categories){
			Double probability = hasProbability.getProbability(category);

			add(category, weight != 1d ? (probability * weight) : probability);
		}
	}

	public Map maxMap(Collection categories){

		if(this.hasProbabilities == null){
			throw new IllegalStateException();
		}

		Function function = new Function(){

			@Override
			public Double apply(DoubleVector values){
				return values.max();
			}
		};

		Map maxValues = transform(function);

		Map.Entry maxMaxValue = getWinner(maxValues, categories);
		if(maxMaxValue == null){
			return Collections.emptyMap();
		}

		String category = maxMaxValue.getKey();
		double maxProbability = maxMaxValue.getValue();

		List contributors = new ArrayList<>();

		DoubleVector values = get(category);
		for(int i = 0; i < values.size(); i++){
			double probability = values.get(i);

			if(probability == maxProbability){
				HasProbability contributor = this.hasProbabilities.get(i);

				contributors.add(contributor);
			}
		}

		return averageMap(contributors);
	}

	public Map medianMap(Collection categories){

		if(this.hasProbabilities == null){
			throw new IllegalStateException();
		}

		Function function = new Function(){

			@Override
			public Double apply(DoubleVector values){
				return values.median();
			}
		};

		Map medianValues = transform(function);

		Map.Entry maxMedianValue = getWinner(medianValues, categories);
		if(maxMedianValue == null){
			return Collections.emptyMap();
		}

		String category = maxMedianValue.getKey();
		double medianProbability = maxMedianValue.getValue();

		List contributors = new ArrayList<>();

		double minDifference = Double.MAX_VALUE;

		DoubleVector values = get(category);
		for(int i = 0; i < values.size(); i++){
			double probability = values.get(i);

			// Choose models whose probability is closest to the calculated median probability.
			// If the number of models is odd (the calculated median probability equals that of the middle model),
			// then all the chosen models will have the same probability (ie. difference == 0).
			// If the number of models is even (the calculated median probability equals the average of two middle-most models),
			// then some of the chosen models will have lower probabilies (ie. difference > 0), whereas the others will have higher probabilities (ie. difference < 0).
			double difference = Math.abs(medianProbability - probability);

			if(difference < minDifference){
				contributors.clear();

				minDifference = difference;
			} // End if

			if(difference <= minDifference){
				HasProbability contributor = this.hasProbabilities.get(i);

				contributors.add(contributor);
			}
		}

		return averageMap(contributors);
	}

	public Map averageMap(final double denominator){
		Function function = new Function(){

			@Override
			public Double apply(DoubleVector values){
				return values.sum() / denominator;
			}
		};

		return transform(function);
	}

	static
	private Map.Entry getWinner(Map values, Collection categories){

		if(categories == null || categories.isEmpty()){
			throw new EvaluationException();
		}

		Map.Entry maxEntry = null;

		for(String category : categories){
			Double value = values.get(category);

			if(value == null){
				continue;
			} // End if

			if(maxEntry == null || (maxEntry.getValue()).compareTo(value) < 0){
				maxEntry = new AbstractMap.SimpleEntry<>(category, value);
			}
		}

		return maxEntry;
	}

	static
	private Map averageMap(List hasProbabilities){

		if(hasProbabilities.size() == 1){
			HasProbability hasProbability = hasProbabilities.get(0);

			Map result = new LinkedHashMap<>();

			Set categories = hasProbability.getCategoryValues();
			for(String category : categories){
				Double probability = hasProbability.getProbability(category);

				result.put(category, probability);
			}

			return result;
		} else

		{
			ProbabilityAggregator aggregator = new ProbabilityAggregator();

			for(HasProbability hasProbability : hasProbabilities){
				aggregator.add(hasProbability);
			}

			return aggregator.averageMap(hasProbabilities.size());
		}
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy