All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.evaluator.ExpressionUtil Maven / Gradle / Ivy

There is a newer version: 1.6.11
Show newest version
/*
 * Copyright (c) 2013 University of Tartu
 */
package org.jpmml.evaluator;

import java.util.*;

import org.jpmml.manager.*;

import org.dmg.pmml.*;

import com.google.common.base.*;
import com.google.common.collect.*;

public class ExpressionUtil {

	private ExpressionUtil(){
	}

	static
	public FieldValue evaluate(FieldName name, EvaluationContext context){
		Map.Entry entry = context.getArgumentEntry(name);
		if(entry == null){
			DerivedField derivedField = context.resolveField(name);
			if(derivedField == null){
				return null;
			}

			return evaluate(derivedField, context);
		}

		return entry.getValue();
	}

	static
	public FieldValue evaluate(DerivedField derivedField, EvaluationContext context){
		FieldValue value = evaluate(derivedField.getExpression(), context);

		return FieldValueUtil.refine(derivedField, value);
	}

	static
	public FieldValue evaluate(Expression expression, EvaluationContext context){

		if(expression instanceof Constant){
			return evaluateConstant((Constant)expression, context);
		} else

		if(expression instanceof FieldRef){
			return evaluateFieldRef((FieldRef)expression, context);
		} else

		if(expression instanceof NormContinuous){
			return evaluateNormContinuous((NormContinuous)expression, context);
		} else

		if(expression instanceof NormDiscrete){
			return evaluateNormDiscrete((NormDiscrete)expression, context);
		} else

		if(expression instanceof Discretize){
			return evaluateDiscretize((Discretize)expression, context);
		} else

		if(expression instanceof MapValues){
			return evaluateMapValues((MapValues)expression, context);
		} else

		if(expression instanceof Apply){
			return evaluateApply((Apply)expression, context);
		} else

		if(expression instanceof Aggregate){
			return evaluateAggregate((Aggregate)expression, context);
		}

		throw new UnsupportedFeatureException(expression);
	}

	static
	public FieldValue evaluateConstant(Constant constant, EvaluationContext context){
		String value = constant.getValue();

		DataType dataType = constant.getDataType();
		if(dataType == null){
			dataType = TypeUtil.getConstantDataType(value);
		}

		return FieldValueUtil.create(dataType, null, value);
	}

	static
	public FieldValue evaluateFieldRef(FieldRef fieldRef, EvaluationContext context){
		FieldValue value = evaluate(fieldRef.getField(), context);
		if(value == null){
			return FieldValueUtil.create(fieldRef.getMapMissingTo());
		}

		return value;
	}

	static
	public FieldValue evaluateNormContinuous(NormContinuous normContinuous, EvaluationContext context){
		FieldValue value = evaluate(normContinuous.getField(), context);
		if(value == null){
			return FieldValueUtil.create(normContinuous.getMapMissingTo());
		}

		return NormalizationUtil.normalize(normContinuous, value);
	}

	static
	public FieldValue evaluateNormDiscrete(NormDiscrete normDiscrete, EvaluationContext context){
		FieldValue value = evaluate(normDiscrete.getField(), context);
		if(value == null){
			return FieldValueUtil.create(normDiscrete.getMapMissingTo());
		}

		boolean equals = value.equalsString(normDiscrete.getValue());

		return FieldValueUtil.create(equals ? 1d : 0d);
	}

	static
	public FieldValue evaluateDiscretize(Discretize discretize, EvaluationContext context){
		FieldValue value = evaluate(discretize.getField(), context);
		if(value == null){
			return FieldValueUtil.create(discretize.getDataType(), null, discretize.getMapMissingTo());
		}

		return DiscretizationUtil.discretize(discretize, value);
	}

	static
	public FieldValue evaluateMapValues(MapValues mapValues, EvaluationContext context){
		Map values = Maps.newLinkedHashMap();

		List fieldColumnPairs = mapValues.getFieldColumnPairs();
		for(FieldColumnPair fieldColumnPair : fieldColumnPairs){
			FieldValue value = evaluate(fieldColumnPair.getField(), context);
			if(value == null){
				return FieldValueUtil.create(mapValues.getDataType(), null, mapValues.getMapMissingTo());
			}

			values.put(fieldColumnPair.getColumn(), value);
		}

		return DiscretizationUtil.mapValue(mapValues, values);
	}

	static
	public FieldValue evaluateApply(Apply apply, EvaluationContext context){
		List values = Lists.newArrayList();

		List arguments = apply.getExpressions();
		for(Expression argument : arguments){
			FieldValue value = evaluate(argument, context);

			values.add(value);
		}

		FieldValue result;

		try {
			result = FunctionUtil.evaluate(apply, values, context);
		} catch(InvalidResultException ire){
			InvalidValueTreatmentMethodType invalidValueTreatmentMethod = apply.getInvalidValueTreatment();

			switch(invalidValueTreatmentMethod){
				case RETURN_INVALID:
					throw new InvalidResultException(apply);
				case AS_IS:
					// Re-throw the given InvalidResultException instance
					throw ire;
				case AS_MISSING:
					return FieldValueUtil.create(apply.getMapMissingTo());
				default:
					throw new UnsupportedFeatureException(apply, invalidValueTreatmentMethod);
			}
		}

		if(result == null){
			return FieldValueUtil.create(apply.getMapMissingTo());
		}

		return result;
	}

	@SuppressWarnings (
		value = {"rawtypes", "unchecked"}
	)
	static
	public FieldValue evaluateAggregate(Aggregate aggregate, EvaluationContext context){
		FieldValue value = evaluate(aggregate.getField(), context);

		Collection values;

		// The JPMML library operates with single records, so it's impossible to implement "proper" aggregation over multiple records
		// It is assumed that the aggregation has been performed by application developer beforehand
		try {
			values = (Collection)FieldValueUtil.getValue(value);
		} catch(ClassCastException cce){
			throw new TypeCheckException(Collection.class, value);
		}

		FieldName groupName = aggregate.getGroupField();
		if(groupName != null){
			FieldValue groupValue = evaluate(groupName, context);

			// Ensure that the group value is a simple type, not a collection type
			TypeUtil.getDataType(FieldValueUtil.getValue(groupValue));
		}

		// Remove missing values
		values = Lists.newArrayList(Iterables.filter(values, Predicates.notNull()));

		Aggregate.Function function = aggregate.getFunction();
		switch(function){
			case COUNT:
				return FieldValueUtil.create(values.size());
			case SUM:
				return FunctionUtil.evaluate(new Apply("sum"), createValues(values), context);
			case AVERAGE:
				return FunctionUtil.evaluate(new Apply("avg"), createValues(values), context);
			case MIN:
				return FieldValueUtil.create(Collections.min((List)values));
			case MAX:
				return FieldValueUtil.create(Collections.max((List)values));
			default:
				throw new UnsupportedFeatureException(aggregate, function);
		}
	}

	static
	private List createValues(Collection values){
		Function function = new Function(){

			@Override
			public FieldValue apply(Object value){
				return FieldValueUtil.create(value);
			}
		};

		return Lists.newArrayList(Iterables.transform(values, function));
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy