All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.evaluator.AssociationModelEvaluator Maven / Gradle / Ivy

/*
 * Copyright (c) 2013 University of Tartu
 */
package org.jpmml.evaluator;

import java.util.*;

import org.jpmml.manager.*;

import org.dmg.pmml.*;

import com.google.common.collect.*;

public class AssociationModelEvaluator extends AssociationModelManager implements Evaluator {

	private BiMap items = null;

	private BiMap itemsets = null;

	private BiMap entities = null;

	private BiMap itemValues = null;


	public AssociationModelEvaluator(PMML pmml){
		super(pmml);
	}

	public AssociationModelEvaluator(PMML pmml, AssociationModel associationModel){
		super(pmml, associationModel);
	}

	@Override
	public BiMap getItemRegistry(){

		if(this.items == null){
			this.items = super.getItemRegistry();
		}

		return this.items;
	}

	@Override
	public BiMap getItemsetRegistry(){

		if(this.itemsets == null){
			this.itemsets = super.getItemsetRegistry();
		}

		return this.itemsets;
	}

	@Override
	public BiMap getEntityRegistry(){

		if(this.entities == null){
			this.entities = super.getEntityRegistry();
		}

		return this.entities;
	}

	@Override
	public FieldValue prepare(FieldName name, Object value){
		return ArgumentUtil.prepare(getDataField(name), getMiningField(name), value);
	}

	@Override
	public Map evaluate(Map arguments){
		AssociationModel associationModel = getModel();
		if(!associationModel.isScorable()){
			throw new InvalidResultException(associationModel);
		}

		Map predictions;

		ModelManagerEvaluationContext context = new ModelManagerEvaluationContext(this);
		context.pushFrame(arguments);

		MiningFunctionType miningFunction = associationModel.getFunctionName();
		switch(miningFunction){
			case ASSOCIATION_RULES:
				predictions = evaluate(context);
				break;
			default:
				throw new UnsupportedFeatureException(associationModel, miningFunction);
		}

		return OutputUtil.evaluate(predictions, context);
	}

	private Map evaluate(EvaluationContext context){
		AssociationModel associationModel = getModel();

		FieldName activeField = getActiveField();

		FieldValue value = context.getArgument(activeField);
		if(value == null){
			throw new MissingFieldException(activeField, associationModel);
		}

		Collection values;

		try {
			values = (Collection)FieldValueUtil.getValue(value);
		} catch(ClassCastException cce){
			throw new TypeCheckException(Collection.class, value);
		}

		Set input = createInput(values, context);

		Map flags = Maps.newLinkedHashMap();

		List itemsets = getItemsets();
		for(Itemset itemset : itemsets){
			flags.put(itemset.getId(), isSubset(input, itemset));
		}

		List associationRules = getAssociationRules();

		BitSet antecedentFlags = new BitSet(associationRules.size());
		BitSet consequentFlags = new BitSet(associationRules.size());

		for(int i = 0; i < associationRules.size(); i++){
			AssociationRule associationRule = associationRules.get(i);

			Boolean antecedentFlag = flags.get(associationRule.getAntecedent());
			if(antecedentFlag == null){
				throw new InvalidFeatureException(associationRule);
			}

			antecedentFlags.set(i, antecedentFlag);

			Boolean consequentFlag = flags.get(associationRule.getConsequent());
			if(consequentFlag == null){
				throw new InvalidFeatureException(associationRule);
			}

			consequentFlags.set(i, consequentFlag);
		}

		Association association = new Association(associationRules, antecedentFlags, consequentFlags){

			@Override
			public BiMap getItemRegistry(){
				return AssociationModelEvaluator.this.getItemRegistry();
			}

			@Override
			public BiMap getItemsetRegistry(){
				return AssociationModelEvaluator.this.getItemsetRegistry();
			}

			@Override
			public BiMap getAssociationRuleRegistry(){
				return AssociationModelEvaluator.this.getEntityRegistry();
			}
		};

		return Collections.singletonMap(getTargetField(), association);
	}

	/**
	 * @return A set of {@link Item#getId() Item identifiers}.
	 */
	private Set createInput(Collection values, EvaluationContext context){
		Set result = Sets.newLinkedHashSet();

		Map valueItems = (getItemValues().inverse());

		values:
		for(Object value : values){
			String stringValue = TypeUtil.format(value);

			String id = valueItems.get(stringValue);
			if(id == null){
				context.addWarning("Unknown item value \"" + stringValue + "\"");

				continue values;
			}

			result.add(id);
		}

		return result;
	}

	/**
	 * @return A bidirectional map between {@link Item#getId() Item identifiers} and {@link Item#getValue() Item values}.
	 */
	private BiMap getItemValues(){

		if(this.itemValues == null){
			this.itemValues = createItemValues();
		}

		return this.itemValues;
	}

	private BiMap createItemValues(){
		BiMap result = HashBiMap.create();

		List items = getItems();
		for(Item item : items){
			result.put(item.getId(), item.getValue());
		}

		return result;
	}

	static
	private boolean isSubset(Set input, Itemset itemset){
		boolean result = true;

		List itemRefs = itemset.getItemRefs();
		for(ItemRef itemRef : itemRefs){
			result &= input.contains(itemRef.getItemRef());

			if(!result){
				return false;
			}
		}

		return result;
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy