org.jpmml.evaluator.Classification Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of pmml-evaluator Show documentation
Show all versions of pmml-evaluator Show documentation
JPMML class model evaluator
/*
* Copyright (c) 2013 Villu Ruusmann
*
* This file is part of JPMML-Evaluator
*
* JPMML-Evaluator is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-Evaluator is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-Evaluator. If not, see .
*/
package org.jpmml.evaluator;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Set;
import com.google.common.collect.Lists;
import com.google.common.collect.Ordering;
import com.google.common.collect.Range;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.jpmml.model.ToStringHelper;
/**
* @see MiningFunction#CLASSIFICATION
* @see MiningFunction#CLUSTERING
*/
public class Classification extends AbstractComputable implements HasPrediction {
private Type type = null;
private ValueMap values = null;
private Object result = null;
protected Classification(Type type, ValueMap values){
setType(type);
setValues(values);
}
@Override
public Object getResult(){
if(this.result == null){
throw new EvaluationException("Classification result has not been computed");
}
return this.result;
}
protected void setResult(Object result){
this.result = result;
}
protected void computeResult(DataType dataType){
Map.Entry> entry = getWinner();
if(entry == null){
throw new EvaluationException("Empty classification");
}
K key = entry.getKey();
Value value = entry.getValue();
Object result = TypeUtil.parseOrCast(dataType, key);
setResult(result);
}
@Override
public Object getPrediction(){
return getResult();
}
@Override
public Report getPredictionReport(){
Map.Entry> entry = getWinner();
if(entry == null){
return null;
}
K key = entry.getKey();
Value value = entry.getValue();
return ReportUtil.getReport(value);
}
@Override
protected ToStringHelper toStringHelper(){
Type type = getType();
ValueMap values = getValues();
ToStringHelper helper = super.toStringHelper()
.add(type.entryKey(), values.entrySet());
return helper;
}
public void put(K key, Value value){
ValueMap values = getValues();
if(values.containsKey(key)){
throw new EvaluationException("Value for key " + EvaluationException.formatKey(key) + " has already been defined");
}
values.put(key, value);
}
public Double getValue(K key){
Type type = getType();
ValueMap values = getValues();
Value value = values.get(key);
return type.getValue(value);
}
public Report getValueReport(K key){
ValueMap values = getValues();
Value value = values.get(key);
return ReportUtil.getReport(value);
}
protected Map.Entry> getWinner(){
return getWinner(getType(), entrySet());
}
protected List>> getWinnerRanking(){
return getWinnerList(getType(), entrySet());
}
protected List getWinnerKeys(){
return entryKeys(getWinnerRanking());
}
protected List getWinnerValues(){
return Lists.transform(entryValues(getWinnerRanking()), Value::doubleValue);
}
protected Set keySet(){
ValueMap values = getValues();
return values.keySet();
}
protected Set>> entrySet(){
ValueMap values = getValues();
return values.entrySet();
}
public Type getType(){
return this.type;
}
private void setType(Type type){
this.type = Objects.requireNonNull(type);
}
public ValueMap getValues(){
return this.values;
}
private void setValues(ValueMap values){
this.values = Objects.requireNonNull(values);
}
static
public Map.Entry> getWinner(Type type, Collection>> entries){
Ordering>> ordering = Classification.createOrdering(type);
try {
return ordering.max(entries);
} catch(NoSuchElementException nsee){
return null;
}
}
static
public List>> getWinnerList(Type type, Collection>> entries){
Ordering>> ordering = (Classification.createOrdering(type)).reverse();
return ordering.sortedCopy(entries);
}
static
protected Ordering>> createOrdering(Type type){
return Ordering.from((Map.Entry> left, Map.Entry> right) -> type.compareValues(left.getValue(), right.getValue()));
}
static
public List entryKeys(List> entries){
return Lists.transform(entries, Map.Entry::getKey);
}
static
public List entryValues(List> entries){
return Lists.transform(entries, Map.Entry::getValue);
}
static
public enum Type {
PROBABILITY(true, Range.closed(Numbers.DOUBLE_ZERO, Numbers.DOUBLE_ONE)),
CONFIDENCE(true, Range.atLeast(Numbers.DOUBLE_ZERO)),
DISTANCE(false, Range.atLeast(Numbers.DOUBLE_ZERO)){
@Override
public Double getDefaultValue(){
return Double.POSITIVE_INFINITY;
}
},
SIMILARITY(true, Range.atLeast(Numbers.DOUBLE_ZERO)),
VOTE(true, Range.atLeast(Numbers.DOUBLE_ZERO)),
;
private boolean ordering;
private Range range;
private Type(boolean ordering, Range range){
setOrdering(ordering);
setRange(range);
}
public Double getValue(Value value){
// The specified value was not encountered during scoring
if(value == null){
return getDefaultValue();
}
return value.doubleValue();
}
public int compareValues(Value left, Value right){
boolean ordering = getOrdering();
int result = (left).compareTo(right);
return (ordering ? result : -result);
}
public boolean isValidValue(Value value){
Range range = getRange();
return range.contains(value.doubleValue());
}
/**
*
* Gets the least optimal value in the range of valid values.
*
*/
public Double getDefaultValue(){
return Numbers.DOUBLE_ZERO;
}
public String entryKey(){
String name = name();
return (name.toLowerCase() + "_entries");
}
public boolean getOrdering(){
return this.ordering;
}
private void setOrdering(boolean ordering){
this.ordering = ordering;
}
public Range getRange(){
return this.range;
}
private void setRange(Range range){
this.range = range;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy