org.jpmml.manager.NeuralNetworkManager Maven / Gradle / Ivy
/*
* Copyright (c) 2011 University of Tartu
*/
package org.jpmml.manager;
import java.util.*;
import org.dmg.pmml.*;
import com.google.common.collect.*;
import static com.google.common.base.Preconditions.*;
public class NeuralNetworkManager extends ModelManager implements HasEntityRegistry {
private NeuralNetwork neuralNetwork = null;
public NeuralNetworkManager() {
}
public NeuralNetworkManager(PMML pmml) {
this(pmml, find(pmml.getContent(), NeuralNetwork.class));
}
public NeuralNetworkManager(PMML pmml, NeuralNetwork neuralNetwork) {
super(pmml);
this.neuralNetwork = neuralNetwork;
}
@Override
public String getSummary(){
return "Neural network";
}
@Override
public NeuralNetwork getModel() {
checkState(this.neuralNetwork != null);
return this.neuralNetwork;
}
/**
* @see #getModel()
*/
public NeuralNetwork createModel(MiningFunctionType miningFunction, ActivationFunctionType activationFunction) {
checkState(this.neuralNetwork == null);
this.neuralNetwork = new NeuralNetwork(new MiningSchema(), new NeuralInputs(), miningFunction, activationFunction);
getModels().add(this.neuralNetwork);
return this.neuralNetwork;
}
public List getNeuralInputs() {
NeuralNetwork neuralNetwork = getModel();
return (neuralNetwork.getNeuralInputs()).getNeuralInputs();
}
/**
* @param id Unique identifier
*
* @see #getEntityRegistry()
*/
public NeuralInput addNeuralInput(String id, NormContinuous normContinuous) {
DerivedField derivedField = new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE);
derivedField.setExpression(normContinuous);
NeuralInput neuralInput = new NeuralInput(derivedField, id);
getNeuralInputs().add(neuralInput);
return neuralInput;
}
public List getNeuralLayers(){
NeuralNetwork neuralNetwork = getModel();
return neuralNetwork.getNeuralLayers();
}
public NeuralLayer addNeuralLayer() {
NeuralLayer neuralLayer = new NeuralLayer();
getNeuralLayers().add(neuralLayer);
return neuralLayer;
}
@Override
public BiMap getEntityRegistry(){
BiMap result = HashBiMap.create();
List neuralInputs = getNeuralInputs();
for(NeuralInput neuralInput : neuralInputs){
EntityUtil.put(neuralInput, result);
}
List neuralLayers = getNeuralLayers();
for(NeuralLayer neuralLayer : neuralLayers){
List neurons = neuralLayer.getNeurons();
for(Neuron neuron : neurons){
EntityUtil.put(neuron, result);
}
}
return result;
}
/**
* @param id Unique identifier
*
* @see #getEntityRegistry()
*/
static
public Neuron addNeuron(NeuralLayer neuralLayer, String id, Double bias) {
Neuron neuron = new Neuron(id);
neuron.setBias(bias);
(neuralLayer.getNeurons()).add(neuron);
return neuron;
}
static
public void addConnection(NeuralInput from, Neuron to, double weight) {
Connection connection = new Connection(from.getId(), weight);
(to.getConnections()).add(connection);
}
static
public void addConnection(Neuron from, Neuron to, double weight) {
Connection connection = new Connection(from.getId(), weight);
(to.getConnections()).add(connection);
}
public List getOrCreateNeuralOutputs() {
NeuralNetwork neuralNetwork = getModel();
NeuralOutputs neuralOutputs = neuralNetwork.getNeuralOutputs();
if(neuralOutputs == null){
neuralOutputs = new NeuralOutputs();
neuralNetwork.setNeuralOutputs(neuralOutputs);
}
return neuralOutputs.getNeuralOutputs();
}
public NeuralOutput addNeuralOutput(Neuron neuron, NormContinuous normCountinuous) {
DerivedField derivedField = new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE);
derivedField.setExpression(normCountinuous);
NeuralOutput output = new NeuralOutput(derivedField, neuron.getId());
getOrCreateNeuralOutputs().add(output);
return output;
}
}