org.jpmml.rexp.NNetConverter Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of pmml-rexp Show documentation
Show all versions of pmml-rexp Show documentation
JPMML R to PMML converter
The newest version!
/*
* Copyright (c) 2018 Villu Ruusmann
*
* This file is part of JPMML-R
*
* JPMML-R is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-R is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-R. If not, see .
*/
package org.jpmml.rexp;
import java.util.ArrayList;
import java.util.List;
import com.google.common.collect.Iterables;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.neural_network.NeuralEntity;
import org.dmg.pmml.neural_network.NeuralInputs;
import org.dmg.pmml.neural_network.NeuralLayer;
import org.dmg.pmml.neural_network.NeuralNetwork;
import org.dmg.pmml.neural_network.Neuron;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.neural_network.NeuralNetworkUtil;
public class NNetConverter extends ModelConverter {
public NNetConverter(RGenericVector nnet){
super(nnet);
}
@Override
public void encodeSchema(RExpEncoder encoder){
RGenericVector nnet = getObject();
RStringVector lev = nnet.getStringElement("lev", false);
RExp terms = nnet.getElement("terms");
RGenericVector xlevels = nnet.getGenericElement("xlevels");
RStringVector coefnames = nnet.getStringElement("coefnames");
FormulaContext context = new XLevelsFormulaContext(xlevels);
Formula formula = FormulaUtil.createFormula(terms, context, encoder);
FormulaUtil.setLabel(formula, terms, lev, encoder);
FormulaUtil.addFeatures(formula, coefnames, true, encoder);
}
@Override
public Model encodeModel(Schema schema){
RGenericVector nnet = getObject();
RDoubleVector n = nnet.getDoubleElement("n");
RBooleanVector linout = nnet.getBooleanElement("linout", false);
RBooleanVector softmax = nnet.getBooleanElement("softmax", false);
RBooleanVector censored = nnet.getBooleanElement("censored", false);
RDoubleVector wts = nnet.getDoubleElement("wts");
RStringVector lev = nnet.getStringElement("lev", false);
n.checkSize(3);
Label label = schema.getLabel();
List extends Feature> features = schema.getFeatures();
MiningFunction miningFunction;
if(lev == null){
if(linout != null && !linout.asScalar()){
throw new IllegalArgumentException();
}
miningFunction = MiningFunction.REGRESSION;
} else
{
miningFunction = MiningFunction.CLASSIFICATION;
}
int nInput = ValueUtil.asInt(n.getValue(0));
SchemaUtil.checkSize(nInput, features);
NeuralInputs neuralInputs = NeuralNetworkUtil.createNeuralInputs(features, DataType.DOUBLE);
int offset = 0;
List neuralLayers = new ArrayList<>();
List extends NeuralEntity> entities = neuralInputs.getNeuralInputs();
int nHidden = ValueUtil.asInt(n.getValue(1));
if(nHidden > 0){
NeuralLayer neuralLayer = encodeNeuralLayer("hidden", nHidden, entities, wts, offset)
.setActivationFunction(NeuralNetwork.ActivationFunction.LOGISTIC);
offset += (nHidden * (entities.size() + 1));
neuralLayers.add(neuralLayer);
entities = neuralLayer.getNeurons();
}
int nOutput = ValueUtil.asInt(n.getValue(2));
if(nOutput == 1){
NeuralLayer neuralLayer = encodeNeuralLayer("output", nOutput, entities, wts, offset);
offset += (nOutput * (entities.size() + 1));
neuralLayers.add(neuralLayer);
entities = neuralLayer.getNeurons();
switch(miningFunction){
case REGRESSION:
break;
case CLASSIFICATION:
{
List transformationNeuralLayers = NeuralNetworkUtil.createBinaryLogisticTransformation(Iterables.getOnlyElement(entities));
neuralLayers.addAll(transformationNeuralLayers);
neuralLayer = Iterables.getLast(transformationNeuralLayers);
entities = neuralLayer.getNeurons();
}
break;
}
} else
if(nOutput > 1){
NeuralLayer neuralLayer = encodeNeuralLayer("output", nOutput, entities, wts, offset);
if(softmax != null && softmax.asScalar()){
if(censored != null && censored.asScalar()){
throw new IllegalArgumentException();
}
neuralLayer.setNormalizationMethod(NeuralNetwork.NormalizationMethod.SOFTMAX);
}
offset += (nOutput * (entities.size() + 1));
neuralLayers.add(neuralLayer);
entities = neuralLayer.getNeurons();
} else
{
throw new IllegalArgumentException();
}
NeuralNetwork neuralNetwork = new NeuralNetwork(miningFunction, NeuralNetwork.ActivationFunction.IDENTITY, ModelUtil.createMiningSchema(label), neuralInputs, neuralLayers);
switch(miningFunction){
case REGRESSION:
neuralNetwork
.setNeuralOutputs(NeuralNetworkUtil.createRegressionNeuralOutputs(entities, (ContinuousLabel)label));
break;
case CLASSIFICATION:
neuralNetwork
.setNeuralOutputs(NeuralNetworkUtil.createClassificationNeuralOutputs(entities, (CategoricalLabel)label))
.setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, (CategoricalLabel)label));
break;
}
return neuralNetwork;
}
static
private NeuralLayer encodeNeuralLayer(String prefix, int n, List extends NeuralEntity> entities, RDoubleVector wts, int offset){
NeuralLayer neuralLayer = new NeuralLayer();
for(int i = 0; i < n; i++){
List weights = (wts.getValues()).subList(offset + 1, offset + (entities.size() + 1));
Double bias = wts.getValue(offset);
Neuron neuron = NeuralNetworkUtil.createNeuron(entities, weights, bias)
.setId(prefix + "/" + String.valueOf(i + 1));
neuralLayer.addNeurons(neuron);
offset += (entities.size() + 1);
}
return neuralLayer;
}
}