org.jpmml.evaluator.ModelManager Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of pmml-evaluator Show documentation
Show all versions of pmml-evaluator Show documentation
JPMML class model evaluator
The newest version!
/*
* Copyright (c) 2020 Villu Ruusmann
*
* This file is part of JPMML-Evaluator
*
* JPMML-Evaluator is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-Evaluator is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-Evaluator. If not, see .
*/
package org.jpmml.evaluator;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Sets;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Field;
import org.dmg.pmml.LocalTransformations;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.Output;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.Target;
import org.dmg.pmml.Targets;
import org.jpmml.model.UnsupportedElementException;
import org.jpmml.model.visitors.FieldResolver;
abstract
public class ModelManager extends PMMLManager implements HasModel {
private M model = null;
private DefaultDataField defaultDataField = null;
private Map miningFields = Collections.emptyMap();
private Map localDerivedFields = Collections.emptyMap();
private Map targets = Collections.emptyMap();
private Map outputFields = Collections.emptyMap();
private Set resultFeatures = Collections.emptySet();
private List inputFields = null;
private List activeInputFields = null;
private List targetResultFields = null;
private List outputResultFields = null;
private ListMultimap> visibleFields = null;
protected ModelManager(){
}
protected ModelManager(PMML pmml, M model){
super(pmml);
setModel(model);
@SuppressWarnings("unused")
MiningFunction miningFunction = model.requireMiningFunction();
MiningSchema miningSchema = model.requireMiningSchema();
if(miningSchema.hasMiningFields()){
this.miningFields = ImmutableMap.copyOf(IndexableUtil.buildMap(miningSchema.getMiningFields()));
}
LocalTransformations localTransformations = model.getLocalTransformations();
if(localTransformations != null && localTransformations.hasDerivedFields()){
this.localDerivedFields = ImmutableMap.copyOf(IndexableUtil.buildMap(localTransformations.getDerivedFields()));
}
Targets targets = model.getTargets();
if(targets != null && targets.hasTargets()){
// Cannot use Guava's ImmutableMap, because it is null-hostile
this.targets = Collections.unmodifiableMap(IndexableUtil.buildMap(targets.getTargets()));
}
Output output = model.getOutput();
if(output != null && output.hasOutputFields()){
this.outputFields = ImmutableMap.copyOf(IndexableUtil.buildMap(output.getOutputFields()));
this.resultFeatures = Sets.immutableEnumSet(collectResultFeatures(output));
}
}
public MiningFunction getMiningFunction(){
M model = getModel();
return model.requireMiningFunction();
}
public MathContext getMathContext(){
M model = getModel();
return model.getMathContext();
}
@Override
public DataField getDataField(String name){
if(Objects.equals(Evaluator.DEFAULT_TARGET_NAME, name)){
return getDefaultDataField();
}
return super.getDataField(name);
}
/**
* @return A synthetic {@link DataField} element describing the default target field.
*/
public DefaultDataField getDefaultDataField(){
if(this.defaultDataField != null){
return this.defaultDataField;
}
MiningFunction miningFunction = getMiningFunction();
switch(miningFunction){
case REGRESSION:
MathContext mathContext = getMathContext();
switch(mathContext){
case FLOAT:
return DefaultDataField.CONTINUOUS_FLOAT;
default:
return DefaultDataField.CONTINUOUS_DOUBLE;
}
case CLASSIFICATION:
case CLUSTERING:
return DefaultDataField.CATEGORICAL_STRING;
default:
return null;
}
}
public void setDefaultDataField(DefaultDataField defaultDataField){
this.defaultDataField = defaultDataField;
}
public MiningField getMiningField(String name){
if(Objects.equals(Evaluator.DEFAULT_TARGET_NAME, name)){
return null;
}
return this.miningFields.get(name);
}
protected boolean hasLocalDerivedFields(){
return !this.localDerivedFields.isEmpty();
}
public DerivedField getLocalDerivedField(String name){
return this.localDerivedFields.get(name);
}
public Target getTarget(String name){
return this.targets.get(name);
}
protected boolean hasOutputFields(){
return !this.outputFields.isEmpty();
}
public org.dmg.pmml.OutputField getOutputField(String name){
return this.outputFields.get(name);
}
/**
*
* Indicates if this model manager provides the specified result feature.
*
*
*
* A result feature is first and foremost manifested through output fields.
* However, selected result features may make a secondary manifestation through a target field.
*
*
* @see org.dmg.pmml.OutputField#getResultFeature()
*/
public boolean hasResultFeature(ResultFeature resultFeature){
Set resultFeatures = getResultFeatures();
return resultFeatures.contains(resultFeature);
}
public void addResultFeatures(Set resultFeatures){
this.resultFeatures = Sets.immutableEnumSet(Iterables.concat(this.resultFeatures, resultFeatures));
}
protected Set getResultFeatures(){
return this.resultFeatures;
}
public List getInputFields(){
if(this.inputFields == null){
List inputFields = filterInputFields(createInputFields());
this.inputFields = ImmutableList.copyOf(inputFields);
}
return this.inputFields;
}
public List getActiveFields(){
if(this.activeInputFields == null){
List activeInputFields = filterInputFields(createInputFields(MiningField.UsageType.ACTIVE));
this.activeInputFields = ImmutableList.copyOf(activeInputFields);
}
return this.activeInputFields;
}
public List getTargetFields(){
if(this.targetResultFields == null){
List targetResultFields = filterTargetFields(createTargetFields());
this.targetResultFields = ImmutableList.copyOf(targetResultFields);
}
return this.targetResultFields;
}
public TargetField getTargetField(){
List targetFields = getTargetFields();
if(targetFields.size() != 1){
throw createMiningSchemaException("Expected 1 target field, got " + targetFields.size() + " target fields");
}
TargetField targetField = targetFields.get(0);
return targetField;
}
public String getTargetName(){
TargetField targetField = getTargetField();
return targetField.getFieldName();
}
TargetField findTargetField(String name){
List targetFields = getTargetFields();
if(targetFields.size() == 1){
TargetField targetField = targetFields.get(0);
if(Objects.equals(targetField.getFieldName(), name)){
return targetField;
}
} else
{
for(TargetField targetField : targetFields){
if(Objects.equals(targetField.getFieldName(), name)){
return targetField;
}
}
}
return null;
}
public List getOutputFields(){
if(this.outputResultFields == null){
List outputResultFields = filterOutputFields(createOutputFields());
this.outputResultFields = ImmutableList.copyOf(outputResultFields);
}
return this.outputResultFields;
}
protected void resetInputFields(){
this.inputFields = null;
this.activeInputFields = null;
}
protected void resetResultFields(){
this.targetResultFields = null;
this.outputResultFields = null;
}
protected Field> resolveField(String name){
ListMultimap> visibleFields = getVisibleFields();
List> fields = visibleFields.get(name);
if(fields.isEmpty()){
return null;
} else
if(fields.size() == 1){
return fields.get(0);
} else
{
throw new DuplicateFieldException(name);
}
}
protected ListMultimap> getVisibleFields(){
if(this.visibleFields == null){
this.visibleFields = collectVisibleFields();
}
return this.visibleFields;
}
protected EvaluationException createMiningSchemaException(String message){
M model = getModel();
MiningSchema miningSchema = model.requireMiningSchema();
return new EvaluationException(message, miningSchema);
}
protected List createInputFields(){
List inputFields = getActiveFields();
List outputFields = getOutputFields();
if(!outputFields.isEmpty()){
List expandedInputFields = null;
for(OutputField outputField : outputFields){
org.dmg.pmml.OutputField pmmlOutputField = outputField.getField();
if(pmmlOutputField.getResultFeature() != ResultFeature.RESIDUAL){
continue;
}
int depth = outputField.getDepth();
if(depth > 0){
throw new UnsupportedElementException(pmmlOutputField);
}
String targetFieldName = pmmlOutputField.getTargetField();
if(targetFieldName == null){
targetFieldName = getTargetName();
}
DataField dataField = getDataField(targetFieldName);
if(dataField == null){
throw new MissingFieldException(targetFieldName, pmmlOutputField);
}
MiningField miningField = getMiningField(targetFieldName);
if(miningField == null){
throw new InvisibleFieldException(targetFieldName, pmmlOutputField);
}
ResidualInputField residualInputField = new ResidualInputField(dataField, miningField);
if(expandedInputFields == null){
expandedInputFields = new ArrayList<>(inputFields);
}
expandedInputFields.add(residualInputField);
}
if(expandedInputFields != null){
return expandedInputFields;
}
}
return inputFields;
}
protected List createInputFields(MiningField.UsageType usageType){
M model = getModel();
List inputFields = new ArrayList<>();
MiningSchema miningSchema = model.requireMiningSchema();
if(miningSchema.hasMiningFields()){
List miningFields = miningSchema.getMiningFields();
for(MiningField miningField : miningFields){
String fieldName = miningField.requireName();
if(miningField.getUsageType() != usageType){
continue;
}
Field> field = getDataField(fieldName);
if(field == null){
field = new VariableField(fieldName);
}
InputField inputField = new InputField(field, miningField);
inputFields.add(inputField);
}
}
return inputFields;
}
protected List filterInputFields(List inputFields){
return inputFields;
}
protected List createTargetFields(){
M model = getModel();
List targetFields = new ArrayList<>();
MiningSchema miningSchema = model.requireMiningSchema();
if(miningSchema.hasMiningFields()){
List miningFields = miningSchema.getMiningFields();
for(MiningField miningField : miningFields){
String fieldName = miningField.requireName();
MiningField.UsageType usageType = miningField.getUsageType();
switch(usageType){
case PREDICTED:
case TARGET:
break;
default:
continue;
}
DataField dataField = getDataField(fieldName);
if(dataField == null){
throw new MissingFieldException(miningField);
}
Target target = getTarget(fieldName);
TargetField targetField = new TargetField(dataField, miningField, target);
targetFields.add(targetField);
}
}
synthesis:
if(targetFields.isEmpty()){
DefaultDataField defaultDataField = getDefaultDataField();
if(defaultDataField == null){
break synthesis;
}
Target target = getTarget(defaultDataField.requireName());
TargetField targetField = new SyntheticTargetField(defaultDataField, target);
targetFields.add(targetField);
}
return targetFields;
}
protected List filterTargetFields(List targetFields){
return targetFields;
}
protected List createOutputFields(){
M model = getModel();
Output output = model.getOutput();
List outputFields = new ArrayList<>();
if(output != null && output.hasOutputFields()){
List pmmlOutputFields = output.getOutputFields();
for(org.dmg.pmml.OutputField pmmlOutputField : pmmlOutputFields){
OutputField outputField = new OutputField(pmmlOutputField);
outputFields.add(outputField);
}
}
return outputFields;
}
protected List filterOutputFields(List outputFields){
return outputFields;
}
private ListMultimap> collectVisibleFields(){
PMML pmml = getPMML();
Model model = getModel();
ListMultimap> visibleFields = ArrayListMultimap.create();
FieldResolver fieldResolver = new FieldResolver(){
@Override
public PMMLObject popParent(){
PMMLObject parent = super.popParent();
if(Objects.equals(model, parent)){
Model model = (Model)parent;
Collection> fields = getFields(model);
for(Field> field : fields){
visibleFields.put(field.requireName(), field);
}
}
return parent;
}
};
fieldResolver.applyTo(pmml);
return ImmutableListMultimap.copyOf(visibleFields);
}
@Override
public M getModel(){
return this.model;
}
private void setModel(M model){
this.model = Objects.requireNonNull(model);
}
static
protected Set collectResultFeatures(Output output){
Set result = EnumSet.noneOf(ResultFeature.class);
if(output != null && output.hasOutputFields()){
List pmmlOutputFields = output.getOutputFields();
for(org.dmg.pmml.OutputField pmmlOutputField : pmmlOutputFields){
String segmentId = pmmlOutputField.getSegmentId();
if(segmentId != null){
continue;
}
result.add(pmmlOutputField.getResultFeature());
}
}
return result;
}
static
protected Map> collectSegmentResultFeatures(Output output){
Map> result = new LinkedHashMap<>();
List pmmlOutputFields = output.getOutputFields();
for(org.dmg.pmml.OutputField pmmlOutputField : pmmlOutputFields){
String segmentId = pmmlOutputField.getSegmentId();
if(segmentId == null){
continue;
}
Set resultFeatures = result.get(segmentId);
if(resultFeatures == null){
resultFeatures = EnumSet.noneOf(ResultFeature.class);
result.put(segmentId, resultFeatures);
}
resultFeatures.add(pmmlOutputField.getResultFeature());
}
return result;
}
}