
org.jpmml.rexp.HurdleConverter Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of pmml-rexp Show documentation
Show all versions of pmml-rexp Show documentation
JPMML R to PMML converter
/*
* Copyright (c) 2024 Villu Ruusmann
*
* This file is part of JPMML-R
*
* JPMML-R is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-R is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-R. If not, see .
*/
package org.jpmml.rexp;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLFunctions;
import org.dmg.pmml.True;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.ExpressionUtil;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.rexp.evaluator.RExpFunctions;
public class HurdleConverter extends MixtureModelConverter {
public HurdleConverter(RGenericVector hurdle){
super(hurdle);
}
@Override
public PMML encodePMML(RExpEncoder encoder){
Model zeroModel = encodeComponent(MixtureModelConverter.NAME_ZERO, encoder);
OutputField zeroPredictField = ModelUtil.createPredictedField(FieldNameUtil.create("predict", MixtureModelConverter.NAME_ZERO), OpType.CONTINUOUS, DataType.DOUBLE);
encoder.createDerivedField(zeroModel, zeroPredictField, true);
Segment zeroSegment = new Segment(True.INSTANCE, zeroModel)
.setId(MixtureModelConverter.NAME_ZERO);
Model countModel = encodeComponent(MixtureModelConverter.NAME_COUNT, encoder);
OutputField countPredictField = ModelUtil.createPredictedField(FieldNameUtil.create("predict", MixtureModelConverter.NAME_COUNT), OpType.CONTINUOUS, DataType.DOUBLE);
encoder.createDerivedField(countModel, countPredictField, true);
Segment countSegment = new Segment(True.INSTANCE, countModel)
.setId(MixtureModelConverter.NAME_COUNT);
Expression adjExpression = ExpressionUtil.createApply(PMMLFunctions.EXP,
ExpressionUtil.createApply(PMMLFunctions.SUBTRACT,
ExpressionUtil.createApply(PMMLFunctions.LN, new FieldRef(zeroPredictField)),
ExpressionUtil.createApply(RExpFunctions.STATS_PPOIS, ExpressionUtil.createConstant(0), new FieldRef(countPredictField))
)
);
DerivedField adjZeroPredictField = encoder.createDerivedField(FieldNameUtil.create("adjusted", zeroPredictField.requireName()), OpType.CONTINUOUS, DataType.DOUBLE, adjExpression);
Expression targetExpression = ExpressionUtil.createApply(PMMLFunctions.MULTIPLY, new FieldRef(countPredictField), new FieldRef(adjZeroPredictField));
DerivedField targetField = encoder.createDerivedField(FieldNameUtil.create("adjusted", countPredictField.requireName()), OpType.CONTINUOUS, DataType.DOUBLE, targetExpression);
ContinuousLabel label = getLabel();
Map outputFields = Collections.singletonMap(FieldNameUtil.create("truncated", label.getName()), countPredictField);
Model fullModel = encodeTarget(targetField, outputFields, encoder);
Segment fullSegment = new Segment(True.INSTANCE, fullModel)
.setId(MixtureModelConverter.NAME_FULL);
List models = Arrays.asList(zeroModel, countModel, fullModel);
Segmentation segmentation = new Segmentation(Segmentation.MultipleModelMethod.MODEL_CHAIN, null)
.addSegments(zeroSegment, countSegment, fullSegment);
MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, MiningModelUtil.createMiningSchema(models))
.setSegmentation(segmentation);
return encoder.encodePMML(miningModel);
}
@Override
protected Model encodeZeroComponent(List features, List coefficients, Double intercept, Schema schema){
RGenericVector hurdle = getObject();
RStringVector dist = hurdle.getGenericElement("dist").getStringElement(MixtureModelConverter.NAME_ZERO);
String distName = dist.asScalar();
switch(distName){
case "binomial":
return RegressionModelUtil.createRegression(features, coefficients, intercept, RegressionModel.NormalizationMethod.LOGIT, schema);
default:
throw new IllegalArgumentException(distName);
}
}
@Override
protected Model encodeCountComponent(List features, List coefficients, Double intercept, Schema schema){
RGenericVector hurdle = getObject();
RStringVector dist = hurdle.getGenericElement("dist").getStringElement(MixtureModelConverter.NAME_COUNT);
String distName = dist.asScalar();
switch(distName){
case "poisson":
return RegressionModelUtil.createRegression(features, coefficients, intercept, RegressionModel.NormalizationMethod.EXP, schema);
default:
throw new IllegalArgumentException(distName);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy