![JAR search and dependency download from the Maven repository](/logo.png)
org.jpmml.evaluator.ModelManager Maven / Gradle / Ivy
/*
* Copyright (c) 2020 Villu Ruusmann
*
* This file is part of JPMML-Evaluator
*
* JPMML-Evaluator is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-Evaluator is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-Evaluator. If not, see .
*/
package org.jpmml.evaluator;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import com.google.common.base.Function;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DefineFunction;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.LocalTransformations;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLAttributes;
import org.dmg.pmml.PMMLElements;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.Target;
import org.dmg.pmml.Targets;
import org.dmg.pmml.TransformationDictionary;
import org.jpmml.model.XPathUtil;
import org.jpmml.model.visitors.FieldResolver;
abstract
public class ModelManager implements HasModel, Serializable {
private PMML pmml = null;
private M model = null;
private DataField defaultDataField = null;
private Map dataFields = Collections.emptyMap();
private Map derivedFields = Collections.emptyMap();
private Map defineFunctions = Collections.emptyMap();
private Map miningFields = Collections.emptyMap();
private Map localDerivedFields = Collections.emptyMap();
private Map targets = Collections.emptyMap();
private Map outputFields = Collections.emptyMap();
private Set resultFeatures = Collections.emptySet();
private List inputFields = null;
private List activeInputFields = null;
private List targetResultFields = null;
private List outputResultFields = null;
private ListMultimap> visibleFields = null;
protected ModelManager(){
}
protected ModelManager(PMML pmml, M model){
setPMML(Objects.requireNonNull(pmml));
setModel(Objects.requireNonNull(model));
DataDictionary dataDictionary = pmml.getDataDictionary();
if(dataDictionary == null){
throw new MissingElementException(pmml, PMMLElements.PMML_DATADICTIONARY);
} // End if
if(dataDictionary.hasDataFields()){
this.dataFields = ImmutableMap.copyOf(IndexableUtil.buildMap(dataDictionary.getDataFields(), PMMLAttributes.DATAFIELD_NAME));
}
TransformationDictionary transformationDictionary = pmml.getTransformationDictionary();
if(transformationDictionary != null && transformationDictionary.hasDerivedFields()){
this.derivedFields = ImmutableMap.copyOf(IndexableUtil.buildMap(transformationDictionary.getDerivedFields(), PMMLAttributes.DERIVEDFIELD_NAME));
} // End if
if(transformationDictionary != null && transformationDictionary.hasDefineFunctions()){
this.defineFunctions = ImmutableMap.copyOf(IndexableUtil.buildMap(transformationDictionary.getDefineFunctions(), PMMLAttributes.DEFINEFUNCTION_NAME));
}
MiningFunction miningFunction = model.getMiningFunction();
if(miningFunction == null){
throw new MissingAttributeException(MissingAttributeException.formatMessage(XPathUtil.formatElement(model.getClass()) + "@functionName"), model);
}
MiningSchema miningSchema = model.getMiningSchema();
if(miningSchema == null){
throw new MissingElementException(MissingElementException.formatMessage(XPathUtil.formatElement(model.getClass()) + "/" + XPathUtil.formatElement(MiningSchema.class)), model);
} // End if
if(miningSchema.hasMiningFields()){
this.miningFields = ImmutableMap.copyOf(IndexableUtil.buildMap(miningSchema.getMiningFields(), PMMLAttributes.MININGFIELD_NAME));
}
LocalTransformations localTransformations = model.getLocalTransformations();
if(localTransformations != null && localTransformations.hasDerivedFields()){
this.localDerivedFields = ImmutableMap.copyOf(IndexableUtil.buildMap(localTransformations.getDerivedFields(), PMMLAttributes.DERIVEDFIELD_NAME));
}
Targets targets = model.getTargets();
if(targets != null && targets.hasTargets()){
// Cannot use Guava's ImmutableMap, because it is null-hostile
this.targets = Collections.unmodifiableMap(IndexableUtil.buildMap(targets.getTargets(), PMMLAttributes.TARGET_FIELD, true));
}
Output output = model.getOutput();
if(output != null && output.hasOutputFields()){
this.outputFields = ImmutableMap.copyOf(IndexableUtil.buildMap(output.getOutputFields(), PMMLAttributes.OUTPUTFIELD_NAME));
this.resultFeatures = Sets.immutableEnumSet(collectResultFeatures(output));
}
}
public MiningFunction getMiningFunction(){
M model = getModel();
return model.getMiningFunction();
}
public MathContext getMathContext(){
M model = getModel();
return model.getMathContext();
}
public DataField getDataField(FieldName name){
if(Objects.equals(Evaluator.DEFAULT_TARGET_NAME, name)){
return getDefaultDataField();
}
return this.dataFields.get(name);
}
/**
* @return A synthetic {@link DataField} element describing the default target field.
*/
public DataField getDefaultDataField(){
if(this.defaultDataField != null){
return this.defaultDataField;
}
MiningFunction miningFunction = getMiningFunction();
switch(miningFunction){
case REGRESSION:
MathContext mathContext = getMathContext();
switch(mathContext){
case FLOAT:
return ModelManager.DEFAULT_TARGET_CONTINUOUS_FLOAT;
default:
return ModelManager.DEFAULT_TARGET_CONTINUOUS_DOUBLE;
}
case CLASSIFICATION:
case CLUSTERING:
return ModelManager.DEFAULT_TARGET_CATEGORICAL_STRING;
default:
return null;
}
}
public void setDefaultDataField(DataField defaultDataField){
this.defaultDataField = defaultDataField;
}
public DerivedField getDerivedField(FieldName name){
return this.derivedFields.get(name);
}
public DefineFunction getDefineFunction(String name){
return this.defineFunctions.get(name);
}
public MiningField getMiningField(FieldName name){
if(Objects.equals(Evaluator.DEFAULT_TARGET_NAME, name)){
return null;
}
return this.miningFields.get(name);
}
protected boolean hasLocalDerivedFields(){
return !this.localDerivedFields.isEmpty();
}
public DerivedField getLocalDerivedField(FieldName name){
return this.localDerivedFields.get(name);
}
public Target getTarget(FieldName name){
return this.targets.get(name);
}
protected boolean hasOutputFields(){
return !this.outputFields.isEmpty();
}
public org.dmg.pmml.OutputField getOutputField(FieldName name){
return this.outputFields.get(name);
}
/**
*
* Indicates if this model evaluator provides the specified result feature.
*
*
*
* A result feature is first and foremost manifested through output fields.
* However, selected result features may make a secondary manifestation through a target field.
*
*
* @see org.dmg.pmml.OutputField#getResultFeature()
*/
public boolean hasResultFeature(ResultFeature resultFeature){
Set resultFeatures = getResultFeatures();
return resultFeatures.contains(resultFeature);
}
public void addResultFeatures(Set resultFeatures){
this.resultFeatures = Sets.immutableEnumSet(Iterables.concat(this.resultFeatures, resultFeatures));
}
protected Set getResultFeatures(){
return this.resultFeatures;
}
public List getInputFields(){
if(this.inputFields == null){
List inputFields = filterInputFields(createInputFields());
this.inputFields = ImmutableList.copyOf(inputFields);
}
return this.inputFields;
}
public List getActiveFields(){
if(this.activeInputFields == null){
List activeInputFields = filterInputFields(createInputFields(MiningField.UsageType.ACTIVE));
this.activeInputFields = ImmutableList.copyOf(activeInputFields);
}
return this.activeInputFields;
}
public List getTargetFields(){
if(this.targetResultFields == null){
List targetResultFields = filterTargetFields(createTargetFields());
this.targetResultFields = ImmutableList.copyOf(targetResultFields);
}
return this.targetResultFields;
}
public TargetField getTargetField(){
List targetFields = getTargetFields();
if(targetFields.size() != 1){
throw createMiningSchemaException("Expected 1 target field, got " + targetFields.size() + " target fields");
}
TargetField targetField = targetFields.get(0);
return targetField;
}
public FieldName getTargetName(){
TargetField targetField = getTargetField();
return targetField.getFieldName();
}
TargetField findTargetField(FieldName name){
List targetFields = getTargetFields();
for(TargetField targetField : targetFields){
if(Objects.equals(targetField.getFieldName(), name)){
return targetField;
}
}
return null;
}
public List getOutputFields(){
if(this.outputResultFields == null){
List outputResultFields = filterOutputFields(createOutputFields());
this.outputResultFields = ImmutableList.copyOf(outputResultFields);
}
return this.outputResultFields;
}
protected void resetInputFields(){
this.inputFields = null;
this.activeInputFields = null;
}
protected void resetResultFields(){
this.targetResultFields = null;
this.outputResultFields = null;
}
protected Field> resolveField(FieldName name){
ListMultimap> visibleFields = getVisibleFields();
List> fields = visibleFields.get(name);
if(fields.isEmpty()){
return null;
} else
if(fields.size() == 1){
return fields.get(0);
} else
{
throw new DuplicateFieldException(name);
}
}
protected ListMultimap> getVisibleFields(){
if(this.visibleFields == null){
this.visibleFields = collectVisibleFields();
}
return this.visibleFields;
}
protected EvaluationException createMiningSchemaException(String message){
M model = getModel();
MiningSchema miningSchema = model.getMiningSchema();
return new EvaluationException(message, miningSchema);
}
protected List createInputFields(){
List inputFields = getActiveFields();
List outputFields = getOutputFields();
if(!outputFields.isEmpty()){
List expandedInputFields = null;
for(OutputField outputField : outputFields){
org.dmg.pmml.OutputField pmmlOutputField = outputField.getField();
if(!(ResultFeature.RESIDUAL).equals(pmmlOutputField.getResultFeature())){
continue;
}
int depth = outputField.getDepth();
if(depth > 0){
throw new UnsupportedElementException(pmmlOutputField);
}
FieldName targetName = pmmlOutputField.getTargetField();
if(targetName == null){
targetName = getTargetName();
}
DataField dataField = getDataField(targetName);
if(dataField == null){
throw new MissingFieldException(targetName, pmmlOutputField);
}
MiningField miningField = getMiningField(targetName);
if(miningField == null){
throw new InvisibleFieldException(targetName, pmmlOutputField);
}
ResidualInputField residualInputField = new ResidualInputField(dataField, miningField);
if(expandedInputFields == null){
expandedInputFields = new ArrayList<>(inputFields);
}
expandedInputFields.add(residualInputField);
}
if(expandedInputFields != null){
return expandedInputFields;
}
}
return inputFields;
}
protected List createInputFields(MiningField.UsageType usageType){
M model = getModel();
MiningSchema miningSchema = model.getMiningSchema();
List inputFields = new ArrayList<>();
if(miningSchema.hasMiningFields()){
List miningFields = miningSchema.getMiningFields();
for(MiningField miningField : miningFields){
FieldName name = miningField.getName();
if(!(miningField.getUsageType()).equals(usageType)){
continue;
}
Field> field = getDataField(name);
if(field == null){
field = new VariableField(name);
}
InputField inputField = new InputField(field, miningField);
inputFields.add(inputField);
}
}
return inputFields;
}
protected List filterInputFields(List inputFields){
return inputFields;
}
protected List createTargetFields(){
M model = getModel();
MiningSchema miningSchema = model.getMiningSchema();
List targetFields = new ArrayList<>();
if(miningSchema.hasMiningFields()){
List miningFields = miningSchema.getMiningFields();
for(MiningField miningField : miningFields){
FieldName name = miningField.getName();
MiningField.UsageType usageType = miningField.getUsageType();
switch(usageType){
case PREDICTED:
case TARGET:
break;
default:
continue;
}
DataField dataField = getDataField(name);
if(dataField == null){
throw new MissingFieldException(name, miningField);
}
Target target = getTarget(name);
TargetField targetField = new TargetField(dataField, miningField, target);
targetFields.add(targetField);
}
}
synthesis:
if(targetFields.isEmpty()){
DataField dataField = getDefaultDataField();
if(dataField == null){
break synthesis;
}
Target target = getTarget(dataField.getName());
TargetField targetField = new DefaultTargetField(dataField, target);
targetFields.add(targetField);
}
return targetFields;
}
protected List filterTargetFields(List targetFields){
return targetFields;
}
protected List createOutputFields(){
M model = getModel();
Output output = model.getOutput();
List outputFields = new ArrayList<>();
if(output != null && output.hasOutputFields()){
List pmmlOutputFields = output.getOutputFields();
for(org.dmg.pmml.OutputField pmmlOutputField : pmmlOutputFields){
OutputField outputField = new OutputField(pmmlOutputField);
outputFields.add(outputField);
}
}
return outputFields;
}
protected List filterOutputFields(List outputFields){
return outputFields;
}
private ListMultimap> collectVisibleFields(){
PMML pmml = getPMML();
Model model = getModel();
ListMultimap> visibleFields = ArrayListMultimap.create();
FieldResolver fieldResolver = new FieldResolver(){
@Override
public PMMLObject popParent(){
PMMLObject parent = super.popParent();
if(Objects.equals(model, parent)){
Model model = (Model)parent;
Collection> fields = getFields(model);
for(Field> field : fields){
visibleFields.put(field.getName(), field);
}
}
return parent;
}
};
fieldResolver.applyTo(pmml);
return ImmutableListMultimap.copyOf(visibleFields);
}
@Override
public PMML getPMML(){
return this.pmml;
}
private void setPMML(PMML pmml){
this.pmml = pmml;
}
@Override
public M getModel(){
return this.model;
}
private void setModel(M model){
this.model = model;
}
static
protected Set collectResultFeatures(Output output){
Set result = EnumSet.noneOf(ResultFeature.class);
if(output != null && output.hasOutputFields()){
List pmmlOutputFields = output.getOutputFields();
for(org.dmg.pmml.OutputField pmmlOutputField : pmmlOutputFields){
String segmentId = pmmlOutputField.getSegmentId();
if(segmentId != null){
continue;
}
result.add(pmmlOutputField.getResultFeature());
}
}
return result;
}
static
protected Map> collectSegmentResultFeatures(Output output){
Map> result = new LinkedHashMap<>();
List pmmlOutputFields = output.getOutputFields();
for(org.dmg.pmml.OutputField pmmlOutputField : pmmlOutputFields){
String segmentId = pmmlOutputField.getSegmentId();
if(segmentId == null){
continue;
}
Set resultFeatures = result.get(segmentId);
if(resultFeatures == null){
resultFeatures = EnumSet.noneOf(ResultFeature.class);
result.put(segmentId, resultFeatures);
}
resultFeatures.add(pmmlOutputField.getResultFeature());
}
return result;
}
static
protected Map> toImmutableListMap(Map> map){
Function, ImmutableList> function = new Function, ImmutableList>(){
@Override
public ImmutableList apply(List list){
return ImmutableList.copyOf(list);
}
};
return Maps.transformValues(map, function);
}
static
protected Map> toImmutableSetMap(Map> map){
Function, ImmutableSet> function = new Function, ImmutableSet>(){
@Override
public ImmutableSet apply(Set set){
if(set instanceof EnumSet){
EnumSet> enumSet = (EnumSet>)set;
return (ImmutableSet)Sets.immutableEnumSet(enumSet);
}
return ImmutableSet.copyOf(set);
}
};
return Maps.transformValues(map, function);
}
static
protected Map> toImmutableMapMap(Map> map){
Function
© 2015 - 2025 Weber Informatics LLC | Privacy Policy