All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.sparkml.model.MultilayerPerceptronClassificationModelConverter Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2016 Villu Ruusmann
 *
 * This file is part of JPMML-SparkML
 *
 * JPMML-SparkML is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-SparkML is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-SparkML.  If not, see .
 */
package org.jpmml.sparkml.model;

import java.util.ArrayList;
import java.util.List;

import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel;
import org.apache.spark.ml.linalg.Vector;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.neural_network.NeuralEntity;
import org.dmg.pmml.neural_network.NeuralInputs;
import org.dmg.pmml.neural_network.NeuralLayer;
import org.dmg.pmml.neural_network.NeuralNetwork;
import org.dmg.pmml.neural_network.Neuron;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.neural_network.NeuralNetworkUtil;
import org.jpmml.sparkml.ProbabilisticClassificationModelConverter;

public class MultilayerPerceptronClassificationModelConverter extends ProbabilisticClassificationModelConverter {

	public MultilayerPerceptronClassificationModelConverter(MultilayerPerceptronClassificationModel model){
		super(model);
	}

	@Override
	public NeuralNetwork encodeModel(Schema schema){
		MultilayerPerceptronClassificationModel model = getModel();

		int[] layers = model.getLayers();
		Vector weights = model.weights();

		CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
		List features = schema.getFeatures();

		SchemaUtil.checkSize(layers[layers.length - 1], categoricalLabel);
		SchemaUtil.checkSize(layers[0], features);

		NeuralInputs neuralInputs = NeuralNetworkUtil.createNeuralInputs(features, DataType.DOUBLE);

		List entities = neuralInputs.getNeuralInputs();

		List neuralLayers = new ArrayList<>();

		int weightPos = 0;

		for(int layer = 1; layer < layers.length; layer++){
			NeuralLayer neuralLayer = new NeuralLayer();

			int rows = entities.size();
			int columns = layers[layer];

			List> weightMatrix = new ArrayList<>();

			for(int column = 0; column < columns; column++){
				List weightVector = new ArrayList<>();

				for(int row = 0; row < rows; row++){
					weightVector.add(weights.apply(weightPos + (row * columns) + column));
				}

				weightMatrix.add(weightVector);
			}

			weightPos += (rows * columns);

			for(int column = 0; column < columns; column++){
				List weightVector = weightMatrix.get(column);
				Double bias = weights.apply(weightPos);

				Neuron neuron = NeuralNetworkUtil.createNeuron(entities, weightVector, bias)
					.setId(String.valueOf(layer) + "/" + String.valueOf(column + 1));

				neuralLayer.addNeurons(neuron);

				weightPos++;
			}

			if(layer == (layers.length - 1)){
				neuralLayer
					.setActivationFunction(NeuralNetwork.ActivationFunction.IDENTITY)
					.setNormalizationMethod(NeuralNetwork.NormalizationMethod.SOFTMAX);
			}

			neuralLayers.add(neuralLayer);

			entities = neuralLayer.getNeurons();
		}

		if(weightPos != weights.size()){
			throw new IllegalArgumentException();
		}

		NeuralNetwork neuralNetwork = new NeuralNetwork(MiningFunction.CLASSIFICATION, NeuralNetwork.ActivationFunction.LOGISTIC, ModelUtil.createMiningSchema(categoricalLabel), neuralInputs, neuralLayers)
			.setNeuralOutputs(NeuralNetworkUtil.createClassificationNeuralOutputs(entities, categoricalLabel));

		return neuralNetwork;
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy