All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.evaluator.AssociationModelEvaluator Maven / Gradle / Ivy

There is a newer version: 1.6.8
Show newest version
/*
 * Copyright (c) 2013 Villu Ruusmann
 *
 * This file is part of JPMML-Evaluator
 *
 * JPMML-Evaluator is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-Evaluator is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-Evaluator.  If not, see .
 */
package org.jpmml.evaluator;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.ImmutableBiMap;
import org.dmg.pmml.AssociationModel;
import org.dmg.pmml.AssociationRule;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Item;
import org.dmg.pmml.ItemRef;
import org.dmg.pmml.Itemset;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.PMML;
import org.dmg.pmml.Target;

public class AssociationModelEvaluator extends ModelEvaluator implements HasEntityRegistry {

	public AssociationModelEvaluator(PMML pmml){
		super(pmml, AssociationModel.class);
	}

	public AssociationModelEvaluator(PMML pmml, AssociationModel associationModel){
		super(pmml, associationModel);
	}

	@Override
	public String getSummary(){
		return "Association rules";
	}

	/**
	 * @return null Always.
	 */
	@Override
	public Target getTarget(FieldName name){
		return null;
	}

	@Override
	public BiMap getEntityRegistry(){
		return getValue(AssociationModelEvaluator.entityCache);
	}

	@Override
	public void verify(){
		AssociationModel associationModel = getModel();

		List targetFields = getTargetFields();
		if(targetFields.size() > 0){
			MiningSchema miningSchema = associationModel.getMiningSchema();

			throw new InvalidFeatureException("Too many target fields", miningSchema);
		}

		super.verify();
	}

	@Override
	public Map evaluate(ModelEvaluationContext context){
		AssociationModel associationModel = getModel();
		if(!associationModel.isScorable()){
			throw new InvalidResultException(associationModel);
		}

		Map predictions;

		MiningFunctionType miningFunction = associationModel.getFunctionName();
		switch(miningFunction){
			case ASSOCIATION_RULES:
				predictions = evaluateAssociationRules(context);
				break;
			default:
				throw new UnsupportedFeatureException(associationModel, miningFunction);
		}

		return OutputUtil.evaluate(predictions, context);
	}

	private Map evaluateAssociationRules(EvaluationContext context){
		AssociationModel associationModel = getModel();

		Collection activeValue = getActiveValue(context);

		Set input = createInput(activeValue, context);

		Map flags = new HashMap<>();

		List itemsets = associationModel.getItemsets();
		for(Itemset itemset : itemsets){
			flags.put(itemset.getId(), isSubset(input, itemset));
		}

		List associationRules = associationModel.getAssociationRules();

		BitSet antecedentFlags = new BitSet(associationRules.size());
		BitSet consequentFlags = new BitSet(associationRules.size());

		for(int i = 0; i < associationRules.size(); i++){
			AssociationRule associationRule = associationRules.get(i);

			Boolean antecedentFlag = flags.get(associationRule.getAntecedent());
			if(antecedentFlag == null){
				throw new InvalidFeatureException(associationRule);
			}

			antecedentFlags.set(i, antecedentFlag);

			Boolean consequentFlag = flags.get(associationRule.getConsequent());
			if(consequentFlag == null){
				throw new InvalidFeatureException(associationRule);
			}

			consequentFlags.set(i, consequentFlag);
		}

		Association association = new Association(associationRules, antecedentFlags, consequentFlags){

			@Override
			public Map getItems(){
				return AssociationModelEvaluator.this.getItems();
			}

			@Override
			public Map getItemsets(){
				return AssociationModelEvaluator.this.getItemsets();
			}

			@Override
			public BiMap getAssociationRuleRegistry(){
				return AssociationModelEvaluator.this.getEntityRegistry();
			}
		};

		return Collections.singletonMap(getTargetField(), association);
	}

	public Collection getActiveValue(EvaluationContext context){
		AssociationModel associationModel = getModel();

		List activeFields = getActiveFields();
		List groupFields = getGroupFields();

		MiningSchema miningSchema = associationModel.getMiningSchema();

		// Custom IBM SPSS-style model: no group fields, one or more active fields
		if(groupFields.size() == 0){

			if(activeFields.size() < 1){
				throw new InvalidFeatureException("No active fields", miningSchema);
			}

			List result = new ArrayList<>();

			for(FieldName activeField : activeFields){
				FieldValue value = context.evaluate(activeField);

				if(value == null){
					continue;
				} // End if

				if(value.equalsString("T")){
					result.add(activeField.getValue());
				} else

				if(value.equalsString("F")){
					continue;
				} else

				{
					throw new EvaluationException();
				}
			}

			return result;
		} else

		// Standard model: one group field, one active field
		if(groupFields.size() == 1){

			if(activeFields.size() < 1){
				throw new InvalidFeatureException("No active fields", miningSchema);
			} else

			if(activeFields.size() > 1){
				throw new InvalidFeatureException("Too many active fields", miningSchema);
			}

			FieldName activeField = activeFields.get(0);

			FieldValue value = context.evaluate(activeField);
			if(value == null){
				throw new MissingValueException(activeField);
			}

			Collection result = FieldValueUtil.getValue(Collection.class, value);

			return result;
		} else

		{
			throw new InvalidFeatureException(miningSchema);
		}
	}

	/**
	 * @return A set of {@link Item#getId() Item identifiers}.
	 */
	private Set createInput(Collection values, EvaluationContext context){
		Set result = new HashSet<>();

		Map valueItems = (getItemValues().inverse());

		values:
		for(Object value : values){
			String stringValue = TypeUtil.format(value);

			String id = valueItems.get(stringValue);
			if(id == null){
				context.addWarning("Unknown item value \"" + stringValue + "\"");

				continue values;
			}

			result.add(id);
		}

		return result;
	}

	static
	private boolean isSubset(Set input, Itemset itemset){
		boolean result = true;

		List itemRefs = itemset.getItemRefs();
		for(ItemRef itemRef : itemRefs){
			result &= input.contains(itemRef.getItemRef());

			if(!result){
				return false;
			}
		}

		return result;
	}

	private Map getItems(){
		return getValue(AssociationModelEvaluator.itemCache);
	}

	private Map getItemsets(){
		return getValue(AssociationModelEvaluator.itemsetCache);
	}

	/**
	 * @return A bidirectional map between {@link Item#getId() Item identifiers} and {@link Item#getValue() Item values}.
	 */
	private BiMap getItemValues(){
		return getValue(AssociationModelEvaluator.itemValueCache);
	}

	static
	private BiMap parseItemValues(AssociationModel associationModel){
		BiMap result = HashBiMap.create();

		List items = associationModel.getItems();
		for(Item item : items){
			result.put(item.getId(), item.getValue());
		}

		return result;
	}

	private static final LoadingCache> entityCache = CacheUtil.buildLoadingCache(new CacheLoader>(){

		@Override
		public BiMap load(AssociationModel associationModel){
			return EntityUtil.buildBiMap(associationModel.getAssociationRules());
		}
	});

	private static final LoadingCache> itemCache = CacheUtil.buildLoadingCache(new CacheLoader>(){

		@Override
		public Map load(AssociationModel associationModel){
			return IndexableUtil.buildMap(associationModel.getItems());
		}
	});

	private static final LoadingCache> itemsetCache = CacheUtil.buildLoadingCache(new CacheLoader>(){

		@Override
		public Map load(AssociationModel associationModel){
			return IndexableUtil.buildMap(associationModel.getItemsets());
		}
	});

	private static final LoadingCache> itemValueCache = CacheUtil.buildLoadingCache(new CacheLoader>(){

		@Override
		public BiMap load(AssociationModel associationModel){
			return ImmutableBiMap.copyOf(parseItemValues(associationModel));
		}
	});
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy