All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ml.shifu.guagua.example.nn.Gradient Maven / Gradle / Ivy

/*
 * Copyright [2013-2014] PayPal Software Foundation
 *  
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *  
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package ml.shifu.guagua.example.nn;

import java.util.Arrays;

import org.encog.engine.network.activation.ActivationFunction;
import org.encog.mathutil.error.ErrorCalculation;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.neural.error.ErrorFunction;
import org.encog.neural.flat.FlatNetwork;
import org.encog.neural.networks.BasicNetwork;

/**
 * {@link Gradient} is copied from Encog framework. The reason is that we original Gradient don't pop up
 * {@link #gradients} outside. While we need gradients accumulated into {@link NNMaster} to update NN weights.
 */
public class Gradient {

    /**
     * The network to train.
     */
    private FlatNetwork network;

    /**
     * The error calculation method.
     */
    private final ErrorCalculation errorCalculation = new ErrorCalculation();

    /**
     * The actual values from the neural network.
     */
    private final double[] actual;

    /**
     * The deltas for each layer.
     */
    private final double[] layerDelta;

    /**
     * The neuron counts, per layer.
     */
    private final int[] layerCounts;

    /**
     * The feed counts, per layer.
     */
    private final int[] layerFeedCounts;

    /**
     * The layer indexes.
     */
    private final int[] layerIndex;

    /**
     * The index to each layer's weights and thresholds.
     */
    private final int[] weightIndex;

    /**
     * The output from each layer.
     */
    private final double[] layerOutput;

    /**
     * The sums.
     */
    private final double[] layerSums;

    /**
     * The gradients.
     */
    private double[] gradients;

    /**
     * The weights and thresholds.
     */
    private double[] weights;

    /**
     * The pair to use for training.
     */
    private final MLDataPair pair;

    /**
     * The training data.
     */
    private final MLDataSet training;

    /**
     * error
     */
    private double error;

    /**
     * Derivative add constant. Used to combat flat spot.
     */
    private double[] flatSpot;

    /**
     * The error function to use.
     */
    private final ErrorFunction errorFunction;

    /**
     * Construct a gradient worker.
     * 
     * @param theNetwork
     *            The network to train.
     * @param theOwner
     *            The owner that is doing the training.
     * @param theTraining
     *            The training data.
     * @param theLow
     *            The low index to use in the training data.
     * @param theHigh
     *            The high index to use in the training data.
     */
    public Gradient(final FlatNetwork theNetwork, final MLDataSet theTraining, final double[] flatSpot, ErrorFunction ef) {
        this.network = theNetwork;
        this.training = theTraining;
        this.flatSpot = flatSpot;
        this.errorFunction = ef;

        this.layerDelta = new double[getNetwork().getLayerOutput().length];
        this.gradients = new double[getNetwork().getWeights().length];
        this.actual = new double[getNetwork().getOutputCount()];

        this.weights = getNetwork().getWeights();
        this.layerIndex = getNetwork().getLayerIndex();
        this.layerCounts = getNetwork().getLayerCounts();
        this.weightIndex = getNetwork().getWeightIndex();
        this.layerOutput = getNetwork().getLayerOutput();
        this.layerSums = getNetwork().getLayerSums();
        this.layerFeedCounts = getNetwork().getLayerFeedCounts();

        this.pair = BasicMLDataPair.createPair(getNetwork().getInputCount(), getNetwork().getOutputCount());
    }

    /**
     * Process one training set element.
     * 
     * @param input
     *            The network input.
     * @param ideal
     *            The ideal values.
     * @param s
     *            The significance.
     */
    private void process(final double[] input, final double[] ideal, double s) {
        this.getNetwork().compute(input, this.actual);

        this.errorCalculation.updateError(this.actual, ideal, s);
        this.errorFunction.calculateError(ideal, actual, this.getLayerDelta());

        for(int i = 0; i < this.actual.length; i++) {
            this.getLayerDelta()[i] = ((this.getNetwork().getActivationFunctions()[0].derivativeFunction(
                    this.layerSums[i], this.layerOutput[i]) + this.flatSpot[0])) * (this.getLayerDelta()[i] * s);
        }

        for(int i = this.getNetwork().getBeginTraining(); i < this.getNetwork().getEndTraining(); i++) {
            processLevel(i);
        }
    }

    /**
     * Process one level.
     * 
     * @param currentLevel
     *            The level.
     */
    private void processLevel(final int currentLevel) {
        final int fromLayerIndex = this.layerIndex[currentLevel + 1];
        final int toLayerIndex = this.layerIndex[currentLevel];
        final int fromLayerSize = this.layerCounts[currentLevel + 1];
        final int toLayerSize = this.layerFeedCounts[currentLevel];

        final int index = this.weightIndex[currentLevel];
        final ActivationFunction activation = this.getNetwork().getActivationFunctions()[currentLevel + 1];
        final double currentFlatSpot = this.flatSpot[currentLevel + 1];

        // handle weights
        int yi = fromLayerIndex;
        for(int y = 0; y < fromLayerSize; y++) {
            final double output = this.layerOutput[yi];
            double sum = 0;
            int xi = toLayerIndex;
            int wi = index + y;
            for(int x = 0; x < toLayerSize; x++) {
                this.gradients[wi] += output * this.getLayerDelta()[xi];
                sum += this.weights[wi] * this.getLayerDelta()[xi];
                wi += fromLayerSize;
                xi++;
            }

            this.getLayerDelta()[yi] = sum
                    * (activation.derivativeFunction(this.layerSums[yi], this.layerOutput[yi]) + currentFlatSpot);
            yi++;
        }
    }

    /**
     * Perform the gradient calculation
     */
    public final void run() {
        try {
            // reset errors and gradients firstly
            this.errorCalculation.reset();
            Arrays.fill(this.gradients, 0.0);

            for(int i = 0; i < this.training.getRecordCount(); i++) {
                this.training.getRecord(i, this.pair);
                process(this.pair.getInputArray(), this.pair.getIdealArray(), pair.getSignificance());
            }
            this.error = this.errorCalculation.calculate();

        } catch (final Throwable ex) {
            throw new RuntimeException(ex);
        }
    }

    public ErrorCalculation getErrorCalculation() {
        return errorCalculation;
    }

    /**
     * @return the gradients
     */
    public double[] getGradients() {
        return this.gradients;
    }

    /**
     * @return the error
     */
    public double getError() {
        return error;
    }

    /**
     * @return the weights
     */
    public double[] getWeights() {
        return weights;
    }

    /**
     * @param weights
     *            the weights to set
     */
    public void setWeights(double[] weights) {
        this.weights = weights;
        this.getNetwork().setWeights(weights);
    }

    public void setParams(BasicNetwork network) {
        this.network = network.getFlat();
        this.weights = network.getFlat().getWeights();
    }

    public FlatNetwork getNetwork() {
        return network;
    }

    public double[] getLayerDelta() {
        return layerDelta;
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy