org.jpmml.translator.regression.RegressionModelTranslator Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of pmml-transpiler Show documentation
Show all versions of pmml-transpiler Show documentation
JPMML class model transpiler
/*
* Copyright (c) 2019 Villu Ruusmann
*
* This file is part of JPMML-Transpiler
*
* JPMML-Transpiler is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-Transpiler is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-Transpiler. If not, see .
*/
package org.jpmml.translator.regression;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.Iterables;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Multimaps;
import com.sun.codemodel.JBlock;
import com.sun.codemodel.JClass;
import com.sun.codemodel.JDefinedClass;
import com.sun.codemodel.JExpr;
import com.sun.codemodel.JExpression;
import com.sun.codemodel.JFieldVar;
import com.sun.codemodel.JForEach;
import com.sun.codemodel.JForLoop;
import com.sun.codemodel.JMethod;
import com.sun.codemodel.JVar;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMML;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.TextIndex;
import org.dmg.pmml.regression.CategoricalPredictor;
import org.dmg.pmml.regression.NumericPredictor;
import org.dmg.pmml.regression.PredictorTerm;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.evaluator.Classification;
import org.jpmml.evaluator.ProbabilityDistribution;
import org.jpmml.evaluator.TokenizedString;
import org.jpmml.evaluator.Value;
import org.jpmml.evaluator.VoteDistribution;
import org.jpmml.evaluator.java.JavaModel;
import org.jpmml.evaluator.regression.RegressionModelUtil;
import org.jpmml.model.InvalidElementException;
import org.jpmml.model.UnsupportedAttributeException;
import org.jpmml.model.UnsupportedElementListException;
import org.jpmml.translator.FieldInfo;
import org.jpmml.translator.FieldInfoMap;
import org.jpmml.translator.FunctionInvocation;
import org.jpmml.translator.IdentifierUtil;
import org.jpmml.translator.JBinaryFileInitializer;
import org.jpmml.translator.JDirectInitializer;
import org.jpmml.translator.MethodScope;
import org.jpmml.translator.ModelTranslator;
import org.jpmml.translator.Modifiers;
import org.jpmml.translator.OperableRef;
import org.jpmml.translator.PMMLObjectUtil;
import org.jpmml.translator.Scope;
import org.jpmml.translator.TextIndexUtil;
import org.jpmml.translator.TranslationContext;
import org.jpmml.translator.ValueBuilder;
import org.jpmml.translator.ValueMapBuilder;
public class RegressionModelTranslator extends ModelTranslator {
public RegressionModelTranslator(PMML pmml, RegressionModel regressionModel){
super(pmml, regressionModel);
MiningFunction miningFunction = regressionModel.requireMiningFunction();
switch(miningFunction){
case REGRESSION:
case CLASSIFICATION:
break;
default:
throw new UnsupportedAttributeException(regressionModel, miningFunction);
}
List regressionTables = regressionModel.requireRegressionTables();
for(RegressionTable regressionTable : regressionTables){
if(regressionTable.hasPredictorTerms()){
List predictorTerms = regressionTable.getPredictorTerms();
throw new UnsupportedElementListException(predictorTerms);
}
}
}
@Override
public JMethod translateRegressor(TranslationContext context){
RegressionModel regressionModel = getModel();
List regressionTables = regressionModel.getRegressionTables();
FieldInfoMap fieldInfos = getFieldInfos(new HashSet<>(regressionTables));
RegressionTable regressionTable = Iterables.getOnlyElement(regressionTables);
JMethod evaluateMethod = createEvaluatorMethod(Value.class, regressionTable, true, context);
try {
context.pushScope(new MethodScope(evaluateMethod));
ValueBuilder valueBuilder = translateRegressionTable(regressionTable, fieldInfos, context);
computeValue(valueBuilder, regressionModel, context);
} finally {
context.popScope();
}
return evaluateMethod;
}
@Override
public JMethod translateClassifier(TranslationContext context){
RegressionModel regressionModel = getModel();
List regressionTables = regressionModel.getRegressionTables();
FieldInfoMap fieldInfos = getFieldInfos(new HashSet<>(regressionTables));
JMethod evaluateListMethod = createEvaluatorMethod(Classification.class, regressionTables, true, context);
try {
context.pushScope(new MethodScope(evaluateListMethod));
ValueMapBuilder valueMapBuilder = new ValueMapBuilder(context)
.construct("values");
for(RegressionTable regressionTable : regressionTables){
JMethod evaluateMethod = createEvaluatorMethod(Value.class, regressionTable, true, context);
try {
context.pushScope(new MethodScope(evaluateMethod));
ValueBuilder valueBuilder = translateRegressionTable(regressionTable, fieldInfos, context);
context._return(valueBuilder.getVariable());
} finally {
context.popScope();
}
valueMapBuilder.update("put", regressionTable.getTargetCategory(), createEvaluatorMethodInvocation(evaluateMethod, context));
}
computeClassification(valueMapBuilder, regressionModel, context);
} finally {
context.popScope();
}
return evaluateListMethod;
}
static
public void computeValue(ValueBuilder valueBuilder, RegressionModel regressionModel, TranslationContext context){
RegressionModel.NormalizationMethod normalizationMethod = regressionModel.getNormalizationMethod();
switch(normalizationMethod){
case NONE:
break;
default:
valueBuilder.staticUpdate(RegressionModelUtil.class, "normalizeRegressionResult", normalizationMethod);
break;
}
context._return(valueBuilder.getVariable());
}
static
public void computeClassification(ValueMapBuilder valueMapBuilder, RegressionModel regressionModel, TranslationContext context){
RegressionModel.NormalizationMethod normalizationMethod = regressionModel.getNormalizationMethod();
List regressionTables = regressionModel.getRegressionTables();
Output output = regressionModel.getOutput();
if(regressionTables.size() == 2){
switch(normalizationMethod){
case NONE:
case LOGIT:
case PROBIT:
case CLOGLOG:
case LOGLOG:
case CAUCHIT:
valueMapBuilder.staticUpdate(RegressionModelUtil.class, "computeBinomialProbabilities", normalizationMethod);
break;
case SIMPLEMAX:
case SOFTMAX:
valueMapBuilder.staticUpdate(RegressionModelUtil.class, "computeMultinomialProbabilities", normalizationMethod);
break;
default:
throw new InvalidElementException(regressionModel);
}
} else
if(regressionTables.size() > 2){
switch(normalizationMethod){
case NONE:
case SIMPLEMAX:
case SOFTMAX:
valueMapBuilder.staticUpdate(RegressionModelUtil.class, "computeMultinomialProbabilities", normalizationMethod);
break;
default:
throw new InvalidElementException(regressionModel);
}
} else
{
throw new InvalidElementException(regressionModel);
}
boolean probabilistic = false;
if(output != null && output.hasOutputFields()){
List outputFields = output.getOutputFields();
List probabilityOutputFields = outputFields.stream()
.filter(outputField -> {
ResultFeature resultFeature = outputField.getResultFeature();
switch(resultFeature){
case PROBABILITY:
return true;
default:
return false;
}
})
.collect(Collectors.toList());
probabilistic = (regressionTables.size() == probabilityOutputFields.size());
}
JExpression classificationExpr;
if(probabilistic){
classificationExpr = context._new(ProbabilityDistribution.class, valueMapBuilder);
} else
{
classificationExpr = context._new(VoteDistribution.class, valueMapBuilder);
}
context._return(classificationExpr);
}
static
public ValueBuilder translateRegressionTable(RegressionTable regressionTable, FieldInfoMap fieldInfos, TranslationContext context){
ValueBuilder valueBuilder = new ValueBuilder(context)
.declare(IdentifierUtil.create("result", regressionTable), context.getValueFactoryVariable().newValue());
if(regressionTable.hasNumericPredictors()){
List numericPredictors = regressionTable.getNumericPredictors();
ListMultimap tfTerms = ArrayListMultimap.create();
for(NumericPredictor numericPredictor : numericPredictors){
FieldInfo fieldInfo = fieldInfos.require(numericPredictor);
FunctionInvocation functionInvocation = fieldInfo.getFunctionInvocation();
if((functionInvocation instanceof FunctionInvocation.Tf) || (functionInvocation instanceof FunctionInvocation.TfIdf)){
FunctionInvocationPredictor tfTerm = new FunctionInvocationPredictor(numericPredictor, functionInvocation);
FunctionInvocation.Tf tf = tfTerm.getTf();
tfTerms.put(tf.getTextField(), tfTerm);
continue;
}
Number coefficient = numericPredictor.requireCoefficient();
Integer exponent = numericPredictor.getExponent();
OperableRef operableRef = context.ensureOperable(fieldInfo, (method) -> true);
if(exponent != null && exponent.intValue() != 1){
valueBuilder.update("add", coefficient, operableRef.getExpression(), exponent);
} else
{
if(coefficient.doubleValue() != 1d){
valueBuilder.update("add", coefficient, operableRef.getExpression());
} else
{
valueBuilder.update("add", operableRef.getExpression());
}
}
}
addTermFrequencies(regressionTable, valueBuilder, Multimaps.asMap(tfTerms), fieldInfos, context);
} // End if
if(regressionTable.hasCategoricalPredictors()){
Map> fieldCategoricalPredictors = regressionTable.getCategoricalPredictors().stream()
.collect(Collectors.groupingBy(categoricalPredictor -> categoricalPredictor.requireField(), Collectors.toList()));
JDefinedClass modelFuncInterface = ensureFunctionalInterface(Number.class, context);
List evaluateCategoryMethods = new ArrayList<>();
JBinaryFileInitializer resourceInitializer = null;
Collection>> entries = fieldCategoricalPredictors.entrySet();
for(Map.Entry> entry : entries){
String name = entry.getKey();
List categoricalPredictors = entry.getValue();
FieldInfo fieldInfo = fieldInfos.require(name);
JMethod evaluateCategoryMethod = createEvaluatorMethod(Number.class, categoricalPredictors, false, context);
try {
context.pushScope(new MethodScope(evaluateCategoryMethod));
OperableRef operableRef = context.ensureOperable(fieldInfo, (method) -> true);
Map
© 2015 - 2024 Weber Informatics LLC | Privacy Policy