All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.classification.NeuralNetwork Maven / Gradle / Ivy

The newest version!
/*******************************************************************************
 * Copyright (c) 2010 Haifeng Li
 *   
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *  
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *******************************************************************************/
package smile.classification;

import smile.math.Math;

/**
 * Multilayer perceptron neural network. 
 * An MLP consists of several layers of nodes, interconnected through weighted
 * acyclic arcs from each preceding layer to the following, without lateral or
 * feedback connections. Each node calculates a transformed weighted linear
 * combination of its inputs (output activations from the preceding layer), with
 * one of the weights acting as a trainable bias connected to a constant input.
 * The transformation, called activation function, is a bounded non-decreasing
 * (non-linear) function, such as the sigmoid functions (ranges from 0 to 1).
 * Another popular activation function is hyperbolic tangent which is actually
 * equivalent to the sigmoid function in shape but ranges from -1 to 1. 
 * More specialized activation functions include radial basis functions which
 * are used in RBF networks.
 * 

* The representational capabilities of a MLP are determined by the range of * mappings it may implement through weight variation. Single layer perceptrons * are capable of solving only linearly separable problems. With the sigmoid * function as activation function, the single-layer network is identical * to the logistic regression model. *

* The universal approximation theorem for neural networks states that every * continuous function that maps intervals of real numbers to some output * interval of real numbers can be approximated arbitrarily closely by a * multi-layer perceptron with just one hidden layer. This result holds only * for restricted classes of activation functions, which are extremely complex * and NOT smooth for subtle mathematical reasons. On the other hand, smoothness * is important for gradient descent learning. Besides, the proof is not * constructive regarding the number of neurons required or the settings of * the weights. Therefore, complex systems will have more layers of neurons * with some having increased layers of input neurons and output neurons * in practice. *

* The most popular algorithm to train MLPs is back-propagation, which is a * gradient descent method. Based on chain rule, the algorithm propagates the * error back through the network and adjusts the weights of each connection in * order to reduce the value of the error function by some small amount. * For this reason, back-propagation can only be applied on networks with * differentiable activation functions. *

* During error back propagation, we usually times the gradient with a small * number η, called learning rate, which is carefully selected to ensure * that the network converges to a local minimum of the error function * fast enough, without producing oscillations. One way to avoid oscillation * at large η, is to make the change in weight dependent on the past weight * change by adding a momentum term. *

* Although the back-propagation algorithm may performs gradient * descent on the total error of all instances in a batch way, * the learning rule is often applied to each instance separately in an online * way or stochastic way. There exists empirical indication that the stochastic * way results in faster convergence. *

* In practice, the problem of over-fitting has emerged. This arises in * convoluted or over-specified systems when the capacity of the network * significantly exceeds the needed free parameters. There are two general * approaches for avoiding this problem: The first is to use cross-validation * and similar techniques to check for the presence of over-fitting and * optimally select hyper-parameters such as to minimize the generalization * error. The second is to use some form of regularization, which emerges * naturally in a Bayesian framework, where the regularization can be * performed by selecting a larger prior probability over simpler models; * but also in statistical learning theory, where the goal is to minimize over * the "empirical risk" and the "structural risk". *

* For neural networks, the input patterns usually should be scaled/standardized. * Commonly, each input variable is scaled into interval [0, 1] or to have * mean 0 and standard deviation 1. *

* For penalty functions and output units, the following natural pairings are * recommended: *

    *
  • linear output units and a least squares penalty function. *
  • a two-class cross-entropy penalty function and a logistic * activation function. *
  • a multi-class cross-entropy penalty function and a softmax * activation function. *
* By assigning a softmax activation function on the output layer of * the neural network for categorical target variables, the outputs * can be interpreted as posterior probabilities, which are very useful. * * @author Haifeng Li */ public class NeuralNetwork implements OnlineClassifier { /** * The types of error functions. */ public enum ErrorFunction { /** * Least mean squares error function. */ LEAST_MEAN_SQUARES, /** * Cross entropy error function for output as probabilities. */ CROSS_ENTROPY; } /** * The types of activation functions in output layer. In this implementation, * the hidden layers always employs logistic sigmoid activation function. */ public enum ActivationFunction { /** * Linear activation function. */ LINEAR, /** * Logistic sigmoid activation function. For multi-class classification, * each unit in output layer corresponds to a class. For binary * classification and cross entropy error function, there is only * one output unit whose value can be regarded as posteriori probability. */ LOGISTIC_SIGMOID, /** * Softmax activation for multi-class cross entropy objection function. * The values of units in output layer can be regarded as posteriori * probabilities of each class. */ SOFTMAX; } /** * A layer of a feed forward neural network. */ private class Layer { /** * number of units in this layer */ int units; /** * output of ith unit */ double[] output; /** * error term of ith unit */ double[] error; /** * connection weights to ith unit from previous layer */ double[][] weight; /** * last weight changes for momentum */ double[][] delta; } /** * The type of error function of network. */ private ErrorFunction errorFunction = ErrorFunction.LEAST_MEAN_SQUARES; /** * The type of activation function in output layer. */ private ActivationFunction activationFunction = ActivationFunction.LOGISTIC_SIGMOID; /** * The dimensionality of data. */ private int p; /** * The number of classes. */ private int k; /** * layers of this net */ private Layer[] net; /** * input layer */ private Layer inputLayer; /** * output layer */ private Layer outputLayer; /** * learning rate */ private double eta = 0.1; /** * momentum factor */ private double alpha = 0.0; /** * weight decay factor, which is also a regularization term. */ private double lambda = 0.0; /** * The buffer to store target value of training instance. */ private double[] target; /** * Trainer for neural networks. */ public static class Trainer extends ClassifierTrainer { /** * The type of error function of network. */ private ErrorFunction errorFunction = ErrorFunction.LEAST_MEAN_SQUARES; /** * The type of activation function in output layer. */ private ActivationFunction activationFunction = ActivationFunction.LOGISTIC_SIGMOID; /** * The number of units in each layer. */ private int[] numUnits; /** * learning rate */ private double eta = 0.1; /** * momentum factor */ private double alpha = 0.0; /** * weight decay factor, which is also a regularization term. */ private double lambda = 0.0; /** * The number of epochs of stochastic learning. */ private int epochs = 25; /** * Constructor. The activation function of output layer will be chosen * by natural pairing based on the error function and the number of * classes. * * @param error the error function. * @param numUnits the number of units in each layer. */ public Trainer(ErrorFunction error, int... numUnits) { this(error, natural(error, numUnits[numUnits.length-1]), numUnits); } /** * Constructor. * * @param error the error function. * @param activation the activation function of output layer. * @param numUnits the number of units in each layer. */ public Trainer(ErrorFunction error, ActivationFunction activation, int... numUnits) { int numLayers = numUnits.length; if (numLayers < 2) { throw new IllegalArgumentException("Invalid number of layers: " + numLayers); } for (int i = 0; i < numLayers; i++) { if (numUnits[i] < 1) { throw new IllegalArgumentException(String.format("Invalid number of units of layer %d: %d", i + 1, numUnits[i])); } } if (error == ErrorFunction.LEAST_MEAN_SQUARES) { if (activation == ActivationFunction.SOFTMAX) { throw new IllegalArgumentException("Sofmax activation function is invalid for least mean squares error."); } } if (error == ErrorFunction.CROSS_ENTROPY) { if (activation == ActivationFunction.LINEAR) { throw new IllegalArgumentException("Linear activation function is invalid with cross entropy error."); } if (activation == ActivationFunction.SOFTMAX && numUnits[numLayers - 1] == 1) { throw new IllegalArgumentException("Softmax activation function is for multi-class."); } if (activation == ActivationFunction.LOGISTIC_SIGMOID && numUnits[numLayers - 1] != 1) { throw new IllegalArgumentException("For cross entropy error, logistic sigmoid output is for binary classification."); } } this.errorFunction = error; this.activationFunction = activation; this.numUnits = numUnits; } /** * Sets the learning rate. * @param eta the learning rate. */ public Trainer setLearningRate(double eta) { if (eta <= 0) { throw new IllegalArgumentException("Invalid learning rate: " + eta); } this.eta = eta; return this; } /** * Sets the momentum factor. * @param alpha the momentum factor. */ public Trainer setMomentum(double alpha) { if (alpha < 0.0 || alpha >= 1.0) { throw new IllegalArgumentException("Invalid momentum factor: " + alpha); } this.alpha = alpha; return this; } /** * Sets the weight decay factor. After each weight update, every weight * is simply ''decayed'' or shrunk according w = w * (1 - eta * lambda). * @param lambda the weight decay for regularization. */ public Trainer setWeightDecay(double lambda) { if (lambda < 0.0 || lambda > 0.1) { throw new IllegalArgumentException("Invalid momentum factor: " + alpha); } this.lambda = lambda; return this; } /** * Sets the number of epochs of stochastic learning. * @param epochs the number of epochs of stochastic learning. */ public Trainer setNumEpochs(int epochs) { if (epochs < 1) { throw new IllegalArgumentException("Invlaid numer of epochs of stochastic learning:" + epochs); } this.epochs = epochs; return this; } @Override public NeuralNetwork train(double[][] x, int[] y) { NeuralNetwork net = new NeuralNetwork(errorFunction, activationFunction, numUnits); net.setLearningRate(eta); net.setMomentum(alpha); net.setWeightDecay(lambda); for (int i = 1; i <= epochs; i++) { net.learn(x, y); System.out.format("Neural network learning done epoch %d\n", i); } return net; } } /** * Constructor. The activation function of output layer will be chosen * by natural pairing based on the error function and the number of * classes. * * @param error the error function. * @param numUnits the number of units in each layer. */ public NeuralNetwork(ErrorFunction error, int... numUnits) { this(error, natural(error, numUnits[numUnits.length-1]), numUnits); } /** * Returns the activation function of output layer based on natural pairing. * @param error the error function. * @param k the number of output nodes. * @return the activation function of output layer based on natural pairing */ private static ActivationFunction natural(ErrorFunction error, int k) { if (error == ErrorFunction.CROSS_ENTROPY) { if (k == 1) { return ActivationFunction.LOGISTIC_SIGMOID; } else { return ActivationFunction.SOFTMAX; } } else { return ActivationFunction.LOGISTIC_SIGMOID; } } /** * Constructor. * * @param error the error function. * @param activation the activation function of output layer. * @param numUnits the number of units in each layer. */ public NeuralNetwork(ErrorFunction error, ActivationFunction activation, int... numUnits) { int numLayers = numUnits.length; if (numLayers < 2) { throw new IllegalArgumentException("Invalid number of layers: " + numLayers); } for (int i = 0; i < numLayers; i++) { if (numUnits[i] < 1) { throw new IllegalArgumentException(String.format("Invalid number of units of layer %d: %d", i+1, numUnits[i])); } } if (error == ErrorFunction.LEAST_MEAN_SQUARES) { if (activation == ActivationFunction.SOFTMAX) { throw new IllegalArgumentException("Sofmax activation function is invalid for least mean squares error."); } } if (error == ErrorFunction.CROSS_ENTROPY) { if (activation == ActivationFunction.LINEAR) { throw new IllegalArgumentException("Linear activation function is invalid with cross entropy error."); } if (activation == ActivationFunction.SOFTMAX && numUnits[numLayers-1] == 1) { throw new IllegalArgumentException("Softmax activation function is for multi-class."); } if (activation == ActivationFunction.LOGISTIC_SIGMOID && numUnits[numLayers-1] != 1) { throw new IllegalArgumentException("For cross entropy error, logistic sigmoid output is for binary classification."); } } this.errorFunction = error; this.activationFunction = activation; if (error == ErrorFunction.CROSS_ENTROPY) { this.alpha = 0.0; this.lambda = 0.0; } this.p = numUnits[0]; this.k = numUnits[numLayers - 1] == 1 ? 2 : numUnits[numLayers - 1]; this.target = new double[numUnits[numLayers - 1]]; net = new Layer[numLayers]; for (int i = 0; i < numLayers; i++) { net[i] = new Layer(); net[i].units = numUnits[i]; net[i].output = new double[numUnits[i] + 1]; net[i].error = new double[numUnits[i] + 1]; net[i].output[numUnits[i]] = 1.0; } inputLayer = net[0]; outputLayer = net[numLayers - 1]; // Initialize random weights. for (int l = 1; l < numLayers; l++) { net[l].weight = new double[numUnits[l]][numUnits[l - 1] + 1]; net[l].delta = new double[numUnits[l]][numUnits[l - 1] + 1]; double r = 1.0 / Math.sqrt(net[l - 1].units); for (int i = 0; i < net[l].units; i++) { for (int j = 0; j <= net[l - 1].units; j++) { net[l].weight[i][j] = Math.random(-r, r); } } } } /** * Private constructor for clone purpose. */ private NeuralNetwork() { } @Override public NeuralNetwork clone() { NeuralNetwork copycat = new NeuralNetwork(); copycat.errorFunction = errorFunction; copycat.activationFunction = activationFunction; copycat.p = p; copycat.k = k; copycat.eta = eta; copycat.alpha = alpha; copycat.lambda = lambda; copycat.target = target.clone(); int numLayers = net.length; copycat.net = new Layer[numLayers]; for (int i = 0; i < numLayers; i++) { copycat.net[i] = new Layer(); copycat.net[i].units = net[i].units; copycat.net[i].output = net[i].output.clone(); copycat.net[i].error = net[i].error.clone(); if (i > 0) { copycat.net[i].weight = Math.clone(net[i].weight); copycat.net[i].delta = Math.clone(net[i].delta); } } copycat.inputLayer = copycat.net[0]; copycat.outputLayer = copycat.net[numLayers - 1]; return copycat; } /** * Sets the learning rate. * @param eta the learning rate. */ public void setLearningRate(double eta) { if (eta <= 0) { throw new IllegalArgumentException("Invalid learning rate: " + eta); } this.eta = eta; } /** * Returns the learning rate. */ public double getLearningRate() { return eta; } /** * Sets the momentum factor. * @param alpha the momentum factor. */ public void setMomentum(double alpha) { if (alpha < 0.0 || alpha >= 1.0) { throw new IllegalArgumentException("Invalid momentum factor: " + alpha); } this.alpha = alpha; } /** * Returns the momentum factor. */ public double getMomentum() { return alpha; } /** * Sets the weight decay factor. After each weight update, every weight * is simply ''decayed'' or shrunk according w = w * (1 - eta * lambda). * @param lambda the weight decay for regularization. */ public void setWeightDecay(double lambda) { if (lambda < 0.0 || lambda > 0.1) { throw new IllegalArgumentException("Invalid momentum factor: " + alpha); } this.lambda = lambda; } /** * Returns the weight decay factor. */ public double getWeightDecay() { return lambda; } /** * Sets the input vector into the input layer. * @param x the input vector. */ private void setInput(double[] x) { if (x.length != inputLayer.units) { throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, inputLayer.units)); } System.arraycopy(x, 0, inputLayer.output, 0, inputLayer.units); } /** * Returns the output vector into the given array. * @param y the output vector. */ private void getOutput(double[] y) { if (y.length != outputLayer.units) { throw new IllegalArgumentException(String.format("Invalid output vector size: %d, expected: %d", y.length, outputLayer.units)); } System.arraycopy(outputLayer.output, 0, y, 0, outputLayer.units); } /** * Propagates signals from a lower layer to the next upper layer. * @param lower the lower layer where signals are from. * @param upper the upper layer where signals are propagated to. */ private void propagate(Layer lower, Layer upper) { for (int i = 0; i < upper.units; i++) { double sum = 0.0; for (int j = 0; j <= lower.units; j++) { sum += upper.weight[i][j] * lower.output[j]; } if (upper != outputLayer || activationFunction == ActivationFunction.LOGISTIC_SIGMOID) { upper.output[i] = Math.logistic(sum); } else { if (activationFunction == ActivationFunction.LINEAR || activationFunction == ActivationFunction.SOFTMAX) { upper.output[i] = sum; } else { throw new UnsupportedOperationException("Unsupported activation function."); } } } if (upper == outputLayer && activationFunction == ActivationFunction.SOFTMAX) { softmax(); } } /** * Calculate softmax activation function in output layer without overflow. */ private void softmax() { double max = Double.NEGATIVE_INFINITY; for (int i = 0; i < outputLayer.units; i++) { if (outputLayer.output[i] > max) { max = outputLayer.output[i]; } } double sum = 0.0; for (int i = 0; i < outputLayer.units; i++) { double out = Math.exp(outputLayer.output[i] - max); outputLayer.output[i] = out; sum += out; } for (int i = 0; i < outputLayer.units; i++) { outputLayer.output[i] /= sum; } } /** * Propagates the signals through the neural network. */ private void propagate() { for (int l = 0; l < net.length - 1; l++) { propagate(net[l], net[l + 1]); } } /** * Returns natural log without underflow. */ private static double log(double x) { double y = 0.0; if (x < 1E-300) { y = -690.7755; } else { y = Math.log(x); } return y; } /** * Compute the network output error. * @param output the desired output. */ private double computeOutputError(double[] output) { return computeOutputError(output, outputLayer.error); } /** * Compute the network output error. * @param output the desired output. * @param gradient the array to store gradient on output. * @return the error defined by loss function. */ private double computeOutputError(double[] output, double[] gradient) { if (output.length != outputLayer.units) { throw new IllegalArgumentException(String.format("Invalid output vector size: %d, expected: %d", output.length, outputLayer.units)); } double error = 0.0; for (int i = 0; i < outputLayer.units; i++) { double out = outputLayer.output[i]; double g = output[i] - out; if (errorFunction == ErrorFunction.LEAST_MEAN_SQUARES && activationFunction == ActivationFunction.LOGISTIC_SIGMOID) { g *= out * (1.0 - out); } if (errorFunction == ErrorFunction.LEAST_MEAN_SQUARES) { error += 0.5 * g * g; } else if (errorFunction == ErrorFunction.CROSS_ENTROPY) { if (activationFunction == ActivationFunction.SOFTMAX) { error -= output[i] * log(out); } else if (activationFunction == ActivationFunction.LOGISTIC_SIGMOID) { error = -output[i] * log(out) - (1.0 - output[i]) * log(1.0 - out); } } gradient[i] = g; } return error; } /** * Propagates the errors back from a upper layer to the next lower layer. * @param upper the lower layer where errors are from. * @param lower the upper layer where errors are propagated back to. */ private void backpropagate(Layer upper, Layer lower) { for (int i = 0; i <= lower.units; i++) { double out = lower.output[i]; double err = 0; for (int j = 0; j < upper.units; j++) { err += upper.weight[j][i] * upper.error[j]; } lower.error[i] = out * (1.0 - out) * err; } } /** * Propagates the errors back through the network. */ private void backpropagate() { for (int l = net.length; --l > 0;) { backpropagate(net[l], net[l - 1]); } } /** * Adjust network weights by back-propagation algorithm. */ private void adjustWeights() { for (int l = 1; l < net.length; l++) { for (int i = 0; i < net[l].units; i++) { for (int j = 0; j <= net[l - 1].units; j++) { double out = net[l - 1].output[j]; double err = net[l].error[i]; double delta = (1 - alpha) * eta * err * out + alpha * net[l].delta[i][j]; net[l].delta[i][j] = delta; net[l].weight[i][j] += delta; if (lambda != 0.0 && j < net[l-1].units) { net[l].weight[i][j] *= (1.0 - eta * lambda); } } } } } /** * Predict the target value of a given instance. Note that this method is NOT * multi-thread safe. * @param x the instance. * @param y the array to store network output on output. For softmax * activation function, these are estimated posteriori probabilities. * @return the predicted class label. */ @Override public int predict(double[] x, double[] y) { setInput(x); propagate(); getOutput(y); if (outputLayer.units == 1) { if (outputLayer.output[0] > 0.5) { return 0; } else { return 1; } } double max = Double.NEGATIVE_INFINITY; int label = -1; for (int i = 0; i < outputLayer.units; i++) { if (outputLayer.output[i] > max) { max = outputLayer.output[i]; label = i; } } return label; } /** * Predict the class of a given instance. Note that this method is NOT * multi-thread safe. * @param x the instance. * @return the predicted class label. */ @Override public int predict(double[] x) { setInput(x); propagate(); if (outputLayer.units == 1) { if (outputLayer.output[0] > 0.5) { return 0; } else { return 1; } } double max = Double.NEGATIVE_INFINITY; int label = -1; for (int i = 0; i < outputLayer.units; i++) { if (outputLayer.output[i] > max) { max = outputLayer.output[i]; label = i; } } return label; } /** * Update the neural network with given instance and associated target value. * Note that this method is NOT multi-thread safe. * @param x the training instance. * @param y the target value. * @param weight a positive weight value associated with the training instance. * @return the weighted training error before back-propagation. */ public double learn(double[] x, double[] y, double weight) { setInput(x); propagate(); double err = weight * computeOutputError(y); if (weight != 1.0) { for (int i = 0; i < outputLayer.units; i++) { outputLayer.error[i] *= weight; } } backpropagate(); adjustWeights(); return err; } @Override public void learn(double[] x, int y) { learn(x, y, 1.0); } /** * Online update the neural network with a new training instance. * Note that this method is NOT multi-thread safe. * * @param x training instance. * @param y training label. * @param weight a positive weight value associated with the training instance. */ public void learn(double[] x, int y, double weight) { if (weight < 0.0) { throw new IllegalArgumentException("Invalid weight: " + weight); } if (weight == 0.0) { System.out.println("Ignore the training instance with zero weight."); return; } if (y < 0) { throw new IllegalArgumentException("Invalid class label: " + y); } if (outputLayer.units == 1 && y > 1) { throw new IllegalArgumentException("Invalid class label: " + y); } if (outputLayer.units > 1 && y >= outputLayer.units) { throw new IllegalArgumentException("Invalid class label: " + y); } if (errorFunction == ErrorFunction.CROSS_ENTROPY) { if (activationFunction == ActivationFunction.LOGISTIC_SIGMOID) { if (y == 0) { target[0] = 1.0; } else { target[0] = 0.0; } } else { for (int i = 0; i < target.length; i++) { target[i] = 0.0; } target[y] = 1.0; } } else { for (int i = 0; i < target.length; i++) { target[i] = 0.1; } target[y] = 0.9; } learn(x, target, weight); } /** * Trains the neural network with the given dataset for one epoch by * stochastic gradient descent. * * @param x training instances. * @param y training labels in [0, k), where k is the number of classes. */ public void learn(double[][] x, int[] y) { int n = x.length; int[] index = Math.permutate(n); for (int i = 0; i < n; i++) { learn(x[index[i]], y[index[i]]); } } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy