org.jpmml.evaluator.ModelEvaluator Maven / Gradle / Ivy
/*
* Copyright (c) 2016 Villu Ruusmann
*
* This file is part of JPMML-Evaluator
*
* JPMML-Evaluator is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-Evaluator is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-Evaluator. If not, see .
*/
package org.jpmml.evaluator;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.Callable;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.collect.Sets;
import com.google.common.collect.Sets.SetView;
import com.google.common.collect.Table;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DefineFunction;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.InlineTable;
import org.dmg.pmml.LocalTransformations;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.ModelVerification;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.PMML;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.Target;
import org.dmg.pmml.Targets;
import org.dmg.pmml.TransformationDictionary;
import org.dmg.pmml.VerificationField;
import org.dmg.pmml.VerificationFields;
abstract
public class ModelEvaluator implements Evaluator, Serializable {
private PMML pmml = null;
private M model = null;
private ModelEvaluatorFactory modelEvaluatorFactory = null;
private ValueFactoryFactory valueFactoryFactory = null;
private ValueFactory> valueFactory = null;
private Map dataFields = Collections.emptyMap();
private Map derivedFields = Collections.emptyMap();
private Map defineFunctions = Collections.emptyMap();
private Map miningFields = Collections.emptyMap();
private Map localDerivedFields = Collections.emptyMap();
private Map targets = Collections.emptyMap();
private Map outputFields = Collections.emptyMap();
transient
private List inputFields = null;
transient
private List activeInputFields = null;
transient
private List targetResultFields = null;
transient
private List outputResultFields = null;
protected ModelEvaluator(PMML pmml, M model){
setPMML(Objects.requireNonNull(pmml));
setModel(Objects.requireNonNull(model));
DataDictionary dataDictionary = pmml.getDataDictionary();
if(dataDictionary == null){
throw new MissingElementException(pmml, PMMLElements.PMML_DATADICTIONARY);
} // End if
if(dataDictionary.hasDataFields()){
this.dataFields = CacheUtil.getValue(dataDictionary, ModelEvaluator.dataFieldCache);
}
TransformationDictionary transformationDictionary = pmml.getTransformationDictionary();
if(transformationDictionary != null && transformationDictionary.hasDerivedFields()){
this.derivedFields = CacheUtil.getValue(transformationDictionary, ModelEvaluator.derivedFieldCache);
} // End if
if(transformationDictionary != null && transformationDictionary.hasDefineFunctions()){
this.defineFunctions = CacheUtil.getValue(transformationDictionary, ModelEvaluator.defineFunctionCache);
}
MiningFunction miningFunction = model.getMiningFunction();
if(miningFunction == null){
throw new MissingAttributeException(MissingAttributeException.formatMessage(XPathUtil.formatElement(model.getClass()) + "@miningFunction"), model);
}
MiningSchema miningSchema = model.getMiningSchema();
if(miningSchema == null){
throw new MissingElementException(MissingElementException.formatMessage(XPathUtil.formatElement(model.getClass()) + "/" + XPathUtil.formatElement(MiningSchema.class)), model);
} // End if
if(miningSchema.hasMiningFields()){
this.miningFields = CacheUtil.getValue(miningSchema, ModelEvaluator.miningFieldCache);
}
LocalTransformations localTransformations = model.getLocalTransformations();
if(localTransformations != null && localTransformations.hasDerivedFields()){
this.localDerivedFields = CacheUtil.getValue(localTransformations, ModelEvaluator.localDerivedFieldCache);
}
Targets targets = model.getTargets();
if(targets != null && targets.hasTargets()){
this.targets = CacheUtil.getValue(targets, ModelEvaluator.targetCache);
}
Output output = model.getOutput();
if(output != null && output.hasOutputFields()){
this.outputFields = CacheUtil.getValue(output, ModelEvaluator.outputFieldCache);
}
}
abstract
public Map evaluate(ModelEvaluationContext context);
/**
*
* Configures the runtime behaviour of this Evaluator instance.
*
*
*
* Must be called once before the first evaluation.
* May be called any number of times between subsequent evaluations.
*
*/
public void configure(ModelEvaluatorFactory modelEvaluatorFactory){
setModelEvaluatorFactory(modelEvaluatorFactory);
setValueFactoryFactory(null);
setValueFactory(null);
}
@Override
public MiningFunction getMiningFunction(){
M model = getModel();
return model.getMiningFunction();
}
public MathContext getMathContext(){
M model = getModel();
return model.getMathContext();
}
public DataField getDataField(FieldName name){
if(Objects.equals(Evaluator.DEFAULT_TARGET_NAME, name)){
return getDataField();
}
return this.dataFields.get(name);
}
/**
* @return A synthetic {@link DataField} element describing the default target field.
*/
protected DataField getDataField(){
MiningFunction miningFunction = getMiningFunction();
switch(miningFunction){
case REGRESSION:
MathContext mathContext = getMathContext();
switch(mathContext){
case FLOAT:
return ModelEvaluator.DEFAULT_TARGET_CONTINUOUS_FLOAT;
default:
return ModelEvaluator.DEFAULT_TARGET_CONTINUOUS_DOUBLE;
}
case CLASSIFICATION:
case CLUSTERING:
return ModelEvaluator.DEFAULT_TARGET_CATEGORICAL_STRING;
default:
return null;
}
}
public DerivedField getDerivedField(FieldName name){
return this.derivedFields.get(name);
}
public DefineFunction getDefineFunction(String name){
return this.defineFunctions.get(name);
}
public MiningField getMiningField(FieldName name){
if(Objects.equals(Evaluator.DEFAULT_TARGET_NAME, name)){
return null;
}
return this.miningFields.get(name);
}
public DerivedField getLocalDerivedField(FieldName name){
return this.localDerivedFields.get(name);
}
public Target getTarget(FieldName name){
return this.targets.get(name);
}
public org.dmg.pmml.OutputField getOutputField(FieldName name){
return this.outputFields.get(name);
}
public boolean isPrimitive(){
return this.localDerivedFields.isEmpty() && this.outputFields.isEmpty();
}
@Override
public List getInputFields(){
if(this.inputFields == null){
this.inputFields = createInputFields();
}
return this.inputFields;
}
InputField findInputField(FieldName name){
return findModelField(getInputFields(), name);
}
@Override
public List getActiveFields(){
if(this.activeInputFields == null){
this.activeInputFields = createInputFields(MiningField.UsageType.ACTIVE);
}
return this.activeInputFields;
}
@Override
public List getTargetFields(){
if(this.targetResultFields == null){
this.targetResultFields = createTargetFields();
}
return this.targetResultFields;
}
public TargetField getTargetField(){
List targetFields = getTargetFields();
if(targetFields.size() != 1){
throw createMiningSchemaException("Expected 1 target field, got " + targetFields.size() + " target fields");
}
TargetField targetField = targetFields.get(0);
return targetField;
}
TargetField findTargetField(FieldName name){
return findModelField(getTargetFields(), name);
}
public FieldName getTargetFieldName(){
TargetField targetField = getTargetField();
return targetField.getName();
}
@Override
public List getOutputFields(){
if(this.outputResultFields == null){
this.outputResultFields = createOutputFields();
}
return this.outputResultFields;
}
OutputField findOutputField(FieldName name){
return findModelField(getOutputFields(), name);
}
protected EvaluationException createMiningSchemaException(String message){
M model = getModel();
MiningSchema miningSchema = model.getMiningSchema();
return new EvaluationException(message, miningSchema);
}
@Override
public void verify(){
M model = getModel();
ModelVerification modelVerification = model.getModelVerification();
if(modelVerification == null){
return;
}
VerificationBatch batch = CacheUtil.getValue(modelVerification, ModelEvaluator.batchCache);
List extends Map> records = batch.getRecords();
List inputFields = getInputFields();
if(this instanceof HasGroupFields){
HasGroupFields hasGroupFields = (HasGroupFields)this;
records = EvaluatorUtil.groupRows(hasGroupFields, records);
}
List targetFields = getTargetFields();
List outputFields = getOutputFields();
SetView intersection = Sets.intersection(batch.keySet(), new LinkedHashSet<>(EvaluatorUtil.getNames(outputFields)));
for(Map record : records){
Map arguments = new LinkedHashMap<>();
for(InputField inputField : inputFields){
FieldName name = inputField.getName();
FieldValue value = inputField.prepare(record.get(name));
arguments.put(name, value);
}
Map result = evaluate(arguments);
// "If there exist VerificationField elements that refer to OutputField elements,
// then any VerificationField element that refers to a MiningField element whose "usageType=target" should be ignored,
// because they are considered to represent a dependent variable from the training data set, not an expected output"
if(intersection.size() > 0){
for(OutputField outputField : outputFields){
FieldName name = outputField.getName();
VerificationField verificationField = batch.get(name);
if(verificationField == null){
continue;
}
verify(record.get(name), result.get(name), verificationField.getPrecision(), verificationField.getZeroThreshold());
}
} else
// "If there are no such VerificationField elements,
// then any VerificationField element that refers to a MiningField element whose "usageType=target" should be considered to represent an expected output"
{
for(TargetField targetField : targetFields){
FieldName name = targetField.getName();
VerificationField verificationField = batch.get(name);
if(verificationField == null){
continue;
}
verify(record.get(name), EvaluatorUtil.decode(result.get(name)), verificationField.getPrecision(), verificationField.getZeroThreshold());
}
}
}
}
private void verify(Object expected, Object actual, double precision, double zeroThreshold){
if(expected == null){
return;
} // End if
if(!(actual instanceof Collection)){
DataType dataType = TypeUtil.getDataType(actual);
expected = TypeUtil.parseOrCast(dataType, expected);
}
boolean acceptable = VerificationUtil.acceptable(expected, actual, precision, zeroThreshold);
if(!acceptable){
throw new EvaluationException("Values " + PMMLException.formatValue(expected) + " and " + PMMLException.formatValue(actual) + " do not match");
}
}
@Override
public Map evaluate(Map arguments){
ModelEvaluationContext context = new ModelEvaluationContext(this);
context.setArguments(arguments);
return evaluate(context);
}
protected Field> resolveField(FieldName name){
Field> result = getDataField(name);
if(result == null){
result = resolveDerivedField(name);
}
return result;
}
protected DerivedField resolveDerivedField(FieldName name){
DerivedField result = getDerivedField(name);
if(result == null){
result = getLocalDerivedField(name);
}
return result;
}
protected List createInputFields(){
List inputFields = getActiveFields();
List outputFields = getOutputFields();
if(outputFields.size() > 0){
List targetReferenceFields = null;
for(OutputField outputField : outputFields){
org.dmg.pmml.OutputField pmmlOutputField = outputField.getOutputField();
if(!(pmmlOutputField.getResultFeature()).equals(ResultFeature.RESIDUAL)){
continue;
}
int depth = outputField.getDepth();
if(depth > 0){
throw new UnsupportedElementException(pmmlOutputField);
}
FieldName targetFieldName = pmmlOutputField.getTargetField();
if(targetFieldName == null){
targetFieldName = getTargetFieldName();
}
DataField dataField = getDataField(targetFieldName);
if(dataField == null){
throw new MissingFieldException(targetFieldName, pmmlOutputField);
}
MiningField miningField = getMiningField(targetFieldName);
if(miningField == null){
throw new InvisibleFieldException(targetFieldName, pmmlOutputField);
}
Target target = getTarget(targetFieldName);
TargetReferenceField targetReferenceField = new TargetReferenceField(dataField, miningField, target);
if(targetReferenceFields == null){
targetReferenceFields = new ArrayList<>();
}
targetReferenceFields.add(targetReferenceField);
}
if(targetReferenceFields != null && targetReferenceFields.size() > 0){
inputFields = ImmutableList.copyOf(Iterables.concat(inputFields, targetReferenceFields));
}
}
return inputFields;
}
protected List createInputFields(MiningField.UsageType usageType){
M model = getModel();
MiningSchema miningSchema = model.getMiningSchema();
List inputFields = new ArrayList<>();
if(miningSchema.hasMiningFields()){
List miningFields = miningSchema.getMiningFields();
for(MiningField miningField : miningFields){
FieldName name = miningField.getName();
if(!(miningField.getUsageType()).equals(usageType)){
continue;
}
Field> field = getDataField(name);
if(field == null){
field = new VariableField(name);
}
InputField inputField = new InputField(field, miningField);
inputFields.add(inputField);
}
}
return ImmutableList.copyOf(inputFields);
}
protected List createTargetFields(){
M model = getModel();
MiningSchema miningSchema = model.getMiningSchema();
List targetFields = new ArrayList<>();
if(miningSchema.hasMiningFields()){
List miningFields = miningSchema.getMiningFields();
for(MiningField miningField : miningFields){
FieldName name = miningField.getName();
MiningField.UsageType usageType = miningField.getUsageType();
switch(usageType){
case TARGET:
case PREDICTED:
break;
default:
continue;
}
DataField dataField = getDataField(name);
if(dataField == null){
throw new MissingFieldException(name, miningField);
}
Target target = getTarget(name);
TargetField targetField = new TargetField(dataField, miningField, target);
targetFields.add(targetField);
}
}
synthesis:
if(targetFields.isEmpty()){
DataField dataField = getDataField();
if(dataField == null){
break synthesis;
}
Target target = getTarget(dataField.getName());
TargetField targetField = new TargetField(dataField, null, target);
targetFields.add(targetField);
}
return ImmutableList.copyOf(targetFields);
}
protected List createOutputFields(){
M model = getModel();
Output output = model.getOutput();
List resultFields = new ArrayList<>();
if(output != null && output.hasOutputFields()){
List outputFields = output.getOutputFields();
for(org.dmg.pmml.OutputField outputField : outputFields){
OutputField resultField = new OutputField(outputField);
resultFields.add(resultField);
}
}
return ImmutableList.copyOf(resultFields);
}
protected M ensureScorableModel(){
M model = getModel();
if(!model.isScorable()){
throw new EvaluationException("Model is not scorable", model);
}
return model;
}
public V getValue(LoadingCache cache){
M model = getModel();
return CacheUtil.getValue(model, cache);
}
public V getValue(Cache cache, Callable extends V> loader){
M model = getModel();
return CacheUtil.getValue(model, cache, loader);
}
protected ModelEvaluatorFactory ensureModelEvaluatorFactory(){
ModelEvaluatorFactory modelEvaluatorFactory = getModelEvaluatorFactory();
if(modelEvaluatorFactory == null){
modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
setModelEvaluatorFactory(modelEvaluatorFactory);
}
return modelEvaluatorFactory;
}
protected ValueFactoryFactory ensureValueFactoryFactory(){
ValueFactoryFactory valueFactoryFactory = getValueFactoryFactory();
if(valueFactoryFactory == null){
ModelEvaluatorFactory modelEvaluatorFactory = ensureModelEvaluatorFactory();
valueFactoryFactory = modelEvaluatorFactory.getValueFactoryFactory();
if(valueFactoryFactory == null){
valueFactoryFactory = ValueFactoryFactory.newInstance();
}
setValueFactoryFactory(valueFactoryFactory);
}
return valueFactoryFactory;
}
protected ValueFactory> ensureValueFactory(){
ValueFactory> valueFactory = getValueFactory();
if(valueFactory == null){
ValueFactoryFactory valueFactoryFactory = ensureValueFactoryFactory();
MathContext mathContext = getMathContext();
valueFactory = valueFactoryFactory.newValueFactory(mathContext);
setValueFactory(valueFactory);
}
return valueFactory;
}
public PMML getPMML(){
return this.pmml;
}
private void setPMML(PMML pmml){
this.pmml = pmml;
}
public M getModel(){
return this.model;
}
private void setModel(M model){
this.model = model;
}
public ModelEvaluatorFactory getModelEvaluatorFactory(){
return this.modelEvaluatorFactory;
}
private void setModelEvaluatorFactory(ModelEvaluatorFactory modelEvaluatorFactory){
this.modelEvaluatorFactory = modelEvaluatorFactory;
}
public ValueFactoryFactory getValueFactoryFactory(){
return this.valueFactoryFactory;
}
private void setValueFactoryFactory(ValueFactoryFactory valueFactoryFactory){
this.valueFactoryFactory = valueFactoryFactory;
}
public ValueFactory> getValueFactory(){
return this.valueFactory;
}
private void setValueFactory(ValueFactory> valueFactory){
this.valueFactory = valueFactory;
}
static
protected M selectModel(PMML pmml, Class extends M> clazz){
if(!pmml.hasModels()){
throw new MissingElementException(MissingElementException.formatMessage(XPathUtil.formatElement(pmml.getClass()) + "/" + XPathUtil.formatElement(clazz)), pmml);
}
List models = pmml.getModels();
Iterable extends M> filteredModels = Iterables.filter(models, clazz);
M model = Iterables.getFirst(filteredModels, null);
if(model == null){
throw new MissingElementException(MissingElementException.formatMessage(XPathUtil.formatElement(pmml.getClass()) + "/" + XPathUtil.formatElement(clazz)), pmml);
}
return model;
}
static
private F findModelField(Collection extends F> fields, FieldName name){
for(F field : fields){
if(Objects.equals(field.getName(), name)){
return field;
}
}
return null;
}
static
private VerificationBatch parseModelVerification(ModelVerification modelVerification){
VerificationBatch result = new VerificationBatch();
VerificationFields verificationFields = modelVerification.getVerificationFields();
if(verificationFields == null){
throw new MissingElementException(modelVerification, PMMLElements.MODELVERIFICATION_VERIFICATIONFIELDS);
}
for(VerificationField verificationField : verificationFields){
result.put(verificationField.getField(), verificationField);
}
InlineTable inlineTable = modelVerification.getInlineTable();
if(inlineTable == null){
throw new MissingElementException(modelVerification, PMMLElements.MODELVERIFICATION_INLINETABLE);
}
Table table = InlineTableUtil.getContent(inlineTable);
List
© 2015 - 2025 Weber Informatics LLC | Privacy Policy