org.jpmml.evaluator.ExpressionUtil Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of pmml-evaluator Show documentation
Show all versions of pmml-evaluator Show documentation
JPMML class model evaluator
/*
* Copyright (c) 2013 University of Tartu
*/
package org.jpmml.evaluator;
import java.util.*;
import org.jpmml.manager.*;
import org.dmg.pmml.*;
import com.google.common.base.*;
import com.google.common.collect.*;
public class ExpressionUtil {
private ExpressionUtil(){
}
static
public FieldValue evaluate(FieldName name, EvaluationContext context){
Map.Entry entry = context.getArgumentEntry(name);
if(entry == null){
DerivedField derivedField = context.resolveField(name);
if(derivedField == null){
return null;
}
return evaluate(derivedField, context);
}
return entry.getValue();
}
static
public FieldValue evaluate(DerivedField derivedField, EvaluationContext context){
FieldValue value = evaluate(derivedField.getExpression(), context);
return FieldValueUtil.refine(derivedField, value);
}
static
public FieldValue evaluate(Expression expression, EvaluationContext context){
if(expression instanceof Constant){
return evaluateConstant((Constant)expression, context);
} else
if(expression instanceof FieldRef){
return evaluateFieldRef((FieldRef)expression, context);
} else
if(expression instanceof NormContinuous){
return evaluateNormContinuous((NormContinuous)expression, context);
} else
if(expression instanceof NormDiscrete){
return evaluateNormDiscrete((NormDiscrete)expression, context);
} else
if(expression instanceof Discretize){
return evaluateDiscretize((Discretize)expression, context);
} else
if(expression instanceof MapValues){
return evaluateMapValues((MapValues)expression, context);
} else
if(expression instanceof Apply){
return evaluateApply((Apply)expression, context);
} else
if(expression instanceof Aggregate){
return evaluateAggregate((Aggregate)expression, context);
}
throw new UnsupportedFeatureException(expression);
}
static
public FieldValue evaluateConstant(Constant constant, EvaluationContext context){
String value = constant.getValue();
DataType dataType = constant.getDataType();
if(dataType == null){
dataType = TypeUtil.getConstantDataType(value);
}
return FieldValueUtil.create(dataType, null, value);
}
static
public FieldValue evaluateFieldRef(FieldRef fieldRef, EvaluationContext context){
FieldValue value = evaluate(fieldRef.getField(), context);
if(value == null){
return FieldValueUtil.create(fieldRef.getMapMissingTo());
}
return value;
}
static
public FieldValue evaluateNormContinuous(NormContinuous normContinuous, EvaluationContext context){
FieldValue value = evaluate(normContinuous.getField(), context);
if(value == null){
return FieldValueUtil.create(normContinuous.getMapMissingTo());
}
return NormalizationUtil.normalize(normContinuous, value);
}
static
public FieldValue evaluateNormDiscrete(NormDiscrete normDiscrete, EvaluationContext context){
FieldValue value = evaluate(normDiscrete.getField(), context);
if(value == null){
return FieldValueUtil.create(normDiscrete.getMapMissingTo());
}
boolean equals = value.equalsString(normDiscrete.getValue());
return FieldValueUtil.create(equals ? 1d : 0d);
}
static
public FieldValue evaluateDiscretize(Discretize discretize, EvaluationContext context){
FieldValue value = evaluate(discretize.getField(), context);
if(value == null){
return FieldValueUtil.create(discretize.getDataType(), null, discretize.getMapMissingTo());
}
return DiscretizationUtil.discretize(discretize, value);
}
static
public FieldValue evaluateMapValues(MapValues mapValues, EvaluationContext context){
Map values = Maps.newLinkedHashMap();
List fieldColumnPairs = mapValues.getFieldColumnPairs();
for(FieldColumnPair fieldColumnPair : fieldColumnPairs){
FieldValue value = evaluate(fieldColumnPair.getField(), context);
if(value == null){
return FieldValueUtil.create(mapValues.getDataType(), null, mapValues.getMapMissingTo());
}
values.put(fieldColumnPair.getColumn(), value);
}
return DiscretizationUtil.mapValue(mapValues, values);
}
static
public FieldValue evaluateApply(Apply apply, EvaluationContext context){
List values = Lists.newArrayList();
List arguments = apply.getExpressions();
for(Expression argument : arguments){
FieldValue value = evaluate(argument, context);
values.add(value);
}
FieldValue result;
try {
result = FunctionUtil.evaluate(apply, values, context);
} catch(InvalidResultException ire){
InvalidValueTreatmentMethodType invalidValueTreatmentMethod = apply.getInvalidValueTreatment();
switch(invalidValueTreatmentMethod){
case RETURN_INVALID:
throw new InvalidResultException(apply);
case AS_IS:
// Re-throw the given InvalidResultException instance
throw ire;
case AS_MISSING:
return FieldValueUtil.create(apply.getMapMissingTo());
default:
throw new UnsupportedFeatureException(apply, invalidValueTreatmentMethod);
}
}
if(result == null){
return FieldValueUtil.create(apply.getMapMissingTo());
}
return result;
}
@SuppressWarnings (
value = {"rawtypes", "unchecked"}
)
static
public FieldValue evaluateAggregate(Aggregate aggregate, EvaluationContext context){
FieldValue value = evaluate(aggregate.getField(), context);
Collection> values;
// The JPMML library operates with single records, so it's impossible to implement "proper" aggregation over multiple records
// It is assumed that the aggregation has been performed by application developer beforehand
try {
values = (Collection>)FieldValueUtil.getValue(value);
} catch(ClassCastException cce){
throw new TypeCheckException(Collection.class, value);
}
FieldName groupName = aggregate.getGroupField();
if(groupName != null){
FieldValue groupValue = evaluate(groupName, context);
// Ensure that the group value is a simple type, not a collection type
TypeUtil.getDataType(FieldValueUtil.getValue(groupValue));
}
// Remove missing values
values = Lists.newArrayList(Iterables.filter(values, Predicates.notNull()));
Aggregate.Function function = aggregate.getFunction();
switch(function){
case COUNT:
return FieldValueUtil.create(values.size());
case SUM:
return FunctionUtil.evaluate(new Apply("sum"), createValues(values), context);
case AVERAGE:
return FunctionUtil.evaluate(new Apply("avg"), createValues(values), context);
case MIN:
return FieldValueUtil.create(Collections.min((List)values));
case MAX:
return FieldValueUtil.create(Collections.max((List)values));
default:
throw new UnsupportedFeatureException(aggregate, function);
}
}
static
private List createValues(Collection> values){
Function
© 2015 - 2025 Weber Informatics LLC | Privacy Policy