All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.evaluator.Classification Maven / Gradle / Ivy

/*
 * Copyright (c) 2013 Villu Ruusmann
 *
 * This file is part of JPMML-Evaluator
 *
 * JPMML-Evaluator is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-Evaluator is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-Evaluator.  If not, see .
 */
package org.jpmml.evaluator;

import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Set;

import com.google.common.collect.Lists;
import com.google.common.collect.Ordering;
import com.google.common.collect.Range;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.jpmml.model.ToStringHelper;

/**
 * @see MiningFunction#CLASSIFICATION
 * @see MiningFunction#CLUSTERING
 */
public class Classification extends AbstractComputable implements HasPrediction {

	private Type type = null;

	private ValueMap values = null;

	private Object result = null;


	protected Classification(Type type, ValueMap values){
		setType(type);
		setValues(values);
	}

	@Override
	public Object getResult(){

		if(this.result == null){
			throw new EvaluationException("Classification result has not been computed");
		}

		return this.result;
	}

	protected void setResult(Object result){
		this.result = result;
	}

	protected void computeResult(DataType dataType){
		Map.Entry> entry = getWinner();

		if(entry == null){
			throw new EvaluationException("Empty classification");
		}

		K key = entry.getKey();
		Value value = entry.getValue();

		Object result = TypeUtil.parseOrCast(dataType, key);

		setResult(result);
	}

	@Override
	public Object getPrediction(){
		return getResult();
	}

	@Override
	public Report getPredictionReport(){
		Map.Entry> entry = getWinner();

		if(entry == null){
			return null;
		}

		K key = entry.getKey();
		Value value = entry.getValue();

		return ReportUtil.getReport(value);
	}

	@Override
	protected ToStringHelper toStringHelper(){
		Type type = getType();
		ValueMap values = getValues();

		ToStringHelper helper = super.toStringHelper()
			.add(type.entryKey(), values.entrySet());

		return helper;
	}

	public void put(K key, Value value){
		ValueMap values = getValues();

		if(values.containsKey(key)){
			throw new EvaluationException("Value for key " + EvaluationException.formatKey(key) + " has already been defined");
		}

		values.put(key, value);
	}

	public Double getValue(K key){
		Type type = getType();
		ValueMap values = getValues();

		Value value = values.get(key);

		return type.getValue(value);
	}

	public Report getValueReport(K key){
		ValueMap values = getValues();

		Value value = values.get(key);

		return ReportUtil.getReport(value);
	}

	protected Map.Entry> getWinner(){
		return getWinner(getType(), entrySet());
	}

	protected List>> getWinnerRanking(){
		return getWinnerList(getType(), entrySet());
	}

	protected List getWinnerKeys(){
		return entryKeys(getWinnerRanking());
	}

	protected List getWinnerValues(){
		return Lists.transform(entryValues(getWinnerRanking()), Value::doubleValue);
	}

	protected Set keySet(){
		ValueMap values = getValues();

		return values.keySet();
	}

	protected Set>> entrySet(){
		ValueMap values = getValues();

		return values.entrySet();
	}

	public Type getType(){
		return this.type;
	}

	private void setType(Type type){
		this.type = Objects.requireNonNull(type);
	}

	public ValueMap getValues(){
		return this.values;
	}

	private void setValues(ValueMap values){
		this.values = Objects.requireNonNull(values);
	}

	static
	public  Map.Entry> getWinner(Type type, Collection>> entries){
		Ordering>> ordering = Classification.createOrdering(type);

		try {
			return ordering.max(entries);
		} catch(NoSuchElementException nsee){
			return null;
		}
	}

	static
	public  List>> getWinnerList(Type type, Collection>> entries){
		Ordering>> ordering = (Classification.createOrdering(type)).reverse();

		return ordering.sortedCopy(entries);
	}

	static
	protected  Ordering>> createOrdering(Type type){
		return Ordering.from((Map.Entry> left, Map.Entry> right) -> type.compareValues(left.getValue(), right.getValue()));
	}

	static
	public  List entryKeys(List> entries){
		return Lists.transform(entries, Map.Entry::getKey);
	}

	static
	public  List entryValues(List> entries){
		return Lists.transform(entries, Map.Entry::getValue);
	}

	static
	public enum Type {
		PROBABILITY(true, Range.closed(Numbers.DOUBLE_ZERO, Numbers.DOUBLE_ONE)),
		CONFIDENCE(true, Range.atLeast(Numbers.DOUBLE_ZERO)),
		DISTANCE(false, Range.atLeast(Numbers.DOUBLE_ZERO)){

			@Override
			public Double getDefaultValue(){
				return Double.POSITIVE_INFINITY;
			}
		},
		SIMILARITY(true, Range.atLeast(Numbers.DOUBLE_ZERO)),
		VOTE(true, Range.atLeast(Numbers.DOUBLE_ZERO)),
		;

		private boolean ordering;

		private Range range;


		private Type(boolean ordering, Range range){
			setOrdering(ordering);
			setRange(range);
		}

		public  Double getValue(Value value){

			// The specified value was not encountered during scoring
			if(value == null){
				return getDefaultValue();
			}

			return value.doubleValue();
		}

		public  int compareValues(Value left, Value right){
			boolean ordering = getOrdering();

			int result = (left).compareTo(right);

			return (ordering ? result : -result);
		}

		public  boolean isValidValue(Value value){
			Range range = getRange();

			return range.contains(value.doubleValue());
		}

		/**
		 * 

* Gets the least optimal value in the range of valid values. *

*/ public Double getDefaultValue(){ return Numbers.DOUBLE_ZERO; } public String entryKey(){ String name = name(); return (name.toLowerCase() + "_entries"); } public boolean getOrdering(){ return this.ordering; } private void setOrdering(boolean ordering){ this.ordering = ordering; } public Range getRange(){ return this.range; } private void setRange(Range range){ this.range = range; } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy