All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.bigml.mimir.deepnet.layers.twod.InitialConvolution2D Maven / Gradle / Ivy

The newest version!
package org.bigml.mimir.deepnet.layers.twod;

import org.bigml.mimir.math.gpu.Convolution2DKernel;

/**
 * A 2D-deepnet layer specifying the first convolution of the input image in
 * a CNN.  This convolution requires some special processing because parameters
 * for this first convolution assume an input depth of three (RGB), whereas we
 * upconvert this to depth 4, adding an "always opaque" transparency channel
 * for better compatibility with vector processing on the GPU.  We thus add
 * dummy parameters to upconvert the kernel similarly.
 *
 * The layer itself is immutable and thread-safety is guaranteed by the use of
 * the OutputTensor class.
 *
 * @see OutputTensor
 * @author  Charles Parker
 */
public class InitialConvolution2D extends Convolution2D {

    public InitialConvolution2D(
            double[][][][] filters,
            double[] biases,
            int[] strides,
            boolean samePadding) {

        super(filters, biases, strides, samePadding);

        int inChan = filters[0][0].length + 1;
        assert inChan == 4;

        _kernelShape[2] = inChan;
        _programType = Convolution2DKernel.getProgramType(inChan);
    }

    @Override
    public float[] unrollFilters(double[][][][] filters) {
        int filtH = filters.length;
        int filtW = filters[0].length;
        int filtD = filters[0][0].length;
        int nFilters = filters[0][0][0].length;

        float[] output = new float[nFilters * filtH * filtW * (filtD + 1)];
        int oIdx = 0;

        for (int n = 0; n < nFilters; n++) {
            for (int i = 0; i < filtH; i++) {
                for (int j = 0; j < filtW; j++) {
                    for (int d = 0; d < filtD; d++) {
                        output[oIdx] = (float)filters[i][j][d][n];
                        oIdx++;
                    }
                    output[oIdx] = 0;
                    oIdx++;
                }
            }
        }

        return output;
    }

    private static final long serialVersionUID = 1L;
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy