All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.evaluator.OutputUtil Maven / Gradle / Ivy

There is a newer version: 1.7.2
Show newest version
/*
 * Copyright (c) 2013 Villu Ruusmann
 *
 * This file is part of JPMML-Evaluator
 *
 * JPMML-Evaluator is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-Evaluator is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-Evaluator.  If not, see .
 */
package org.jpmml.evaluator;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import com.google.common.base.Joiner;
import com.google.common.collect.BiMap;
import com.google.common.collect.Ordering;
import org.dmg.pmml.AssociationRule;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Item;
import org.dmg.pmml.ItemRef;
import org.dmg.pmml.Itemset;
import org.dmg.pmml.MiningModel;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.ResultFeatureType;
import org.dmg.pmml.RuleFeatureType;
import org.dmg.pmml.Target;
import org.dmg.pmml.TargetValue;
import org.dmg.pmml.Value;

public class OutputUtil {

	private OutputUtil(){
	}

	/**
	 * Evaluates the {@link Output} element.
	 *
	 * @param predictions Map of {@link Evaluator#getTargetFields() target field} values.
	 *
	 * @return Map of {@link Evaluator#getTargetFields() target field} values together with {@link Evaluator#getOutputFields() output field} values.
	 */
	@SuppressWarnings (
		value = {"fallthrough"}
	)
	static
	public Map evaluate(Map predictions, ModelEvaluationContext context){
		ModelEvaluator modelEvaluator = context.getModelEvaluator();

		Model model = modelEvaluator.getModel();

		Output output = model.getOutput();
		if(output == null){
			return predictions;
		}

		Map result = new LinkedHashMap<>(predictions);

		List outputFields = output.getOutputFields();

		outputFields:
		for(OutputField outputField : outputFields){
			FieldName targetField = outputField.getTargetField();

			Object targetValue = null;

			ResultFeatureType resultFeature = outputField.getFeature();

			String segmentId = outputField.getSegmentId();

			SegmentResultMap segmentPredictions = null;

			// Load the target value of the specified segment
			if(segmentId != null){

				if(!(model instanceof MiningModel)){
					throw new InvalidFeatureException(outputField);
				}

				MiningModelEvaluationContext miningModelContext = (MiningModelEvaluationContext)context;

				segmentPredictions = miningModelContext.getResult(segmentId);

				// "If there is no Segment matching segmentId or if the predicate of the matching Segment evaluated to false, then the result delivered by this OutputField is missing"
				if(segmentPredictions == null){
					continue outputFields;
				} // End if

				if(targetField != null){

					if(!segmentPredictions.containsKey(targetField)){
						throw new MissingFieldException(targetField, outputField);
					}

					targetValue = segmentPredictions.get(targetField);
				} else

				{
					targetValue = segmentPredictions.getTargetValue();
				}
			} else

			// Load the target value
			{
				switch(resultFeature){
					case ENTITY_ID:
						{
							// "Result feature entityId returns the id of the winning segment"
							if(model instanceof MiningModel){
								targetValue = TypeUtil.cast(HasEntityId.class, predictions);

								break;
							}
						}
						// Falls through
					default:
						{
							if(targetField == null){
								targetField = modelEvaluator.getTargetField();
							} // End if

							if(!predictions.containsKey(targetField)){
								throw new MissingFieldException(targetField, outputField);
							}

							targetValue = predictions.get(targetField);
						}
						break;
				}
			}

			// If the target value is missing, then the result delivered by this OutputField is missing
			if(targetValue == null){
				continue outputFields;
			}

			Object value;

			// Perform the requested computation on the target value
			switch(resultFeature){
				case PREDICTED_VALUE:
					{
						value = getPredictedValue(targetValue);
					}
					break;
				case PREDICTED_DISPLAY_VALUE:
					{
						DataField dataField = modelEvaluator.getDataField(targetField);

						Target target = modelEvaluator.getTarget(targetField);

						value = getPredictedDisplayValue(targetValue, dataField, target);
					}
					break;
				case TRANSFORMED_VALUE:
				case DECISION:
					{
						if(segmentId != null){
							String name = outputField.getValue();
							if(name == null){
								throw new InvalidFeatureException(outputField);
							}

							Expression expression = outputField.getExpression();
							if(expression != null){
								throw new InvalidFeatureException(outputField);
							}

							value = segmentPredictions.get(FieldName.create(name));

							break;
						}

						Expression expression = outputField.getExpression();
						if(expression == null){
							throw new InvalidFeatureException(outputField);
						}

						value = FieldValueUtil.getValue(ExpressionUtil.evaluate(expression, context));
					}
					break;
				case PROBABILITY:
					{
						value = getProbability(targetValue, outputField);
					}
					break;
				case RESIDUAL:
					{
						FieldValue expectedValue = context.getField(targetField);
						if(expectedValue == null){
							throw new MissingFieldException(targetField, outputField);
						}

						DataField dataField = modelEvaluator.getDataField(targetField);

						OpType opType = dataField.getOpType();
						switch(opType){
							case CONTINUOUS:
								value = getContinuousResidual(targetValue, expectedValue);
								break;
							case CATEGORICAL:
								value = getCategoricalResidual(targetValue, expectedValue);
								break;
							default:
								throw new UnsupportedFeatureException(dataField, opType);
						}
					}
					break;
				case CLUSTER_ID:
					{
						value = getClusterId(targetValue);
					}
					break;
				case ENTITY_ID:
					{
						if(targetValue instanceof HasRuleValues){
							value = getRuleValue(targetValue, outputField, RuleFeatureType.RULE_ID);

							break;
						}

						value = getEntityId(targetValue, outputField);
					}
					break;
				case AFFINITY:
					{
						value = getAffinity(targetValue, outputField);
					}
					break;
				case CLUSTER_AFFINITY:
				case ENTITY_AFFINITY:
					{
						value = getEntityAffinity(targetValue);
					}
					break;
				case REASON_CODE:
					{
						value = getReasonCode(targetValue, outputField);
					}
					break;
				case RULE_VALUE:
					{
						value = getRuleValue(targetValue, outputField);
					}
					break;
				case ANTECEDENT:
					{
						value = getRuleValue(targetValue, outputField, RuleFeatureType.ANTECEDENT);
					}
					break;
				case CONSEQUENT:
					{
						value = getRuleValue(targetValue, outputField, RuleFeatureType.CONSEQUENT);
					}
					break;
				case RULE:
					{
						value = getRuleValue(targetValue, outputField, RuleFeatureType.RULE);
					}
					break;
				case RULE_ID:
					{
						value = getRuleValue(targetValue, outputField, RuleFeatureType.RULE_ID);
					}
					break;
				case CONFIDENCE:
					{
						value = getRuleValue(targetValue, outputField, RuleFeatureType.CONFIDENCE);
					}
					break;
				case SUPPORT:
					{
						value = getRuleValue(targetValue, outputField, RuleFeatureType.SUPPORT);
					}
					break;
				case LIFT:
					{
						value = getRuleValue(targetValue, outputField, RuleFeatureType.LIFT);
					}
					break;
				case LEVERAGE:
					{
						value = getRuleValue(targetValue, outputField, RuleFeatureType.LEVERAGE);
					}
					break;
				case WARNING:
					{
						value = context.getWarnings();
					}
					break;
				default:
					throw new UnsupportedFeatureException(outputField, resultFeature);
			}

			FieldValue outputValue = FieldValueUtil.create(outputField, value);

			// The result of one output field becomes available to other output fields
			context.declare(outputField.getName(), outputValue);

			result.put(outputField.getName(), FieldValueUtil.getValue(outputValue));
		}

		return result;
	}

	/**
	 * @throws TypeAnalysisException If the data type cannot be determined.
	 */
	static
	public DataType getDataType(OutputField outputField, ModelEvaluator modelEvaluator){
		DataType dataType = outputField.getDataType();

		if(dataType != null){
			return dataType;
		}

		String segmentId = outputField.getSegmentId();
		if(segmentId != null){
			throw new TypeAnalysisException(outputField);
		}

		ResultFeatureType resultFeature = outputField.getFeature();
		switch(resultFeature){
			case PREDICTED_VALUE:
				{
					FieldName targetField = outputField.getTargetField();
					if(targetField == null){
						targetField = modelEvaluator.getTargetField();
					}

					DataField dataField = modelEvaluator.getDataField(targetField);
					if(dataField == null){
						throw new TypeAnalysisException(outputField);
					}

					return dataField.getDataType();
				}
			case PREDICTED_DISPLAY_VALUE:
				{
					return DataType.STRING; // XXX
				}
			case TRANSFORMED_VALUE:
			case DECISION:
				{
					Expression expression = outputField.getExpression();
					if(expression == null){
						throw new InvalidFeatureException(outputField);
					}

					return ExpressionUtil.getDataType(expression, modelEvaluator);
				}
			case PROBABILITY:
			case RESIDUAL:
			case STANDARD_ERROR:
				{
					return DataType.DOUBLE;
				}
			case ENTITY_ID:
			case CLUSTER_ID:
				{
					return DataType.STRING;
				}
			case AFFINITY:
			case ENTITY_AFFINITY:
			case CLUSTER_AFFINITY:
				{
					return DataType.DOUBLE;
				}
			case REASON_CODE:
				{
					return DataType.STRING;
				}
			case RULE_VALUE:
				{
					return getRuleDataType(outputField);
				}
			case ANTECEDENT:
				{
					return getRuleDataType(outputField, RuleFeatureType.ANTECEDENT);
				}
			case CONSEQUENT:
				{
					return getRuleDataType(outputField, RuleFeatureType.CONSEQUENT);
				}
			case RULE:
				{
					return getRuleDataType(outputField, RuleFeatureType.RULE);
				}
			case RULE_ID:
				{
					return getRuleDataType(outputField, RuleFeatureType.RULE_ID);
				}
			case SUPPORT:
				{
					return getRuleDataType(outputField, RuleFeatureType.SUPPORT);
				}
			case CONFIDENCE:
				{
					return getRuleDataType(outputField, RuleFeatureType.CONFIDENCE);
				}
			case LIFT:
				{
					return getRuleDataType(outputField, RuleFeatureType.LIFT);
				}
			case LEVERAGE:
				{
					return getRuleDataType(outputField, RuleFeatureType.LEVERAGE);
				}
			case WARNING:
				{
					throw new TypeAnalysisException(outputField);
				}
			default:
				throw new UnsupportedFeatureException(outputField, resultFeature);
		}
	}

	static
	private Object getPredictedValue(Object object){
		return EvaluatorUtil.decode(object);
	}

	static
	private Object getPredictedDisplayValue(Object object, DataField dataField, Target target){

		if(object instanceof HasDisplayValue){
			HasDisplayValue hasDisplayValue = TypeUtil.cast(HasDisplayValue.class, object);

			return hasDisplayValue.getDisplayValue();
		}

		object = getPredictedValue(object);

		if(target != null){
			TargetValue targetValue = TargetUtil.getTargetValue(target, object);

			if(targetValue != null){
				String displayValue = targetValue.getDisplayValue();

				if(displayValue != null){
					return displayValue;
				}
			}
		}

		OpType opType = dataField.getOpType();

		if((OpType.CATEGORICAL).equals(opType) || (OpType.ORDINAL).equals(opType)){
			Value value = ArgumentUtil.getValidValue(dataField, object);

			if(value != null){
				String displayValue = value.getDisplayValue();

				if(displayValue != null){
					return displayValue;
				}
			}
		}

		// "If the display value is not specified explicitly, then the raw predicted value is used by default"
		return object;
	}

	static
	private Double getProbability(Object object, OutputField outputField){
		HasProbability hasProbability = TypeUtil.cast(HasProbability.class, object);

		String value = getCategoryValue(object, outputField);

		return hasProbability.getProbability(value);
	}

	static
	private String getCategoryValue(Object object, OutputField outputField){
		String value = outputField.getValue();

		// "If the value attribute is not specified, then the predicted categorical value should be returned as a result"
		if(value == null){
			return TypeUtil.format(getPredictedValue(object));
		}

		return value;
	}

	static
	private Double getContinuousResidual(Object object, FieldValue expectedObject){
		Number value = (Number)getPredictedValue(object);
		Number expectedValue = (Number)FieldValueUtil.getValue(expectedObject);

		return Double.valueOf(expectedValue.doubleValue() - value.doubleValue());
	}

	static
	public Double getCategoricalResidual(Object object, FieldValue expectedObject){
		HasProbability hasProbability = TypeUtil.cast(HasProbability.class, object);

		String value = TypeUtil.format(getPredictedValue(object));
		String expectedValue = TypeUtil.format(FieldValueUtil.getValue(expectedObject));

		boolean equals = TypeUtil.equals(DataType.STRING, value, expectedValue);

		return Double.valueOf((equals ? 1d : 0d) - hasProbability.getProbability(value));
	}

	static
	private String getClusterId(Object object){
		HasEntityId hasEntityId = TypeUtil.cast(HasEntityId.class, object);

		return hasEntityId.getEntityId();
	}

	static
	private String getEntityId(Object object, OutputField outputField){
		HasEntityId hasEntityId = TypeUtil.cast(HasEntityId.class, object);

		int rank = outputField.getRank();
		if(rank <= 0){
			throw new InvalidFeatureException(outputField);
		}

		if(rank > 1){
			HasEntityIdRanking hasEntityIdRanking = TypeUtil.cast(HasEntityIdRanking.class, object);

			OutputField.RankOrder rankOrder = outputField.getRankOrder();
			switch(rankOrder){
				case DESCENDING:
					break;
				default:
					throw new UnsupportedFeatureException(outputField, rankOrder);
			}

			return getElement(hasEntityIdRanking.getEntityIdRanking(), rank);
		}

		return hasEntityId.getEntityId();
	}

	static
	public Double getAffinity(Object object, OutputField outputField){
		HasAffinity hasAffinity = TypeUtil.cast(HasAffinity.class, object);

		int rank = outputField.getRank();
		if(rank <= 0){
			throw new InvalidFeatureException(outputField);
		}

		if(rank > 1){
			HasAffinityRanking hasAffinityRanking = TypeUtil.cast(HasAffinityRanking.class, object);

			OutputField.RankOrder rankOrder = outputField.getRankOrder();
			switch(rankOrder){
				case DESCENDING:
					break;
				default:
					throw new UnsupportedFeatureException(outputField, rankOrder);
			}

			return getElement(hasAffinityRanking.getAffinityRanking(), rank);
		}

		String value = getCategoryValue(object, outputField);

		return hasAffinity.getAffinity(value);
	}

	static
	public Double getEntityAffinity(Object object){
		HasEntityAffinity hasEntityAffinity = TypeUtil.cast(HasEntityAffinity.class, object);

		return hasEntityAffinity.getEntityAffinity();
	}

	static
	public String getReasonCode(Object object, OutputField outputField){
		HasReasonCodeRanking hasReasonCodeRanking = TypeUtil.cast(HasReasonCodeRanking.class, object);

		int rank = outputField.getRank();
		if(rank <= 0){
			throw new InvalidFeatureException(outputField);
		}

		return getElement(hasReasonCodeRanking.getReasonCodeRanking(), rank);
	}

	static
	public Object getRuleValue(Object object, OutputField outputField, RuleFeatureType ruleFeature){
		HasRuleValues hasRuleValues = TypeUtil.cast(HasRuleValues.class, object);

		List associationRules = getRuleValues(hasRuleValues, outputField);

		String isMultiValued = outputField.getIsMultiValued();
		if(!("0").equals(isMultiValued)){
			throw new UnsupportedFeatureException(outputField);
		}

		int rank = outputField.getRank();
		if(rank <= 0){
			throw new InvalidFeatureException(outputField);
		}

		AssociationRule associationRule = getElement(associationRules, rank);
		if(associationRule != null){
			return getRuleFeature(hasRuleValues, associationRule, outputField, ruleFeature);
		}

		return null;
	}

	static
	public Object getRuleValue(Object object, OutputField outputField){
		HasRuleValues hasRuleValues = TypeUtil.cast(HasRuleValues.class, object);

		List associationRules = getRuleValues(hasRuleValues, outputField);

		String isMultiValued = outputField.getIsMultiValued();

		// Return a single result
		if(("0").equals(isMultiValued)){
			int rank = outputField.getRank();
			if(rank <= 0){
				throw new InvalidFeatureException(outputField);
			}

			AssociationRule associationRule = getElement(associationRules, rank);
			if(associationRule != null){
				return getRuleFeature(hasRuleValues, associationRule, outputField);
			}

			return null;
		} else

		// Return multiple results
		if(("1").equals(isMultiValued)){
			int size;

			int rank = outputField.getRank();
			if(rank < 0){
				throw new InvalidFeatureException(outputField);
			} else

			// "A zero value indicates that all output values are to be returned"
			if(rank == 0){
				size = associationRules.size();
			} else

			// "A positive value indicates the number of output values to be returned"
			{
				size = Math.min(rank, associationRules.size());
			}

			associationRules = associationRules.subList(0, size);

			List result = new ArrayList<>();

			for(AssociationRule associationRule : associationRules){
				result.add(getRuleFeature(hasRuleValues, associationRule, outputField));
			}

			return result;
		} else

		{
			throw new InvalidFeatureException(outputField);
		}
	}

	static
	private List getRuleValues(HasRuleValues hasRuleValues, final OutputField outputField){
		List associationRules = hasRuleValues.getRuleValues(outputField.getAlgorithm());

		Comparator comparator = new Comparator(){

			private OutputField.RankBasis rankBasis = outputField.getRankBasis();

			private OutputField.RankOrder rankOrder = outputField.getRankOrder();


			@Override
			public int compare(AssociationRule left, AssociationRule right){
				int order;

				switch(this.rankBasis){
					case CONFIDENCE:
						order = (getConfidence(left)).compareTo(getConfidence(right));
						break;
					case SUPPORT:
						order = (getSupport(left)).compareTo(getSupport(right));
						break;
					case LIFT:
						order = (getLift(left)).compareTo(getLift(right));
						break;
					case LEVERAGE:
						order = (getLeverage(left)).compareTo(getLeverage(right));
						break;
					case AFFINITY:
						order = (getAffinity(left)).compareTo(getAffinity(right));
						break;
					default:
						throw new UnsupportedFeatureException(outputField, this.rankBasis);
				} // End switch

				switch(this.rankOrder){
					case ASCENDING:
						return order;
					case DESCENDING:
						return -order;
					default:
						throw new UnsupportedFeatureException(outputField, this.rankOrder);
				}
			}

			private Double getConfidence(AssociationRule rule){
				return checkRuleFeature(rule, rule.getConfidence());
			}

			private Double getSupport(AssociationRule rule){
				return checkRuleFeature(rule, rule.getSupport());
			}

			private Double getLift(AssociationRule rule){
				return checkRuleFeature(rule, rule.getLift());
			}

			private Double getLeverage(AssociationRule rule){
				return checkRuleFeature(rule, rule.getLeverage());
			}

			private Double getAffinity(AssociationRule rule){
				return checkRuleFeature(rule, rule.getAffinity());
			}

			private  V checkRuleFeature(AssociationRule rule, V value){

				if(value == null){
					throw new InvalidFeatureException(rule);
				}

				return value;
			}
		};

		Ordering ordering = Ordering.from(comparator);

		return ordering.sortedCopy(associationRules);
	}

	static
	private Object getRuleFeature(HasRuleValues hasRuleValues, AssociationRule associationRule, OutputField outputField){
		return getRuleFeature(hasRuleValues, associationRule, outputField, outputField.getRuleFeature());
	}

	@SuppressWarnings (
		value = {"unchecked"}
	)
	static
	private Object getRuleFeature(HasRuleValues hasRuleValues, AssociationRule associationRule, PMMLObject element, RuleFeatureType ruleFeature){

		switch(ruleFeature){
			case ANTECEDENT:
				return getItemValues(hasRuleValues, associationRule.getAntecedent());
			case CONSEQUENT:
				return getItemValues(hasRuleValues, associationRule.getConsequent());
			case RULE:
				{
					Joiner joiner = Joiner.on(',');

					StringBuilder sb = new StringBuilder();

					String left = joiner.join(getItemValues(hasRuleValues, associationRule.getAntecedent()));
					sb.append('{').append(left).append('}');

					sb.append("->");

					String right = joiner.join(getItemValues(hasRuleValues, associationRule.getConsequent()));
					sb.append('{').append(right).append('}');

					return sb.toString();
				}
			case RULE_ID:
				{
					HasEntityRegistry hasEntityRegistry = (HasEntityRegistry)hasRuleValues;

					return EntityUtil.getId(associationRule, hasEntityRegistry);
				}
			case CONFIDENCE:
				return associationRule.getConfidence();
			case SUPPORT:
				return associationRule.getSupport();
			case LIFT:
				return associationRule.getLift();
			case LEVERAGE:
				return associationRule.getLeverage();
			case AFFINITY:
				return associationRule.getAffinity();
			default:
				throw new UnsupportedFeatureException(element, ruleFeature);
		}
	}

	static
	private DataType getRuleDataType(OutputField outputField){
		return getRuleDataType(outputField, outputField.getRuleFeature());
	}

	static
	private DataType getRuleDataType(OutputField outputField, RuleFeatureType ruleFeature){
		String isMultiValued = outputField.getIsMultiValued();
		if(!("0").equals(isMultiValued)){
			throw new TypeAnalysisException(outputField);
		}

		switch(ruleFeature){
			case ANTECEDENT:
			case CONSEQUENT:
				throw new TypeAnalysisException(outputField);
			case RULE:
			case RULE_ID:
				return DataType.STRING;
			case SUPPORT:
			case CONFIDENCE:
			case LIFT:
			case LEVERAGE:
			case AFFINITY:
				return DataType.DOUBLE;
			default:
				throw new UnsupportedFeatureException(outputField, ruleFeature);
		}
	}

	static
	private List getItemValues(HasRuleValues hasRuleValues, String id){
		List result = new ArrayList<>();

		BiMap itemRegistry = hasRuleValues.getItemRegistry();
		BiMap itemsetRegistry = hasRuleValues.getItemsetRegistry();

		Itemset itemset = itemsetRegistry.get(id);

		List itemRefs = itemset.getItemRefs();
		for(ItemRef itemRef : itemRefs){
			Item item = itemRegistry.get(itemRef.getItemRef());

			result.add(item.getValue());
		}

		return result;
	}

	static
	private  E getElement(List elements, int rank){
		int index = (rank - 1);

		if(index < elements.size()){
			return elements.get(index);
		}

		return null;
	}
}