org.jpmml.evaluator.MiningModelEvaluator Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of pmml-evaluator Show documentation
Show all versions of pmml-evaluator Show documentation
JPMML class model evaluator
/*
* Copyright (c) 2013 Villu Ruusmann
*
* This file is part of JPMML-Evaluator
*
* JPMML-Evaluator is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-Evaluator is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-Evaluator. If not, see .
*/
package org.jpmml.evaluator;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.BiMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.ListMultimap;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.EmbeddedModel;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.LocalTransformations;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.MiningModel;
import org.dmg.pmml.Model;
import org.dmg.pmml.MultipleModelMethodType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.Segment;
import org.dmg.pmml.Segmentation;
public class MiningModelEvaluator extends ModelEvaluator implements HasEntityRegistry {
private ModelEvaluatorFactory evaluatorFactory = null;
public MiningModelEvaluator(PMML pmml){
super(pmml, MiningModel.class);
}
public MiningModelEvaluator(PMML pmml, MiningModel miningModel){
super(pmml, miningModel);
}
@Override
public String getSummary(){
return "Ensemble model";
}
@Override
public BiMap getEntityRegistry(){
return getValue(MiningModelEvaluator.entityCache);
}
@Override
protected DataField getDataField(){
MiningModel miningModel = getModel();
Segmentation segmentation = miningModel.getSegmentation();
if(segmentation == null){
return null;
}
MultipleModelMethodType multipleModelMethod = segmentation.getMultipleModelMethod();
switch(multipleModelMethod){
case SELECT_ALL:
return null;
default:
return super.getDataField();
}
}
@Override
public MiningModelEvaluationContext createContext(ModelEvaluationContext parent){
return new MiningModelEvaluationContext(parent, this);
}
@Override
public Map evaluate(ModelEvaluationContext context){
return evaluate((MiningModelEvaluationContext)context);
}
public Map evaluate(MiningModelEvaluationContext context){
MiningModel miningModel = getModel();
if(!miningModel.isScorable()){
throw new InvalidResultException(miningModel);
}
EmbeddedModel embeddedModel = Iterables.getFirst(miningModel.getEmbeddedModels(), null);
if(embeddedModel != null){
throw new UnsupportedFeatureException(embeddedModel);
}
Segmentation segmentation = miningModel.getSegmentation();
if(segmentation == null){
throw new InvalidFeatureException(miningModel);
}
Map predictions;
MiningFunctionType miningFunction = miningModel.getFunctionName();
switch(miningFunction){
case REGRESSION:
predictions = evaluateRegression(context);
break;
case CLASSIFICATION:
predictions = evaluateClassification(context);
break;
case CLUSTERING:
predictions = evaluateClustering(context);
break;
default:
predictions = evaluateAny(context);
break;
}
return OutputUtil.evaluate(predictions, context);
}
private Map evaluateRegression(MiningModelEvaluationContext context){
MiningModel miningModel = getModel();
List segmentResults = evaluateSegmentation(context);
Map predictions = getSegmentationResult(REGRESSION_METHODS, segmentResults);
if(predictions != null){
return predictions;
}
Segmentation segmentation = miningModel.getSegmentation();
Double result = aggregateValues(segmentation, segmentResults);
return TargetUtil.evaluateRegression(result, context);
}
private Map evaluateClassification(MiningModelEvaluationContext context){
MiningModel miningModel = getModel();
List segmentResults = evaluateSegmentation(context);
Map predictions = getSegmentationResult(CLASSIFICATION_METHODS, segmentResults);
if(predictions != null){
return predictions;
}
Segmentation segmentation = miningModel.getSegmentation();
MultipleModelMethodType multipleModelMethod = segmentation.getMultipleModelMethod();
Classification result;
switch(multipleModelMethod){
case MAJORITY_VOTE:
case WEIGHTED_MAJORITY_VOTE:
{
result = new ProbabilityDistribution();
result.putAll(aggregateVotes(segmentation, segmentResults));
// Convert from votes to probabilities
result.normalizeValues();
}
break;
case MAX:
case MEDIAN:
{
// The max and median aggregation functions yield non-probability distributions
result = new Classification(Classification.Type.VOTE);
result.putAll(aggregateProbabilities(segmentation, segmentResults));
}
break;
case AVERAGE:
case WEIGHTED_AVERAGE:
{
// The average and weighted average (with weights summing to 1) aggregation functions yield probability distributions
result = new ProbabilityDistribution();
result.putAll(aggregateProbabilities(segmentation, segmentResults));
}
break;
default:
throw new UnsupportedFeatureException(segmentation, multipleModelMethod);
}
return TargetUtil.evaluateClassification(result, context);
}
private Map evaluateClustering(MiningModelEvaluationContext context){
MiningModel miningModel = getModel();
List segmentResults = evaluateSegmentation(context);
Map predictions = getSegmentationResult(CLUSTERING_METHODS, segmentResults);
if(predictions != null){
return predictions;
}
Segmentation segmentation = miningModel.getSegmentation();
Classification result = new Classification(Classification.Type.VOTE);
result.putAll(aggregateVotes(segmentation, segmentResults));
result.computeResult(DataType.STRING);
return Collections.singletonMap(getTargetField(), result);
}
private Map evaluateAny(MiningModelEvaluationContext context){
List segmentResults = evaluateSegmentation(context);
return getSegmentationResult(Collections.emptySet(), segmentResults);
}
private List evaluateSegmentation(MiningModelEvaluationContext context){
MiningModel miningModel = getModel();
List results = new ArrayList<>();
Segmentation segmentation = miningModel.getSegmentation();
LocalTransformations localTransformations = segmentation.getLocalTransformations();
if(localTransformations != null){
throw new UnsupportedFeatureException(localTransformations);
}
ModelEvaluatorFactory evaluatorFactory = getEvaluatorFactory();
if(evaluatorFactory == null){
evaluatorFactory = ModelEvaluatorFactory.newInstance();
}
BiMap entityRegistry = getEntityRegistry();
MultipleModelMethodType multipleModelMethod = segmentation.getMultipleModelMethod();
Model lastModel = null;
MiningFunctionType miningFunction = miningModel.getFunctionName();
List segments = segmentation.getSegments();
for(Segment segment : segments){
Predicate predicate = segment.getPredicate();
if(predicate == null){
throw new InvalidFeatureException(segment);
}
Boolean status = PredicateUtil.evaluate(predicate, context);
if(status == null || !status.booleanValue()){
continue;
}
Model model = segment.getModel();
if(model == null){
throw new InvalidFeatureException(segment);
}
// "With the exception of modelChain models, all model elements used inside Segment elements in one MiningModel must have the same MINING-FUNCTION"
switch(multipleModelMethod){
case MODEL_CHAIN:
lastModel = model;
break;
default:
if(!(miningFunction).equals(model.getFunctionName())){
throw new InvalidFeatureException(model);
}
break;
}
ModelEvaluator> evaluator = evaluatorFactory.newModelManager(getPMML(), model);
ModelEvaluationContext segmentContext = evaluator.createContext(context);
Map result = evaluator.evaluate(segmentContext);
FieldName targetField = evaluator.getTargetField();
List outputFields = evaluator.getOutputFields();
for(FieldName outputField : outputFields){
FieldValue outputValue = segmentContext.getField(outputField);
if(outputValue == null){
throw new MissingFieldException(outputField, segment);
}
// "The OutputFields from one model element can be passed as input to the MiningSchema of subsequent models"
context.declare(outputField, outputValue);
}
List warnings = segmentContext.getWarnings();
for(String warning : warnings){
context.addWarning(warning);
}
final
String entityId = EntityUtil.getId(segment, entityRegistry);
SegmentResultMap segmentResult = new SegmentResultMap(segment, targetField){
@Override
public String getEntityId(){
return entityId;
}
};
segmentResult.putAll(result);
context.putResult(entityId, segmentResult);
switch(multipleModelMethod){
case SELECT_FIRST:
return Collections.singletonList(segmentResult);
default:
results.add(segmentResult);
break;
}
}
// "The model element used inside the last Segment element executed must have the same MINING-FUNCTION"
switch(multipleModelMethod){
case MODEL_CHAIN:
if(lastModel != null && !(miningFunction).equals(lastModel.getFunctionName())){
throw new InvalidFeatureException(lastModel);
}
break;
default:
break;
}
return results;
}
private Map getSegmentationResult(Set multipleModelMethods, List segmentResults){
MiningModel miningModel = getModel();
Segmentation segmentation = miningModel.getSegmentation();
MultipleModelMethodType multipleModelMethod = segmentation.getMultipleModelMethod();
switch(multipleModelMethod){
case SELECT_ALL:
return selectAll(segmentResults);
case SELECT_FIRST:
if(segmentResults.size() > 0){
return segmentResults.get(0);
}
break;
case MODEL_CHAIN:
if(segmentResults.size() > 0){
return segmentResults.get(segmentResults.size() - 1);
}
break;
default:
if(!(multipleModelMethods).contains(multipleModelMethod)){
throw new UnsupportedFeatureException(segmentation, multipleModelMethod);
}
break;
}
// "If no segments have predicates that evaluate to true, then the result is a missing value"
if(segmentResults.size() == 0){
return Collections.singletonMap(getTargetField(), null);
}
return null;
}
public ModelEvaluatorFactory getEvaluatorFactory(){
return this.evaluatorFactory;
}
public void setEvaluatorFactory(ModelEvaluatorFactory evaluatorFactory){
this.evaluatorFactory = evaluatorFactory;
}
static
private Double aggregateValues(Segmentation segmentation, List segmentResults){
RegressionAggregator aggregator = new RegressionAggregator();
MultipleModelMethodType multipleModelMethod = segmentation.getMultipleModelMethod();
double denominator = 0d;
for(SegmentResultMap segmentResult : segmentResults){
Object targetValue = EvaluatorUtil.decode(segmentResult.getTargetValue());
Double value = (Double)TypeUtil.parseOrCast(DataType.DOUBLE, targetValue);
switch(multipleModelMethod){
case SUM:
case MEDIAN:
aggregator.add(value);
break;
case AVERAGE:
aggregator.add(value);
denominator += 1d;
break;
case WEIGHTED_AVERAGE:
double weight = segmentResult.getWeight();
aggregator.add(value * weight);
denominator += weight;
break;
default:
throw new UnsupportedFeatureException(segmentation, multipleModelMethod);
}
}
switch(multipleModelMethod){
case SUM:
return aggregator.sum();
case MEDIAN:
return aggregator.median();
case AVERAGE:
case WEIGHTED_AVERAGE:
return aggregator.average(denominator);
default:
throw new UnsupportedFeatureException(segmentation, multipleModelMethod);
}
}
static
private Map aggregateVotes(Segmentation segmentation, List segmentResults){
VoteAggregator aggregator = new VoteAggregator<>();
MultipleModelMethodType multipleModelMethod = segmentation.getMultipleModelMethod();
for(SegmentResultMap segmentResult : segmentResults){
Object targetValue = EvaluatorUtil.decode(segmentResult.getTargetValue());
String key = (String)targetValue;
switch(multipleModelMethod){
case MAJORITY_VOTE:
aggregator.add(key, 1d);
break;
case WEIGHTED_MAJORITY_VOTE:
aggregator.add(key, segmentResult.getWeight());
break;
default:
throw new UnsupportedFeatureException(segmentation, multipleModelMethod);
}
}
return aggregator.sumMap();
}
static
private Map aggregateProbabilities(Segmentation segmentation, List segmentResults){
ProbabilityAggregator aggregator = new ProbabilityAggregator();
MultipleModelMethodType multipleModelMethod = segmentation.getMultipleModelMethod();
double denominator = 0d;
for(SegmentResultMap segmentResult : segmentResults){
Object targetValue = segmentResult.getTargetValue();
HasProbability hasProbability = TypeUtil.cast(HasProbability.class, targetValue);
switch(multipleModelMethod){
case MAX:
case MEDIAN:
aggregator.add(hasProbability);
break;
case AVERAGE:
aggregator.add(hasProbability);
denominator += 1d;
break;
case WEIGHTED_AVERAGE:
double weight = segmentResult.getWeight();
aggregator.add(hasProbability, weight);
denominator += weight;
break;
default:
throw new UnsupportedFeatureException(segmentation, multipleModelMethod);
}
}
switch(multipleModelMethod){
case MAX:
return aggregator.maxMap();
case MEDIAN:
return aggregator.medianMap();
case AVERAGE:
case WEIGHTED_AVERAGE:
return aggregator.averageMap(denominator);
default:
throw new UnsupportedFeatureException(segmentation, multipleModelMethod);
}
}
static
private Map selectAll(List segmentResults){
ListMultimap result = ArrayListMultimap.create();
Set keys = null;
for(SegmentResultMap segmentResult : segmentResults){
if(keys == null){
keys = new LinkedHashSet<>(segmentResult.keySet());
} // End if
// Ensure that all List values in the ListMultimap contain the same number of elements
if(!(keys).equals(segmentResult.keySet())){
throw new EvaluationException();
}
for(FieldName key : keys){
result.put(key, segmentResult.get(key));
}
}
return result.asMap();
}
private static final Set REGRESSION_METHODS = EnumSet.of(MultipleModelMethodType.SUM, MultipleModelMethodType.MEDIAN, MultipleModelMethodType.AVERAGE, MultipleModelMethodType.WEIGHTED_AVERAGE);
private static final Set CLASSIFICATION_METHODS = EnumSet.of(MultipleModelMethodType.MAJORITY_VOTE, MultipleModelMethodType.WEIGHTED_MAJORITY_VOTE, MultipleModelMethodType.SUM, MultipleModelMethodType.MEDIAN, MultipleModelMethodType.AVERAGE, MultipleModelMethodType.WEIGHTED_AVERAGE);
private static final Set CLUSTERING_METHODS = EnumSet.of(MultipleModelMethodType.MAJORITY_VOTE, MultipleModelMethodType.WEIGHTED_MAJORITY_VOTE);
private static final LoadingCache> entityCache = CacheBuilder.newBuilder()
.weakKeys()
.build(new CacheLoader>(){
@Override
public BiMap load(MiningModel miningModel){
Segmentation segmentation = miningModel.getSegmentation();
return EntityUtil.buildBiMap(segmentation.getSegments());
}
});
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy