All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.regression.NeuralNetwork Maven / Gradle / Ivy

There is a newer version: 2024.12.1
Show newest version
/*******************************************************************************
 * Copyright (c) 2010 Haifeng Li
 * Modifications copyright (C) 2017 Sam Erickson
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *******************************************************************************/

package smile.regression;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.Math;

import java.io.Serializable;

/**
 * Multilayer perceptron neural network for regression.
 * An MLP consists of several layers of nodes, interconnected through weighted
 * acyclic arcs from each preceding layer to the following, without lateral or
 * feedback connections. Each node calculates a transformed weighted linear
 * combination of its inputs (output activations from the preceding layer), with
 * one of the weights acting as a trainable bias connected to a constant input.
 * The transformation, called activation function, is a bounded non-decreasing
 * (non-linear) function, such as the sigmoid functions (ranges from 0 to 1).
 * Another popular activation function is hyperbolic tangent which is actually
 * equivalent to the sigmoid function in shape but ranges from -1 to 1.
 *
 * @author Sam Erickson
 */
public class NeuralNetwork implements OnlineRegression {
   private static final long serialVersionUID = 1L;
   private static final Logger logger = LoggerFactory.getLogger(NeuralNetwork.class);

   public enum ActivationFunction {
       /**
        * Logistic sigmoid activation function (default): sigma(v)=1/(1+exp(-v))
        */
       LOGISTIC_SIGMOID,
       /**
        * Hyperbolic tangent activation function: f(v)=tanh(v)
        */
       TANH
   }

   private class Layer implements Serializable {
       private static final long serialVersionUID = 1L;

       /**
        * number of units in this layer
        */
       int units;
       /**
        * output of ith unit
        */
       double[] output;
       /**
        * error term of ith unit
        */
       double[] error;
       /**
        * connection weights to ith unit from previous layer
        */
       double[][] weight;
       /**
        * last weight changes for momentum
        */
       double[][] delta;
   }
   /**
    * The type of activation function in output layer.
    */
   private ActivationFunction activationFunction = ActivationFunction.LOGISTIC_SIGMOID;
   /**
    * The dimensionality of data.
    */
   private int p;
   /**
    * layers of this net
    */
   private Layer[] net;
   /**
    * input layer
    */
   private Layer inputLayer;
   /**
    * output layer
    */
   private Layer outputLayer;
   /**
    * learning rate
    */
   private double eta = 0.1;
   /**
    * momentum factor
    */
   private double alpha = 0.0;
   /**
    * weight decay factor, which is also a regularization term.
    */
   private double lambda = 0.0;

   /**
    * Trainer for neural networks.
    */
   public static class Trainer extends RegressionTrainer {
       /**
        * The type of activation function in output layer.
        */
       private ActivationFunction activationFunction = ActivationFunction.LOGISTIC_SIGMOID;
       /**
        * The number of units in each layer.
        */
       private int[] numUnits;
       /**
        * learning rate
        */
       private double eta = 0.1;
       /**
        * momentum factor
        */
       private double alpha = 0.0;
       /**
        * weight decay factor, which is also a regularization term.
        */
       private double lambda = 0.0;
       /**
        * The number of epochs of stochastic learning.
        */
       private int epochs = 25;

       /**
        * Constructor. The default activation function is the logistic sigmoid function.
        *
        * @param numUnits the number of units in each layer.
        */
       public Trainer(int... numUnits) {
           this(ActivationFunction.LOGISTIC_SIGMOID, numUnits);
       }

       /**
        * Constructor.
        *
        * @param activation the activation function of output layer.
        * @param numUnits the number of units in each layer.
        */
       public Trainer(ActivationFunction activation, int... numUnits) {
           int numLayers = numUnits.length;
           if (numLayers < 2) {
               throw new IllegalArgumentException("Invalid number of layers: " + numLayers);
           }

           for (int i = 0; i < numLayers; i++) {
               if (numUnits[i] < 1) {
                   throw new IllegalArgumentException(String.format("Invalid number of units of layer %d: %d", i + 1, numUnits[i]));
               }
           }

           if (numUnits[numLayers - 1]!=1){
               throw new IllegalArgumentException(String.format("Invalid number of units in output layer %d",numUnits[numLayers - 1]));
           }

           this.activationFunction = activation;
           this.numUnits = numUnits;
       }

       /**
        * Sets the learning rate.
        * @param eta the learning rate.
        */
       public Trainer setLearningRate(double eta) {
           if (eta <= 0) {
               throw new IllegalArgumentException("Invalid learning rate: " + eta);
           }
           this.eta = eta;
           return this;
       }

       /**
        * Sets the momentum factor.
        * @param alpha the momentum factor.
        */
       public Trainer setMomentum(double alpha) {
           if (alpha < 0.0 || alpha >= 1.0) {
               throw new IllegalArgumentException("Invalid momentum factor: " + alpha);
           }

           this.alpha = alpha;
           return this;
       }

       /**
        * Sets the weight decay factor. After each weight update, every weight
        * is simply ''decayed'' or shrunk according w = w * (1 - eta * lambda).
        * @param lambda the weight decay for regularization.
        */
       public Trainer setWeightDecay(double lambda) {
           if (lambda < 0.0 || lambda > 0.1) {
               throw new IllegalArgumentException("Invalid weight decay factor: " + lambda);
           }

           this.lambda = lambda;
           return this;
       }

       /**
        * Sets the number of epochs of stochastic learning.
        * @param epochs the number of epochs of stochastic learning.
        */
       public Trainer setNumEpochs(int epochs) {
           if (epochs < 1) {
               throw new IllegalArgumentException("Invalid numer of epochs of stochastic learning:" + epochs);
           }

           this.epochs = epochs;
           return this;
       }

       @Override
       public NeuralNetwork train(double[][] x, double[] y) {
           NeuralNetwork net = new NeuralNetwork(activationFunction, numUnits);
           net.setLearningRate(eta);
           net.setMomentum(alpha);
           net.setWeightDecay(lambda);

           for (int i = 1; i <= epochs; i++) {
               net.learn(x, y);
               logger.info("Neural network learns epoch {}", i);
           }

           return net;
       }
   }

   /**
    * Constructor. The default activation function is the logistic sigmoid function.
    *
    * @param numUnits the number of units in each layer.
    */
   public NeuralNetwork(int... numUnits) {
       this(ActivationFunction.LOGISTIC_SIGMOID, numUnits);
   }
   /**
    * Constructor.
    *
    * @param activation the activation function of output layer.
    * @param numUnits the number of units in each layer.
    */
   public NeuralNetwork(ActivationFunction activation, int... numUnits) {
       this(activation,0.0001,0.9,numUnits);
   }

   /**
    * Constructor.
    *
    * @param activation the activation function of output layer.
    * @param numUnits the number of units in each layer.
    */
   public NeuralNetwork(ActivationFunction activation, double alpha, double lambda, int... numUnits) {
       int numLayers = numUnits.length;
       if (numLayers < 2) {
           throw new IllegalArgumentException("Invalid number of layers: " + numLayers);
       }

       for (int i = 0; i < numLayers; i++) {
           if (numUnits[i] < 1) {
               throw new IllegalArgumentException(String.format("Invalid number of units of layer %d: %d", i+1, numUnits[i]));
           }
       }

       if (numUnits[numLayers - 1]!=1){
           throw new IllegalArgumentException(String.format("Invalid number of units in output layer %d",numUnits[numLayers - 1]));
       }

       this.activationFunction = activation;

       this.alpha = alpha;
       this.lambda = lambda;

       this.p = numUnits[0];

       net = new Layer[numLayers];
       for (int i = 0; i < numLayers; i++) {
           net[i] = new Layer();
           net[i].units = numUnits[i];
           net[i].output = new double[numUnits[i] + 1];
           net[i].error = new double[numUnits[i] + 1];
           net[i].output[numUnits[i]] = 1.0;
       }

       inputLayer = net[0];
       outputLayer = net[numLayers - 1];

       // Initialize random weights.
       for (int l = 1; l < numLayers; l++) {
           net[l].weight = new double[numUnits[l]][numUnits[l - 1] + 1];
           net[l].delta = new double[numUnits[l]][numUnits[l - 1] + 1];
           double r = 1.0 / Math.sqrt(net[l - 1].units);
           for (int i = 0; i < net[l].units; i++) {
               for (int j = 0; j <= net[l - 1].units; j++) {
                   net[l].weight[i][j] = Math.random(-r, r);
               }
           }
       }
   }

   /**
    * Private constructor for clone purpose.
    */
   private NeuralNetwork() {

   }

   @Override
   public NeuralNetwork clone() {
       NeuralNetwork copycat = new NeuralNetwork();

       copycat.activationFunction = activationFunction;
       copycat.p = p;
       copycat.eta = eta;
       copycat.alpha = alpha;
       copycat.lambda = lambda;

       int numLayers = net.length;
       copycat.net = new Layer[numLayers];
       for (int i = 0; i < numLayers; i++) {
           copycat.net[i] = new Layer();
           copycat.net[i].units = net[i].units;
           copycat.net[i].output = net[i].output.clone();
           copycat.net[i].error = net[i].error.clone();
           if (i > 0) {
               copycat.net[i].weight = Math.clone(net[i].weight);
               copycat.net[i].delta = Math.clone(net[i].delta);
           }
       }

       copycat.inputLayer = copycat.net[0];
       copycat.outputLayer = copycat.net[numLayers - 1];

       return copycat;
   }

   /**
    * Sets the learning rate.
    * @param eta the learning rate.
    */
   public void setLearningRate(double eta) {
       if (eta <= 0) {
           throw new IllegalArgumentException("Invalid learning rate: " + eta);
       }
       this.eta = eta;
   }

   /**
    * Returns the learning rate.
    */
   public double getLearningRate() {
       return eta;
   }

   /**
    * Sets the momentum factor.
    * @param alpha the momentum factor.
    */
   public void setMomentum(double alpha) {
       if (alpha < 0.0 || alpha >= 1.0) {
           throw new IllegalArgumentException("Invalid momentum factor: " + alpha);
       }

       this.alpha = alpha;
   }

   /**
    * Returns the momentum factor.
    */
   public double getMomentum() {
       return alpha;
   }

   /**
    * Sets the weight decay factor. After each weight update, every weight
    * is simply ''decayed'' or shrunk according w = w * (1 - eta * lambda).
    * @param lambda the weight decay for regularization.
    */
   public void setWeightDecay(double lambda) {
       if (lambda < 0.0 || lambda > 0.1) {
           throw new IllegalArgumentException("Invalid weight decay factor: " + lambda);
       }

       this.lambda = lambda;
   }

   /**
    * Returns the weight decay factor.
    */
   public double getWeightDecay() {
       return lambda;
   }

   /**
    * Returns the weights of a layer.
    * @param layer the layer of netural network, 0 for input layer.
    */
   public double[][] getWeight(int layer) {
       return net[layer].weight;
   }

   /**
    * Sets the input vector into the input layer.
    * @param x the input vector.
    */
   private void setInput(double[] x) {
       if (x.length != inputLayer.units) {
           throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, inputLayer.units));
       }
       System.arraycopy(x, 0, inputLayer.output, 0, inputLayer.units);
   }

   /**
    * Propagates signals from a lower layer to the next upper layer.
    * @param lower the lower layer where signals are from.
    * @param upper the upper layer where signals are propagated to.
    */
   private void propagate(Layer lower, Layer upper) {
       for (int i = 0; i < upper.units; i++) {
           double sum = 0.0;
           for (int j = 0; j <= lower.units; j++) {
               sum += upper.weight[i][j] * lower.output[j];
           }

           if (upper == outputLayer) {
               upper.output[i] = sum;
           }

           else {
               if (activationFunction == ActivationFunction.LOGISTIC_SIGMOID) {
                   upper.output[i] = Math.logistic(sum);
               }
               else if (activationFunction== ActivationFunction.TANH){
                   upper.output[i]=(2*Math.logistic(2*sum))-1;
               }
           }
       }
   }

   /**
    * Propagates the signals through the neural network.
    */
   private void propagate() {
       for (int l = 0; l < net.length - 1; l++) {
           propagate(net[l], net[l + 1]);
       }
   }

   /**
    * Compute the network output error.
    * @param output the desired output.
    */
   private double computeOutputError(double output) {
       return computeOutputError(output, outputLayer.error);
   }

   /**
    * Compute the network output error.
    * @param output the desired output.
    * @param gradient the array to store gradient on output.
    * @return the error defined by loss function.
    */
   private double computeOutputError(double output, double[] gradient) {

       double error = 0.0;
       double out = outputLayer.output[0];
       double g = output - out;

       error += (0.5*g * g);


       gradient[0] = g;

       return error;
   }

   /**
    * Propagates the errors back from a upper layer to the next lower layer.
    * @param upper the lower layer where errors are from.
    * @param lower the upper layer where errors are propagated back to.
    */
   private void backpropagate(Layer upper, Layer lower) {
       for (int i = 0; i <= lower.units; i++) {
           double out = lower.output[i];
           double err = 0;
           for (int j = 0; j < upper.units; j++) {
               err += upper.weight[j][i] * upper.error[j];
           }
           if (activationFunction== ActivationFunction.LOGISTIC_SIGMOID) {
               lower.error[i] = out * (1.0 - out) * err;
           }
           else if (activationFunction== ActivationFunction.TANH){
               lower.error[i] = (1-(out*out))*err;
           }
       }
   }

   /**
    * Propagates the errors back through the network.
    */
   private void backpropagate() {
       for (int l = net.length; --l > 0;) {
           backpropagate(net[l], net[l - 1]);
       }
   }

   /**
    * Adjust network weights by back-propagation algorithm.
    */
   private void adjustWeights() {
       for (int l = 1; l < net.length; l++) {
           for (int i = 0; i < net[l].units; i++) {
               for (int j = 0; j <= net[l - 1].units; j++) {
                   double out = net[l - 1].output[j];
                   double err = net[l].error[i];
                   double delta = (1 - alpha) * eta * err * out + alpha * net[l].delta[i][j];
                   net[l].delta[i][j] = delta;
                   net[l].weight[i][j] += delta;
                   if (lambda != 0.0 && j < net[l-1].units) {
                       net[l].weight[i][j] *= (1.0 - eta * lambda);
                   }
               }
           }
       }
   }

   @Override
   public double predict(double[] x) {
       setInput(x);
       propagate();
       return outputLayer.output[0];
   }


   /**
    * Update the neural network with given instance and associated target value.
    * Note that this method is NOT multi-thread safe.
    * @param x the training instance.
    * @param y the target value.
    * @param weight a positive weight value associated with the training instance.
    * @return the weighted training error before back-propagation.
    */
   public double learn(double[] x, double y, double weight) {
       setInput(x);
       propagate();

       double err = weight * computeOutputError(y);

       if (weight != 1.0) {
           outputLayer.error[0] *= weight;
       }

       backpropagate();
       adjustWeights();
       return err;
   }

   @Override
   public void learn(double[] x, double y) {
       learn(x, y, 1.0);
   }


   /**
    * Trains the neural network with the given dataset for one epoch by
    * stochastic gradient descent.
    *
    * @param x training instances.
    * @param y training labels in [0, k), where k is the number of classes.
    */
   public void learn(double[][] x, double[] y) {
       int n = x.length;
       int[] index = Math.permutate(n);
       for (int i = 0; i < n; i++) {
           learn(x[index[i]], y[index[i]]);
       }
   }
}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy