org.jpmml.evaluator.AssociationModelEvaluator Maven / Gradle / Ivy
/*
* Copyright (c) 2013 University of Tartu
*/
package org.jpmml.evaluator;
import java.util.*;
import org.jpmml.manager.*;
import org.dmg.pmml.*;
import com.google.common.collect.*;
public class AssociationModelEvaluator extends AssociationModelManager implements Evaluator {
private BiMap items = null;
private BiMap itemsets = null;
private BiMap entities = null;
private BiMap itemValues = null;
public AssociationModelEvaluator(PMML pmml){
super(pmml);
}
public AssociationModelEvaluator(PMML pmml, AssociationModel associationModel){
super(pmml, associationModel);
}
@Override
public BiMap getItemRegistry(){
if(this.items == null){
this.items = super.getItemRegistry();
}
return this.items;
}
@Override
public BiMap getItemsetRegistry(){
if(this.itemsets == null){
this.itemsets = super.getItemsetRegistry();
}
return this.itemsets;
}
@Override
public BiMap getEntityRegistry(){
if(this.entities == null){
this.entities = super.getEntityRegistry();
}
return this.entities;
}
@Override
public FieldValue prepare(FieldName name, Object value){
return ArgumentUtil.prepare(getDataField(name), getMiningField(name), value);
}
@Override
public Map evaluate(Map arguments){
AssociationModel associationModel = getModel();
if(!associationModel.isScorable()){
throw new InvalidResultException(associationModel);
}
Map predictions;
ModelManagerEvaluationContext context = new ModelManagerEvaluationContext(this);
context.pushFrame(arguments);
MiningFunctionType miningFunction = associationModel.getFunctionName();
switch(miningFunction){
case ASSOCIATION_RULES:
predictions = evaluate(context);
break;
default:
throw new UnsupportedFeatureException(associationModel, miningFunction);
}
return OutputUtil.evaluate(predictions, context);
}
private Map evaluate(EvaluationContext context){
AssociationModel associationModel = getModel();
FieldName activeField = getActiveField();
FieldValue value = context.getArgument(activeField);
if(value == null){
throw new MissingFieldException(activeField, associationModel);
}
Collection> values;
try {
values = (Collection>)FieldValueUtil.getValue(value);
} catch(ClassCastException cce){
throw new TypeCheckException(Collection.class, value);
}
Set input = createInput(values, context);
Map flags = Maps.newLinkedHashMap();
List itemsets = getItemsets();
for(Itemset itemset : itemsets){
flags.put(itemset.getId(), isSubset(input, itemset));
}
List associationRules = getAssociationRules();
BitSet antecedentFlags = new BitSet(associationRules.size());
BitSet consequentFlags = new BitSet(associationRules.size());
for(int i = 0; i < associationRules.size(); i++){
AssociationRule associationRule = associationRules.get(i);
Boolean antecedentFlag = flags.get(associationRule.getAntecedent());
if(antecedentFlag == null){
throw new InvalidFeatureException(associationRule);
}
antecedentFlags.set(i, antecedentFlag);
Boolean consequentFlag = flags.get(associationRule.getConsequent());
if(consequentFlag == null){
throw new InvalidFeatureException(associationRule);
}
consequentFlags.set(i, consequentFlag);
}
Association association = new Association(associationRules, antecedentFlags, consequentFlags){
@Override
public BiMap getItemRegistry(){
return AssociationModelEvaluator.this.getItemRegistry();
}
@Override
public BiMap getItemsetRegistry(){
return AssociationModelEvaluator.this.getItemsetRegistry();
}
@Override
public BiMap getAssociationRuleRegistry(){
return AssociationModelEvaluator.this.getEntityRegistry();
}
};
return Collections.singletonMap(getTargetField(), association);
}
/**
* @return A set of {@link Item#getId() Item identifiers}.
*/
private Set createInput(Collection> values, EvaluationContext context){
Set result = Sets.newLinkedHashSet();
Map valueItems = (getItemValues().inverse());
values:
for(Object value : values){
String stringValue = TypeUtil.format(value);
String id = valueItems.get(stringValue);
if(id == null){
context.addWarning("Unknown item value \"" + stringValue + "\"");
continue values;
}
result.add(id);
}
return result;
}
/**
* @return A bidirectional map between {@link Item#getId() Item identifiers} and {@link Item#getValue() Item values}.
*/
private BiMap getItemValues(){
if(this.itemValues == null){
this.itemValues = createItemValues();
}
return this.itemValues;
}
private BiMap createItemValues(){
BiMap result = HashBiMap.create();
List- items = getItems();
for(Item item : items){
result.put(item.getId(), item.getValue());
}
return result;
}
static
private boolean isSubset(Set
input, Itemset itemset){
boolean result = true;
List itemRefs = itemset.getItemRefs();
for(ItemRef itemRef : itemRefs){
result &= input.contains(itemRef.getItemRef());
if(!result){
return false;
}
}
return result;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy