
org.jpmml.evaluator.Classification Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of pmml-evaluator Show documentation
Show all versions of pmml-evaluator Show documentation
JPMML class model evaluator
/*
* Copyright (c) 2013 Villu Ruusmann
*
* This file is part of JPMML-Evaluator
*
* JPMML-Evaluator is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-Evaluator is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-Evaluator. If not, see .
*/
package org.jpmml.evaluator;
import java.util.Collection;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;
import com.google.common.base.Function;
import com.google.common.base.Objects;
import com.google.common.base.Objects.ToStringHelper;
import com.google.common.collect.Lists;
import com.google.common.collect.Ordering;
import com.google.common.collect.Range;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunctionType;
/**
* @see MiningFunctionType#CLASSIFICATION
* @see MiningFunctionType#CLUSTERING
*/
public class Classification implements Computable {
private Map map = new LinkedHashMap<>();
private Object result = null;
private Type type = null;
protected Classification(Type type){
setType(type);
}
@Override
public Object getResult(){
if(this.result == null){
throw new EvaluationException();
}
return this.result;
}
void computeResult(DataType dataType){
Map.Entry entry = getWinner();
if(entry == null){
throw new EvaluationException();
}
Object result = TypeUtil.parseOrCast(dataType, entry.getKey());
setResult(result);
}
void setResult(Object result){
this.result = result;
}
@Override
public String toString(){
ToStringHelper helper = toStringHelper();
return helper.toString();
}
protected ToStringHelper toStringHelper(){
ToStringHelper helper = Objects.toStringHelper(this)
.add("result", getResult())
.add(getType().entryKey(), entrySet());
return helper;
}
Double get(String key){
Double value = this.map.get(key);
// The specified value was not encountered during scoring
if(value == null){
Type type = getType();
return type.getDefault();
}
return value;
}
Double put(String key, Double value){
return this.map.put(key, value);
}
void putAll(Map values){
this.map.putAll(values);
}
boolean isEmpty(){
return this.map.isEmpty();
}
Map.Entry getWinner(){
return getWinner(getType(), entrySet());
}
List> getWinnerRanking(){
return getWinnerList(getType(), entrySet());
}
List getWinnerKeys(){
return entryKeys(getWinnerRanking());
}
List getWinnerValues(){
return entryValues(getWinnerRanking());
}
Double sumValues(){
return sum(this.map);
}
void normalizeValues(){
normalize(this.map);
}
Set keySet(){
return this.map.keySet();
}
Set> entrySet(){
return this.map.entrySet();
}
public Type getType(){
return this.type;
}
private void setType(Type type){
this.type = type;
}
static
Map.Entry getWinner(Type type, Collection> entries){
Ordering> ordering = createOrdering(type);
try {
return ordering.max(entries);
} catch(NoSuchElementException nsee){
return null;
}
}
static
List> getWinnerList(Type type, Collection> entries){
Ordering> ordering = (createOrdering(type)).reverse();
return ordering.sortedCopy(entries);
}
static
Ordering> createOrdering(final Type type){
Comparator> comparator = new Comparator>(){
@Override
public int compare(Map.Entry left, Map.Entry right){
return type.compare(left.getValue(), right.getValue());
}
};
return Ordering.from(comparator);
}
static
public List entryKeys(List> entries){
Function, K> function = new Function, K>(){
@Override
public K apply(Map.Entry entry){
return entry.getKey();
}
};
return Lists.transform(entries, function);
}
static
public List entryValues(List> entries){
Function, V> function = new Function, V>(){
@Override
public V apply(Map.Entry entry){
return entry.getValue();
}
};
return Lists.transform(entries, function);
}
static
public Double sum(Map map){
double sum = 0d;
Collection values = map.values();
for(Double value : values){
sum += value;
}
return sum;
}
static
public void normalize(Map map){
normalize(map, false);
}
static
public void normalizeSoftMax(Map map){
normalize(map, true);
}
static
private void normalize(Map map, boolean transform){
Collection> entries = map.entrySet();
double sum = 0d;
for(Map.Entry entry : entries){
Double value = entry.getValue();
sum += (transform ? Math.exp(value) : value);
}
for(Map.Entry entry : entries){
Double value = entry.getValue();
entry.setValue((transform ? Math.exp(value) : value) / sum);
}
}
private static final Ordering BIGGER_IS_BETTER = Ordering.natural();
private static final Ordering SMALLER_IS_BETTER = Ordering.natural().reverse();
static
public enum Type implements Comparator {
PROBABILITY(Classification.BIGGER_IS_BETTER, Range.closed(Values.DOUBLE_ZERO, Values.DOUBLE_ONE)),
CONFIDENCE(Classification.BIGGER_IS_BETTER, Range.atLeast(Values.DOUBLE_ZERO)),
DISTANCE(Classification.SMALLER_IS_BETTER, Range.atLeast(Values.DOUBLE_ZERO)){
@Override
public Double getDefault(){
return Double.POSITIVE_INFINITY;
}
},
SIMILARITY(Classification.BIGGER_IS_BETTER, Range.atLeast(Values.DOUBLE_ZERO)),
VOTE(Classification.BIGGER_IS_BETTER, Range.atLeast(Values.DOUBLE_ZERO)),
;
private Ordering ordering;
private Range range;
private Type(Ordering ordering, Range range){
setOrdering(ordering);
setRange(range);
}
/**
*
* Calculates the order between arguments.
*
*
* @param left A value
* @param right The reference value
*/
@Override
public int compare(Double left, Double right){
// The behaviour of missing values in comparison operations is not defined
if(left == null || right == null){
throw new EvaluationException();
}
Ordering ordering = getOrdering();
return ordering.compare(left, right);
}
/**
*
* Gets the least optimal value in the range of valid values.
*
*/
public Double getDefault(){
return Values.DOUBLE_ZERO;
}
public boolean isValid(Double value){
Range range = getRange();
return range.contains(value);
}
protected String entryKey(){
String name = name();
return (name.toLowerCase() + "_entries");
}
public Ordering getOrdering(){
return this.ordering;
}
private void setOrdering(Ordering ordering){
this.ordering = ordering;
}
public Range getRange(){
return this.range;
}
private void setRange(Range range){
this.range = range;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy