All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.evaluator.RuleSetModelEvaluator Maven / Gradle / Ivy

/*
 * Copyright (c) 2013 University of Tartu
 */
package org.jpmml.evaluator;

import java.util.*;

import org.jpmml.manager.*;

import org.dmg.pmml.*;

import com.google.common.collect.*;

public class RuleSetModelEvaluator extends RuleSetModelManager implements Evaluator {

	public RuleSetModelEvaluator(PMML pmml){
		super(pmml);
	}

	public RuleSetModelEvaluator(PMML pmml, RuleSetModel ruleSetModel){
		super(pmml, ruleSetModel);
	}

	@Override
	public FieldValue prepare(FieldName name, Object value){
		return ArgumentUtil.prepare(getDataField(name), getMiningField(name), value);
	}

	@Override
	public Map evaluate(Map arguments){
		RuleSetModel ruleSetModel = getModel();
		if(!ruleSetModel.isScorable()){
			throw new InvalidResultException(ruleSetModel);
		}

		Map predictions;

		ModelManagerEvaluationContext context = new ModelManagerEvaluationContext(this);
		context.pushFrame(arguments);

		MiningFunctionType miningFunction = ruleSetModel.getFunctionName();
		switch(miningFunction){
			case CLASSIFICATION:
				predictions = evaluateRuleSet(context);
				break;
			default:
				throw new UnsupportedFeatureException(ruleSetModel, miningFunction);
		}

		return OutputUtil.evaluate(predictions, context);
	}

	private Map evaluateRuleSet(ModelManagerEvaluationContext context){
		RuleSet ruleSet = getRuleSet();

		List ruleSelectionMethods = ruleSet.getRuleSelectionMethods();

		RuleSelectionMethod ruleSelectionMethod;

		// "If more than one method is included, the first method is used as the default method for scoring"
		if(ruleSelectionMethods.size() > 0){
			ruleSelectionMethod = ruleSelectionMethods.get(0);
		} else

		{
			throw new InvalidFeatureException(ruleSet);
		}

		// Both the ordering of keys and values is significant
		ListMultimap firedRules = LinkedListMultimap.create();

		List rules = ruleSet.getRules();
		for(Rule rule : rules){
			collectFiredRules(firedRules, rule, context);
		}

		RuleClassificationMap result = new RuleClassificationMap();

		RuleSelectionMethod.Criterion criterion = ruleSelectionMethod.getCriterion();

		Set keys = firedRules.keySet();
		for(String key : keys){
			List keyRules = firedRules.get(key);

			switch(criterion){
				case FIRST_HIT:
					{
						SimpleRule winner = keyRules.get(0);

						// The first value of the first key
						if(result.getEntity() == null){
							result.setEntity(winner);
						}

						result.put(key, winner.getConfidence());
					}
					break;
				case WEIGHTED_SUM:
					{
						SimpleRule winner = null;

						double totalWeight = 0;

						for(SimpleRule keyRule : keyRules){

							if(winner == null || (winner.getWeight() < keyRule.getWeight())){
								winner = keyRule;
							}

							totalWeight += keyRule.getWeight();
						}

						result.put(winner, key, totalWeight / firedRules.size());
					}
					break;
				case WEIGHTED_MAX:
					{
						SimpleRule winner = null;

						for(SimpleRule keyRule : keyRules){

							if(winner == null || (winner.getWeight() < keyRule.getWeight())){
								winner = keyRule;
							}
						}

						result.put(winner, key, winner.getConfidence());
					}
					break;
				default:
					throw new UnsupportedFeatureException(ruleSelectionMethod, criterion);
			}
		}

		return TargetUtil.evaluateClassification(result, context);
	}

	static
	private void collectFiredRules(ListMultimap firedRules, Rule rule, EvaluationContext context){
		Predicate predicate = rule.getPredicate();
		if(predicate == null){
			throw new InvalidFeatureException(rule);
		}

		Boolean status = PredicateUtil.evaluate(predicate, context);
		if(status == null || !status.booleanValue()){
			return;
		} // End if

		if(rule instanceof SimpleRule){
			SimpleRule simpleRule = (SimpleRule)rule;

			firedRules.put(simpleRule.getScore(), simpleRule);
		} else

		if(rule instanceof CompoundRule){
			CompoundRule compoundRule = (CompoundRule)rule;

			List childRules = compoundRule.getRules();
			for(Rule childRule : childRules){
				collectFiredRules(firedRules, childRule, context);
			}
		}
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy