org.jpmml.evaluator.mining.MiningModelUtil Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of pmml-evaluator Show documentation
Show all versions of pmml-evaluator Show documentation
JPMML class model evaluator
The newest version!
/*
* Copyright (c) 2017 Villu Ruusmann
*
* This file is part of JPMML-Evaluator
*
* JPMML-Evaluator is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-Evaluator is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-Evaluator. If not, see .
*/
package org.jpmml.evaluator.mining;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.dmg.pmml.DataType;
import org.dmg.pmml.mining.Segmentation;
import org.jpmml.evaluator.EvaluatorUtil;
import org.jpmml.evaluator.HasProbability;
import org.jpmml.evaluator.ProbabilityAggregator;
import org.jpmml.evaluator.TypeCheckException;
import org.jpmml.evaluator.TypeUtil;
import org.jpmml.evaluator.Value;
import org.jpmml.evaluator.ValueAggregator;
import org.jpmml.evaluator.ValueFactory;
import org.jpmml.evaluator.ValueMap;
import org.jpmml.evaluator.VoteAggregator;
public class MiningModelUtil {
private MiningModelUtil(){
}
static
public SegmentResult asSegmentResult(Segmentation.MultipleModelMethod multipleModelMethod, Map predictions){
switch(multipleModelMethod){
case SELECT_FIRST:
case SELECT_ALL:
case MODEL_CHAIN:
case MULTI_MODEL_CHAIN:
{
if(predictions instanceof SegmentResult){
SegmentResult segmentResult = (SegmentResult)predictions;
return segmentResult;
}
}
break;
default:
break;
}
return null;
}
static
public Value aggregateValues(ValueFactory valueFactory, Segmentation.MultipleModelMethod multipleModelMethod, Segmentation.MissingPredictionTreatment missingPredictionTreatment, Number missingThreshold, List segmentResults){
ValueAggregator aggregator;
switch(multipleModelMethod){
case AVERAGE:
case SUM:
aggregator = new ValueAggregator.UnivariateStatistic<>(valueFactory);
break;
case MEDIAN:
aggregator = new ValueAggregator.Median<>(valueFactory, segmentResults.size());
break;
case WEIGHTED_AVERAGE:
case WEIGHTED_SUM:
aggregator = new ValueAggregator.WeightedUnivariateStatistic<>(valueFactory);
break;
case WEIGHTED_MEDIAN:
aggregator = new ValueAggregator.WeightedMedian<>(valueFactory, segmentResults.size());
break;
default:
throw new IllegalArgumentException();
}
Fraction missingFraction = null;
segmentResults:
for(SegmentResult segmentResult : segmentResults){
Object targetValue = EvaluatorUtil.decode(segmentResult.getTargetValue());
if(targetValue == null){
switch(missingPredictionTreatment){
case RETURN_MISSING:
return null;
case SKIP_SEGMENT:
if(missingFraction == null){
missingFraction = new Fraction<>(valueFactory, segmentResults);
} // End if
if(missingFraction.update(segmentResult, missingThreshold)){
return null;
}
continue segmentResults;
case CONTINUE:
return null;
default:
throw new IllegalArgumentException();
}
}
Number value;
try {
if(targetValue instanceof Number){
value = (Number)targetValue;
} else
{
value = (Number)TypeUtil.cast(DataType.DOUBLE, targetValue);
}
} catch(TypeCheckException tce){
throw tce.ensureContext(segmentResult.getSegment());
}
switch(multipleModelMethod){
case AVERAGE:
case SUM:
case MEDIAN:
aggregator.add(value);
break;
case WEIGHTED_AVERAGE:
case WEIGHTED_SUM:
case WEIGHTED_MEDIAN:
Number weight = segmentResult.getWeight();
aggregator.add(value, weight);
break;
default:
throw new IllegalArgumentException();
}
}
switch(multipleModelMethod){
case AVERAGE:
return aggregator.average();
case WEIGHTED_AVERAGE:
return aggregator.weightedAverage();
case SUM:
return aggregator.sum();
case WEIGHTED_SUM:
return aggregator.weightedSum();
case MEDIAN:
return aggregator.median();
case WEIGHTED_MEDIAN:
return aggregator.weightedMedian();
default:
throw new IllegalArgumentException();
}
}
static
public ValueMap