All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.evaluator.Classification Maven / Gradle / Ivy

There is a newer version: 1.7.2
Show newest version
/*
 * Copyright (c) 2013 Villu Ruusmann
 *
 * This file is part of JPMML-Evaluator
 *
 * JPMML-Evaluator is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-Evaluator is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-Evaluator.  If not, see .
 */
package org.jpmml.evaluator;

import java.util.Collection;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;

import com.google.common.base.Function;
import com.google.common.base.Objects;
import com.google.common.base.Objects.ToStringHelper;
import com.google.common.collect.Lists;
import com.google.common.collect.Ordering;
import com.google.common.collect.Range;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunctionType;

/**
 * @see MiningFunctionType#CLASSIFICATION
 * @see MiningFunctionType#CLUSTERING
 */
public class Classification implements Computable {

	private Map map = new LinkedHashMap<>();

	private Object result = null;

	private Type type = null;


	protected Classification(Type type){
		setType(type);
	}

	@Override
	public Object getResult(){

		if(this.result == null){
			throw new EvaluationException();
		}

		return this.result;
	}

	void computeResult(DataType dataType){
		Map.Entry entry = getWinner();
		if(entry == null){
			throw new EvaluationException();
		}

		Object result = TypeUtil.parseOrCast(dataType, entry.getKey());

		setResult(result);
	}

	void setResult(Object result){
		this.result = result;
	}

	@Override
	public String toString(){
		ToStringHelper helper = toStringHelper();

		return helper.toString();
	}

	protected ToStringHelper toStringHelper(){
		ToStringHelper helper = Objects.toStringHelper(this)
			.add("result", getResult())
			.add(getType().entryKey(), entrySet());

		return helper;
	}

	Double get(String key){
		Double value = this.map.get(key);

		// The specified value was not encountered during scoring
		if(value == null){
			Type type = getType();

			return type.getDefault();
		}

		return value;
	}

	Double put(String key, Double value){
		return this.map.put(key, value);
	}

	void putAll(Map values){
		this.map.putAll(values);
	}

	boolean isEmpty(){
		return this.map.isEmpty();
	}

	Map.Entry getWinner(){
		return getWinner(getType(), entrySet());
	}

	List> getWinnerRanking(){
		return getWinnerList(getType(), entrySet());
	}

	List getWinnerKeys(){
		return entryKeys(getWinnerRanking());
	}

	List getWinnerValues(){
		return entryValues(getWinnerRanking());
	}

	Double sumValues(){
		return sum(this.map);
	}

	void normalizeValues(){
		normalize(this.map);
	}

	Set keySet(){
		return this.map.keySet();
	}

	Set> entrySet(){
		return this.map.entrySet();
	}

	public Type getType(){
		return this.type;
	}

	private void setType(Type type){
		this.type = type;
	}

	static
	Map.Entry getWinner(Type type, Collection> entries){
		Ordering> ordering = createOrdering(type);

		try {
			return ordering.max(entries);
		} catch(NoSuchElementException nsee){
			return null;
		}
	}

	static
	List> getWinnerList(Type type, Collection> entries){
		Ordering> ordering = (createOrdering(type)).reverse();

		return ordering.sortedCopy(entries);
	}

	static
	Ordering> createOrdering(final Type type){
		Comparator> comparator = new Comparator>(){

			@Override
			public int compare(Map.Entry left, Map.Entry right){
				return type.compare(left.getValue(), right.getValue());
			}
		};

		return Ordering.from(comparator);
	}

	static
	public  List entryKeys(List> entries){
		Function, K> function = new Function, K>(){

			@Override
			public K apply(Map.Entry entry){
				return entry.getKey();
			}
		};

		return Lists.transform(entries, function);
	}

	static
	public  List entryValues(List> entries){
		Function, V> function = new Function, V>(){

			@Override
			public V apply(Map.Entry entry){
				return entry.getValue();
			}
		};

		return Lists.transform(entries, function);
	}

	static
	public  Double sum(Map map){
		return sum(map, null);
	}

	static
	private  Double sum(Map map, Function function){
		double sum = 0d;

		Collection values = map.values();
		for(Double value : values){

			if(function != null){
				value = function.apply(value);
			}

			sum += value.doubleValue();
		}

		return sum;
	}

	static
	public  void normalize(Map map){
		normalize(map, null);
	}

	static
	public  void normalizeSoftMax(Map map){
		Function function = new Function(){

			@Override
			public Double apply(Double value){
				return Math.exp(value.doubleValue());
			}
		};

		normalize(map, function);
	}

	static
	private  void normalize(Map map, Function function){
		double sum = sum(map, function);

		Collection> entries = map.entrySet();
		for(Map.Entry entry : entries){
			Double value = entry.getValue();

			if(function != null){
				value = function.apply(value);
			}

			entry.setValue(value / sum);
		}
	}

	private static final Ordering BIGGER_IS_BETTER = Ordering.natural();
	private static final Ordering SMALLER_IS_BETTER = Ordering.natural().reverse();

	static
	public enum Type implements Comparator {
		PROBABILITY(Classification.BIGGER_IS_BETTER, Range.closed(0d, 1d)),
		CONFIDENCE(Classification.BIGGER_IS_BETTER, Range.atLeast(0d)),
		DISTANCE(Classification.SMALLER_IS_BETTER, Range.atLeast(0d)){

			@Override
			public double getDefault(){
				return Double.POSITIVE_INFINITY;
			}
		},
		SIMILARITY(Classification.BIGGER_IS_BETTER, Range.atLeast(0d)),
		VOTE(Classification.BIGGER_IS_BETTER, Range.atLeast(0d)),
		;

		private Ordering ordering;

		private Range range;


		private Type(Ordering ordering, Range range){
			setOrdering(ordering);
			setRange(range);
		}

		/**
		 * Calculates the order between arguments.
		 *
		 * @param left A value
		 * @param right The reference value
		 */
		@Override
		public int compare(Double left, Double right){

			// The behaviour of missing values in comparison operations is not defined
			if(left == null || right == null){
				throw new EvaluationException();
			}

			Ordering ordering = getOrdering();

			return ordering.compare(left, right);
		}

		/**
		 * Gets the least optimal value in the range of valid values.
		 */
		public double getDefault(){
			return 0d;
		}

		public boolean isValid(Double value){
			Range range = getRange();

			return range.contains(value);
		}

		protected String entryKey(){
			String name = name();

			return (name.toLowerCase() + "_entries");
		}

		public Ordering getOrdering(){
			return this.ordering;
		}

		private void setOrdering(Ordering ordering){
			this.ordering = ordering;
		}

		public Range getRange(){
			return this.range;
		}

		private void setRange(Range range){
			this.range = range;
		}
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy