org.jpmml.rexp.ModelConverter Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of pmml-rexp Show documentation
Show all versions of pmml-rexp Show documentation
JPMML R to PMML converter
The newest version!
/*
* Copyright (c) 2016 Villu Ruusmann
*
* This file is part of JPMML-R
*
* JPMML-R is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-R is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-R. If not, see .
*/
package org.jpmml.rexp;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import com.google.common.base.Function;
import com.google.common.collect.Lists;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.dmg.pmml.VerificationField;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FeatureImportanceMap;
import org.jpmml.converter.ModelEncoder;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.ScalarLabel;
import org.jpmml.converter.Schema;
abstract
public class ModelConverter extends Converter {
public ModelConverter(R object){
super(object);
}
abstract
public void encodeSchema(RExpEncoder encoder);
abstract
public Model encodeModel(Schema schema);
public Model encode(Schema schema){
Model model = encodeModel(schema);
if(this instanceof HasFeatureImportances){
HasFeatureImportances hasFeatureImportances = (HasFeatureImportances)this;
FeatureImportanceMap featureImportances = hasFeatureImportances.getFeatureImportances(schema);
if(featureImportances != null && !featureImportances.isEmpty()){
ModelEncoder encoder = schema.getEncoder();
Collection> entries = featureImportances.entrySet();
for(Map.Entry entry : entries){
encoder.addFeatureImportance(model, entry.getKey(), entry.getValue());
}
}
}
return model;
}
@Override
public PMML encodePMML(RExpEncoder encoder){
RExp object = getObject();
RGenericVector verification = null;
if(object instanceof S4Object){
S4Object model = (S4Object)object;
verification = model.getGenericAttribute("verification", false);
} else
if(object instanceof RGenericVector){
RGenericVector model = (RGenericVector)object;
verification = model.getGenericElement("verification", false);
}
encodeSchema(encoder);
Schema schema = encoder.createSchema();
Model model = encode(schema);
verification:
if(verification != null){
RDoubleVector precision = verification.getDoubleElement("precision");
RDoubleVector zeroThreshold = verification.getDoubleElement("zeroThreshold");
VerificationMap data = new VerificationMap(precision.asScalar(), zeroThreshold.asScalar());
RGenericVector activeValues = verification.getGenericElement("active_values");
RGenericVector targetValues = verification.getGenericElement("target_values", false);
RGenericVector outputValues = verification.getGenericElement("output_values", false);
if(activeValues != null){
data.putInputData(encodeActiveValues(activeValues));
} // End if
if(targetValues != null && outputValues == null){
ScalarLabel scalarLabel = (ScalarLabel)schema.getLabel();
String name = scalarLabel.getName();
Collection verificationFields = data.keySet();
for(Iterator verificationFieldIt = verificationFields.iterator(); verificationFieldIt.hasNext(); ){
VerificationField verificationField = verificationFieldIt.next();
if((verificationField.requireField()).equals(name)){
verificationFieldIt.remove();
}
}
data.putResultData(encodeTargetValues(targetValues, scalarLabel));
} else
if(outputValues != null){
data.putResultData(encodeOutputValues(outputValues));
} else
{
break verification;
}
model.setModelVerification(ModelUtil.createModelVerification(data));
}
PMML pmml = encoder.encodePMML(model);
return pmml;
}
protected Map> encodeActiveValues(RGenericVector dataFrame){
return encodeVerificationData(dataFrame);
}
protected Map> encodeTargetValues(RGenericVector dataFrame, ScalarLabel scalarLabel){
List columns = dataFrame.getValues();
String name = scalarLabel.getName();
return encodeVerificationData(columns, Collections.singletonList(name));
}
protected Map> encodeOutputValues(RGenericVector dataFrame){
return encodeVerificationData(dataFrame);
}
static
protected Map> encodeVerificationData(RGenericVector dataFrame){
List columns = dataFrame.getValues();
RStringVector columnNames = dataFrame.names();
return encodeVerificationData(columns, columnNames.getDequotedValues());
}
static
protected Map> encodeVerificationData(List extends RExp> columns, List names){
Map> result = new LinkedHashMap<>();
for(int i = 0; i < columns.size(); i++){
String name = names.get(i);
RVector> column = (RVector>)columns.get(i);
List> values;
if(column instanceof RDoubleVector){
Function function = new Function(){
@Override
public Double apply(Double value){
if(value.isNaN()){
return null;
}
return value;
}
};
values = Lists.transform((List)column.getValues(), function);
} else
if(column instanceof RFactorVector){
RFactorVector factor = (RFactorVector)column;
values = factor.getFactorValues();
} else
{
values = column.getValues();
}
VerificationField verificationField = ModelUtil.createVerificationField(name);
result.put(verificationField, values);
}
return result;
}
static
private class VerificationMap extends LinkedHashMap> {
private Double precision = null;
private Double zeroThreshold = null;
public VerificationMap(Double precision, Double zeroThreshold){
setPrecision(precision);
setZeroThreshold(zeroThreshold);
}
public void putInputData(Map> map){
putAll(map);
}
public void putResultData(Map> map){
Double precision = getPrecision();
Double zeroThreshold = getZeroThreshold();
Collection verificationFields = map.keySet();
for(VerificationField verificationField : verificationFields){
verificationField
.setPrecision(precision)
.setZeroThreshold(zeroThreshold);
}
putAll(map);
}
public double getPrecision(){
return this.precision;
}
private void setPrecision(double precision){
this.precision = precision;
}
public double getZeroThreshold(){
return this.zeroThreshold;
}
private void setZeroThreshold(double zeroThreshold){
this.zeroThreshold = zeroThreshold;
}
}
}