org.jpmml.evaluator.RegressionModelEvaluator Maven / Gradle / Ivy
/*
* Copyright (c) 2013 Villu Ruusmann
*
* This file is part of JPMML-Evaluator
*
* JPMML-Evaluator is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-Evaluator is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-Evaluator. If not, see .
*/
package org.jpmml.evaluator;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.dmg.pmml.CategoricalPredictor;
import org.dmg.pmml.DataField;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.NumericPredictor;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PredictorTerm;
import org.dmg.pmml.RegressionModel;
import org.dmg.pmml.RegressionNormalizationMethodType;
import org.dmg.pmml.RegressionTable;
public class RegressionModelEvaluator extends ModelEvaluator {
public RegressionModelEvaluator(PMML pmml){
super(pmml, RegressionModel.class);
}
public RegressionModelEvaluator(PMML pmml, RegressionModel regressionModel){
super(pmml, regressionModel);
}
@Override
public String getSummary(){
return "Regression";
}
@Override
public Map evaluate(ModelEvaluationContext context){
RegressionModel regressionModel = getModel();
if(!regressionModel.isScorable()){
throw new InvalidResultException(regressionModel);
}
Map predictions;
MiningFunctionType miningFunction = regressionModel.getFunctionName();
switch(miningFunction){
case REGRESSION:
predictions = evaluateRegression(context);
break;
case CLASSIFICATION:
predictions = evaluateClassification(context);
break;
default:
throw new UnsupportedFeatureException(regressionModel, miningFunction);
}
return OutputUtil.evaluate(predictions, context);
}
private Map evaluateRegression(ModelEvaluationContext context){
RegressionModel regressionModel = getModel();
FieldName targetField = regressionModel.getTargetFieldName();
if(targetField == null){
targetField = getTargetField();
}
List regressionTables = regressionModel.getRegressionTables();
if(regressionTables.size() != 1){
throw new InvalidFeatureException(regressionModel);
}
RegressionTable regressionTable = regressionTables.get(0);
Double result = evaluateRegressionTable(regressionTable, context);
if(result == null){
return TargetUtil.evaluateRegressionDefault(context);
}
result = normalizeRegressionResult(result);
return TargetUtil.evaluateRegression(Collections.singletonMap(targetField, result), context);
}
private Map evaluateClassification(ModelEvaluationContext context){
RegressionModel regressionModel = getModel();
FieldName targetField = regressionModel.getTargetFieldName();
if(targetField == null){
targetField = getTargetField();
}
DataField dataField = getDataField(targetField);
if(dataField == null){
throw new MissingFieldException(targetField, regressionModel);
}
OpType opType = dataField.getOpType();
switch(opType){
case CONTINUOUS:
throw new InvalidFeatureException(dataField);
case CATEGORICAL:
case ORDINAL:
break;
default:
throw new UnsupportedFeatureException(dataField, opType);
}
List regressionTables = regressionModel.getRegressionTables();
if(regressionTables.size() < 1){
throw new InvalidFeatureException(regressionModel);
}
List targetCategories = FieldValueUtil.getTargetCategories(dataField);
if(targetCategories.size() > 0 && targetCategories.size() != regressionTables.size()){
throw new InvalidFeatureException(dataField);
}
Map values = new LinkedHashMap<>();
for(RegressionTable regressionTable : regressionTables){
String targetCategory = regressionTable.getTargetCategory();
if(targetCategory == null){
throw new InvalidFeatureException(regressionTable);
}
Double value = evaluateRegressionTable(regressionTable, context);
// "If one or more RegressionTable elements cannot be evaluated, then the predictions are defined by the priorProbability values of the Target element"
if(value == null){
return TargetUtil.evaluateClassificationDefault(context);
}
values.put(targetCategory, value);
}
switch(opType){
case CATEGORICAL:
// "The binary logistic regression is a special case"
if(regressionTables.size() == 2){
computeBinomialProbabilities(values);
} else
{
computeMultinomialProbabilities(values);
}
break;
case ORDINAL:
computeOrdinalProbabilities(values, targetCategories);
break;
default:
throw new UnsupportedFeatureException(dataField, opType);
}
ProbabilityDistribution result = new ProbabilityDistribution();
result.putAll(values);
return TargetUtil.evaluateClassification(Collections.singletonMap(targetField, result), context);
}
private Double evaluateRegressionTable(RegressionTable regressionTable, EvaluationContext context){
double result = 0d;
result += regressionTable.getIntercept();
List numericPredictors = regressionTable.getNumericPredictors();
for(NumericPredictor numericPredictor : numericPredictors){
FieldValue value = context.evaluate(numericPredictor.getName());
// "If the input value is missing, then the result evaluates to a missing value"
if(value == null){
return null;
}
result += numericPredictor.getCoefficient() * Math.pow((value.asNumber()).doubleValue(), numericPredictor.getExponent());
}
List categoricalPredictors = regressionTable.getCategoricalPredictors();
for(CategoricalPredictor categoricalPredictor : categoricalPredictors){
FieldValue value = context.evaluate(categoricalPredictor.getName());
// "If the input value is missing, then the product is ignored"
if(value == null){
continue;
}
boolean equals = value.equals(categoricalPredictor);
result += categoricalPredictor.getCoefficient() * (equals ? 1d : 0d);
}
List predictorTerms = regressionTable.getPredictorTerms();
for(PredictorTerm predictorTerm : predictorTerms){
double product = predictorTerm.getCoefficient();
List fieldRefs = predictorTerm.getFieldRefs();
if(fieldRefs.size() < 1){
throw new InvalidFeatureException(predictorTerm);
}
for(FieldRef fieldRef : fieldRefs){
FieldValue value = ExpressionUtil.evaluate(fieldRef, context);
// "If the input value is missing, then the result evaluates to a missing value"
if(value == null){
return null;
}
product *= (value.asNumber()).doubleValue();
}
result += product;
}
return Double.valueOf(result);
}
private Double normalizeRegressionResult(Double value){
RegressionModel regressionModel = getModel();
RegressionNormalizationMethodType regressionNormalizationMethod = regressionModel.getNormalizationMethod();
switch(regressionNormalizationMethod){
case NONE:
return value;
case SOFTMAX:
case LOGIT:
return 1d / (1d + Math.exp(-value));
case EXP:
return Math.exp(value);
default:
throw new UnsupportedFeatureException(regressionModel, regressionNormalizationMethod);
}
}
private void computeBinomialProbabilities(Map values){
Double probability = 0d;
int i = 0;
Collection> entries = values.entrySet();
for(Map.Entry entry : entries){
// The probability of the first category is calculated
if(i == 0){
probability = normalizeClassificationResult(entry.getValue(), 2);
entry.setValue(probability);
} else
// The probability of the second category is obtained by subtracting the probability of the first category from 1.0
if(i == 1){
entry.setValue(1d - probability);
} else
{
throw new EvaluationException();
}
i++;
}
}
private void computeMultinomialProbabilities(Map values){
RegressionModel regressionModel = getModel();
RegressionNormalizationMethodType regressionNormalizationMethod = regressionModel.getNormalizationMethod();
switch(regressionNormalizationMethod){
case NONE:
return;
case SIMPLEMAX:
Classification.normalize(values);
return;
case SOFTMAX:
Classification.normalizeSoftMax(values);
return;
default:
break;
}
Collection> entries = values.entrySet();
for(Map.Entry entry : entries){
entry.setValue(normalizeClassificationResult(entry.getValue(), values.size()));
}
Classification.normalize(values);
}
private void computeOrdinalProbabilities(Map values, List targetCategories){
RegressionModel regressionModel = getModel();
RegressionNormalizationMethodType regressionNormalizationMethod = regressionModel.getNormalizationMethod();
switch(regressionNormalizationMethod){
case NONE:
return;
case SIMPLEMAX:
case SOFTMAX:
throw new InvalidFeatureException(regressionModel);
default:
break;
}
Collection> entries = values.entrySet();
for(Map.Entry entry : entries){
entry.setValue(normalizeClassificationResult(entry.getValue(), values.size()));
}
calculateCategoryProbabilities(values, targetCategories);
}
private Double normalizeClassificationResult(Double value, int classes){
RegressionModel regressionModel = getModel();
RegressionNormalizationMethodType regressionNormalizationMethod = regressionModel.getNormalizationMethod();
switch(regressionNormalizationMethod){
case NONE:
return value;
case SIMPLEMAX:
throw new InvalidFeatureException(regressionModel);
case SOFTMAX:
if(classes != 2){
throw new InvalidFeatureException(regressionModel);
}
// Falls through
case LOGIT:
return 1d / (1d + Math.exp(-value));
case PROBIT:
return NormalDistributionUtil.cumulativeProbability(value);
case CLOGLOG:
return 1d - Math.exp(-Math.exp(value));
case LOGLOG:
return Math.exp(-Math.exp(-value));
case CAUCHIT:
return 0.5d + (1d / Math.PI) * Math.atan(value);
default:
throw new UnsupportedFeatureException(regressionModel, regressionNormalizationMethod);
}
}
static
public void calculateCategoryProbabilities(Map map, List categories){
double offset = 0d;
for(int i = 0; i < categories.size() - 1; i++){
String category = categories.get(i);
Double cumulativeProbability = map.get(category);
if(cumulativeProbability == null || cumulativeProbability > 1d){
throw new EvaluationException();
}
Double probability = (cumulativeProbability - offset);
if(probability < 0d){
throw new EvaluationException();
}
map.put(category, probability);
offset = cumulativeProbability;
}
if(categories.size() > 1){
String category = categories.get(categories.size() - 1);
map.put(category, 1d - offset);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy