All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.bigml.mimir.deepnet.layers.twod.ConvolutionBlock2D Maven / Gradle / Ivy

The newest version!
package org.bigml.mimir.deepnet.layers.twod;

import org.bigml.mimir.deepnet.layers.Activation;
import org.bigml.mimir.deepnet.layers.BatchNormalize;
import org.bigml.mimir.deepnet.layers.Activation.ActivationFn;
import org.bigml.mimir.math.gpu.ConvolutionBlock2DKernel;
import org.bigml.mimir.math.gpu.Device;

/**
 * A 2D-deepnet layer specifying a convolution, followed by batch normalization
 * and then activation.  While more directly implemented as three layers in
 * sequence, the backend has a specialized GPU-kernel that outsources the
 * entirety of the computation to the GPU without multiple copies to and from
 * the GPU memory.
 *
 * The layer itself is immutable and thread-safety is guaranteed by the use of
 * the OutputTensor class.
 *
 * @see OutputTensor
 * @author  Charles Parker
 */
public class ConvolutionBlock2D extends AbstractConvolution2D {
    public ConvolutionBlock2D(
            Convolution2D conv, BatchNormalize bn, Activation act) {

        super(conv);

        _programType = ConvolutionBlock2DKernel.getProgramType(_kernelShape[2]);

        _mean = bn.getMean();
        _stdev = bn.getStDev();
        _beta = bn.getBeta();
        _gamma = bn.getGamma();
        _afn = act.getFunction();
        _index = act.getIndex();
    }

    @Override
    public float[] run(float[] rawInput, int deviceIndex) {
        float[] output = super.run(rawInput, deviceIndex);

        if (deviceIndex < 0 || deviceIndex >= Device.numberOfDevices())
            return Activation.activate(output, _afn);
        else
            return output;
    }

    @Override
    protected void kernelsForPixel(
            float[] input, int row, int col, float[] output, int outStart) {

        int fIdx = 0;

        for (int f = 0; f < _biases.length; f++) {
            int cRow = row;
            float pVal = _biases[f];

            for (int i = 0; i < _filterH; i++) {
                int pIdx = cRow + col;

                for (int j = 0; j < _filterW; j++) {
                    for (int d = 0; d < _inputChannels; d++) {
                        pVal += input[pIdx] * _filters[fIdx];
                        pIdx++;
                        fIdx++;
                    }
                }
                cRow += _rowLength;
            }

            pVal = (pVal - _mean[f]) / _stdev[f];
            pVal = _gamma[f] * pVal + _beta[f];
            output[outStart + f] = pVal;
        }
    }

    public float[] getMean() {
        return _mean;
    }

    public float[] getStDev() {
        return _stdev;
    }

    public float[] getBeta() {
        return _beta;
    }

    public float[] getGamma() {
        return _gamma;
    }

    public ActivationFn getFunction() {
        return _afn;
    }

    private float[] _mean;
    private float[] _stdev;
    private float[] _gamma;
    private float[] _beta;
    private ActivationFn _afn;

    private static final long serialVersionUID = 1L;
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy