org.jpmml.rexp.PreProcessEncoder Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of pmml-rexp Show documentation
Show all versions of pmml-rexp Show documentation
JPMML R to PMML converter
The newest version!
/*
* Copyright (c) 2016 Villu Ruusmann
*
* This file is part of JPMML-R
*
* JPMML-R is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-R is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-R. If not, see .
*/
package org.jpmml.rexp;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMMLFunctions;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ExpressionUtil;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FeatureUtil;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.FortranMatrixUtil;
import org.jpmml.converter.ValueUtil;
public class PreProcessEncoder extends TransformerEncoder {
private Map> ranges = Collections.emptyMap();
private Map mean = Collections.emptyMap();
private Map std = Collections.emptyMap();
private Map median = Collections.emptyMap();
public PreProcessEncoder(RGenericVector preProcess){
super(preProcess);
RGenericVector method = preProcess.getGenericElement("method");
RStringVector methodNames = method.names();
for(int i = 0; i < methodNames.size(); i++){
String methodName = methodNames.getValue(i);
switch(methodName){
case "ignore":
break;
case "range":
this.ranges = createArguments(preProcess.getDoubleElement("ranges"), 2);
break;
case "center":
this.mean = createArguments(preProcess.getDoubleElement("mean"));
break;
case "scale":
this.std = createArguments(preProcess.getDoubleElement("std"));
break;
case "medianImpute":
this.median = createArguments(preProcess.getDoubleElement("median"));
break;
default:
throw new IllegalArgumentException(methodName);
}
}
}
@Override
public void addFeature(Feature feature){
String name = FeatureUtil.getName(feature);
DataField dataField = getDataField(name);
if(dataField != null){
Expression expression = feature.ref();
Expression transformedExpression = encodeExpression(name, expression);
if(!(expression).equals(transformedExpression)){
DerivedField derivedField = createDerivedField(FieldNameUtil.create("preProcess", feature), OpType.CONTINUOUS, DataType.DOUBLE, transformedExpression);
feature = new ContinuousFeature(PreProcessEncoder.this, derivedField);
}
}
super.addFeature(feature);
}
private Expression encodeExpression(String name, Expression expression){
List ranges = this.ranges.get(name);
if(ranges != null){
Double min = ranges.get(0);
Double max = ranges.get(1);
if(!ValueUtil.isZero(min)){
expression = ExpressionUtil.createApply(PMMLFunctions.SUBTRACT, expression, ExpressionUtil.createConstant(min));
} // End if
if(!ValueUtil.isOne(max - min)){
expression = ExpressionUtil.createApply(PMMLFunctions.DIVIDE, expression, ExpressionUtil.createConstant(max - min));
}
}
Double mean = this.mean.get(name);
if(mean != null && !ValueUtil.isZero(mean)){
expression = ExpressionUtil.createApply(PMMLFunctions.SUBTRACT, expression, ExpressionUtil.createConstant(mean));
}
Double std = this.std.get(name);
if(std != null && !ValueUtil.isOne(std)){
expression = ExpressionUtil.createApply(PMMLFunctions.DIVIDE, expression, ExpressionUtil.createConstant(std));
}
Double median = this.median.get(name);
if(median != null){
expression = ExpressionUtil.createApply(PMMLFunctions.IF,
ExpressionUtil.createApply(PMMLFunctions.ISNOTMISSING, new FieldRef(name)),
expression,
ExpressionUtil.createConstant(median)
);
}
return expression;
}
static
private Map createArguments(RDoubleVector values){
Map result = new LinkedHashMap<>();
RStringVector names = values.names();
for(int i = 0; i < names.size(); i++){
String name = names.getValue(i);
result.put(name, values.getValue(i));
}
return result;
}
static
private Map> createArguments(RDoubleVector values, int rows){
Map> result = new LinkedHashMap<>();
RStringVector rowNames = values.dimnames(0);
RStringVector columnNames = values.dimnames(1);
for(int i = 0; i < columnNames.size(); i++){
String name = columnNames.getValue(i);
result.put(name, FortranMatrixUtil.getColumn(values.getValues(), rows, columnNames.size(), i));
}
return result;
}
}