org.jpmml.evaluator.ModelEvaluator Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of pmml-evaluator Show documentation
Show all versions of pmml-evaluator Show documentation
JPMML class model evaluator
/*
* Copyright (c) 2016 Villu Ruusmann
*
* This file is part of JPMML-Evaluator
*
* JPMML-Evaluator is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-Evaluator is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-Evaluator. If not, see .
*/
package org.jpmml.evaluator;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import com.google.common.collect.Iterables;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.google.common.collect.Sets.SetView;
import com.google.common.collect.Table;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Field;
import org.dmg.pmml.InlineTable;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.ModelVerification;
import org.dmg.pmml.PMML;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.VerificationField;
import org.dmg.pmml.VerificationFields;
import org.jpmml.model.InvalidAttributeException;
import org.jpmml.model.InvalidElementException;
import org.jpmml.model.UnsupportedAttributeException;
/**
* @see ModelEvaluatorBuilder
*/
abstract
public class ModelEvaluator extends ModelManager implements Evaluator {
private Configuration configuration = null;
private InputMapper inputMapper = null;
private ResultMapper resultMapper = null;
private ValueFactory> valueFactory = null;
private Boolean parentCompatible = null;
private Boolean pure = null;
private Integer numberOfVisibleFields = null;
protected ModelEvaluator(){
}
protected ModelEvaluator(PMML pmml, M model){
super(pmml, model);
MathContext mathContext = model.getMathContext();
switch(mathContext){
case FLOAT:
case DOUBLE:
break;
default:
throw new UnsupportedAttributeException(model, mathContext);
}
}
/**
*
* Configures the runtime behaviour of this model evaluator.
*
*
*
* Must be called once before the first evaluation.
* May be called any number of times between subsequent evaluations.
*
*/
public void configure(Configuration configuration){
setConfiguration(configuration);
setValueFactory(null);
resetInputFields();
resetResultFields();
}
/**
*
* Indicates if this model evaluator is compatible with its parent model evaluator.
*
*
*
* A parent compatible model evaluator inherits {@link DataField} declarations unchanged,
* which makes it possible to propagate {@link DataField} and global {@link DerivedField} values between evaluation contexts during evaluation.
*
*/
public boolean isParentCompatible(){
if(this.parentCompatible == null){
this.parentCompatible = assessParentCompatibility();
}
return this.parentCompatible;
}
/**
*
* Indicates if this model evaluator represents a pure function.
*
*
*
* A pure model evaluator does not tamper with the evaluation context during evaluation.
*
*/
public boolean isPure(){
if(this.pure == null){
this.pure = assessPurity();
}
return this.pure;
}
protected int getNumberOfVisibleFields(){
if(this.numberOfVisibleFields == null){
ListMultimap> visibleFields = getVisibleFields();
this.numberOfVisibleFields = visibleFields.size();
}
return this.numberOfVisibleFields;
}
@Override
public ModelEvaluator verify(){
M model = getModel();
ModelVerification modelVerification = model.getModelVerification();
if(modelVerification == null){
return this;
}
VerificationBatch batch = parseModelVerification(modelVerification);
List extends Map> records = batch.getRecords();
List inputFields = getInputFields();
if(this instanceof HasGroupFields){
HasGroupFields hasGroupFields = (HasGroupFields)this;
records = EvaluatorUtil.groupRows(hasGroupFields, records);
}
List targetFields = getTargetFields();
List outputFields = getOutputFields();
SetView intersection = Sets.intersection(batch.keySet(), new LinkedHashSet<>(Lists.transform(outputFields, OutputField::getFieldName)));
boolean disjoint = intersection.isEmpty();
for(Map record : records){
Map arguments = new LinkedHashMap<>();
for(InputField inputField : inputFields){
String name = inputField.getFieldName();
FieldValue value = inputField.prepare(record.get(name));
arguments.put(name, value);
}
ModelEvaluationContext context = createEvaluationContext();
context.setArguments(arguments);
Map results = evaluateInternal(context);
// "If there exist VerificationField elements that refer to OutputField elements,
// then any VerificationField element that refers to a MiningField element whose "usageType=target" should be ignored,
// because they are considered to represent a dependent variable from the training data set, not an expected output"
if(!disjoint){
for(OutputField outputField : outputFields){
String name = outputField.getFieldName();
VerificationField verificationField = batch.get(name);
if(verificationField == null){
continue;
}
verify(record.get(name), results.get(name), verificationField.getPrecision(), verificationField.getZeroThreshold());
}
} else
// "If there are no such VerificationField elements,
// then any VerificationField element that refers to a MiningField element whose "usageType=target" should be considered to represent an expected output"
{
for(TargetField targetField : targetFields){
String name = targetField.getFieldName();
VerificationField verificationField = batch.get(name);
if(verificationField == null){
continue;
}
Number precision = verificationField.getPrecision();
Number zeroThreshold = verificationField.getZeroThreshold();
verify(record.get(name), EvaluatorUtil.decode(results.get(name)), precision, zeroThreshold);
}
}
}
return this;
}
private void verify(Object expected, Object actual, Number precision, Number zeroThreshold){
if(expected == null){
return;
} // End if
if(actual instanceof Collection){
// Ignored
} else
{
DataType dataType = TypeUtil.getDataType(actual);
expected = TypeUtil.parseOrCast(dataType, expected);
}
boolean acceptable = VerificationUtil.acceptable(expected, actual, precision.doubleValue(), zeroThreshold.doubleValue());
if(!acceptable){
throw new EvaluationException("Values " + EvaluationException.formatValue(expected) + " and " + EvaluationException.formatValue(actual) + " do not match");
}
}
public ModelEvaluationContext createEvaluationContext(){
return new ModelEvaluationContext(this);
}
@Override
public Map evaluate(Map arguments){
Configuration configuration = ensureConfiguration();
SymbolTable prevDerivedFieldGuard = null;
SymbolTable derivedFieldGuard = configuration.getDerivedFieldGuard();
SymbolTable prevFunctionGuard = null;
SymbolTable functionGuard = configuration.getFunctionGuard();
arguments = processArguments(arguments);
ModelEvaluationContext context = createEvaluationContext();
context.setArguments(arguments);
Map results;
try {
if(derivedFieldGuard != null){
prevDerivedFieldGuard = EvaluationContext.DERIVEDFIELD_GUARD_PROVIDER.get();
EvaluationContext.DERIVEDFIELD_GUARD_PROVIDER.set(derivedFieldGuard.fork());
} // End if
if(functionGuard != null){
prevFunctionGuard = EvaluationContext.FUNCTION_GUARD_PROVIDER.get();
EvaluationContext.FUNCTION_GUARD_PROVIDER.set(functionGuard.fork());
}
results = evaluateInternal(context);
} finally {
if(derivedFieldGuard != null){
EvaluationContext.DERIVEDFIELD_GUARD_PROVIDER.set(prevDerivedFieldGuard);
} // End if
if(functionGuard != null){
EvaluationContext.FUNCTION_GUARD_PROVIDER.set(prevFunctionGuard);
}
}
results = processResults(results);
return results;
}
protected Map processArguments(Map arguments){
InputMapper inputMapper = getInputMapper();
if(inputMapper != null){
Map remappedArguments = new AbstractMap(){
@Override
public Object get(Object key){
return arguments.get(inputMapper.apply((String)key));
}
@Override
public Set> entrySet(){
throw new UnsupportedOperationException();
}
};
return remappedArguments;
}
return arguments;
}
protected Map processResults(Map results){
ResultMapper resultMapper = getResultMapper();
if(results instanceof OutputMap){
OutputMap outputMap = (OutputMap)results;
outputMap.clearPrivate();
} // End if
if(resultMapper != null){
if(results.isEmpty()){
return results;
} else
if(results.size() == 1){
Map.Entry entry = Iterables.getOnlyElement(results.entrySet());
return Collections.singletonMap(resultMapper.apply(entry.getKey()), entry.getValue());
}
Map remappedResults = new LinkedHashMap<>(2 * results.size());
Collection extends Map.Entry> entries = results.entrySet();
for(Map.Entry entry : entries){
remappedResults.put(resultMapper.apply(entry.getKey()), entry.getValue());
}
return remappedResults;
}
return results;
}
@Override
protected List filterInputFields(List inputFields){
InputMapper inputMapper = getInputMapper();
if(inputMapper != null){
inputFields = updateNames(inputFields, inputMapper);
}
return inputFields;
}
@Override
protected List filterTargetFields(List targetFields){
ResultMapper resultMapper = getResultMapper();
if(resultMapper != null){
targetFields = updateNames(targetFields, resultMapper);
}
return targetFields;
}
@Override
protected List filterOutputFields(List outputFields){
ResultMapper resultMapper = getResultMapper();
if(!outputFields.isEmpty()){
OutputFilter outputFilter = ensureOutputFilter();
for(Iterator it = outputFields.iterator(); it.hasNext(); ){
OutputField outputField = it.next();
org.dmg.pmml.OutputField pmmlOutputField = outputField.getField();
if(!outputFilter.test(pmmlOutputField)){
it.remove();
}
}
} // End if
if(resultMapper != null){
outputFields = updateNames(outputFields, resultMapper);
}
return outputFields;
}
public Map evaluateInternal(ModelEvaluationContext context){
M model = getModel();
if(!model.isScorable()){
throw new EvaluationException("Model is not scorable", model);
}
ValueFactory> valueFactory;
MathContext mathContext = model.getMathContext();
switch(mathContext){
case FLOAT:
case DOUBLE:
valueFactory = ensureValueFactory();
break;
default:
throw new UnsupportedAttributeException(model, mathContext);
}
Map predictions;
MiningFunction miningFunction = model.requireMiningFunction();
switch(miningFunction){
case REGRESSION:
predictions = evaluateRegression(valueFactory, context);
break;
case CLASSIFICATION:
predictions = evaluateClassification(valueFactory, context);
break;
case CLUSTERING:
predictions = evaluateClustering(valueFactory, context);
break;
case ASSOCIATION_RULES:
predictions = evaluateAssociationRules(valueFactory, context);
break;
case SEQUENCES:
predictions = evaluateSequences(valueFactory, context);
break;
case TIME_SERIES:
predictions = evaluateTimeSeries(valueFactory, context);
break;
case MIXED:
predictions = evaluateMixed(valueFactory, context);
break;
default:
throw new UnsupportedAttributeException(model, miningFunction);
}
predictions = evaluateOutput(predictions, context);
return predictions;
}
protected Map evaluateRegression(ValueFactory valueFactory, EvaluationContext context){
return evaluateDefault();
}
protected Map evaluateClassification(ValueFactory valueFactory, EvaluationContext context){
return evaluateDefault();
}
protected Map evaluateClustering(ValueFactory valueFactory, EvaluationContext context){
return evaluateDefault();
}
protected Map evaluateAssociationRules(ValueFactory valueFactory, EvaluationContext context){
return evaluateDefault();
}
protected Map evaluateSequences(ValueFactory valueFactory, EvaluationContext context){
return evaluateDefault();
}
protected Map evaluateTimeSeries(ValueFactory valueFactory, EvaluationContext context){
return evaluateDefault();
}
protected Map evaluateMixed(ValueFactory valueFactory, EvaluationContext context){
return evaluateDefault();
}
private Map evaluateDefault(){
Model model = getModel();
MiningFunction miningFunction = model.requireMiningFunction();
throw new InvalidAttributeException(model, miningFunction);
}
protected Map evaluateOutput(Map predictions, ModelEvaluationContext context){
return OutputUtil.evaluate(predictions, context);
}
protected Classification
© 2015 - 2025 Weber Informatics LLC | Privacy Policy