com.simiacryptus.mindseye.layers.cudnn.conv.ConvolutionLayer Maven / Gradle / Ivy
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.layers.cudnn.conv;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.simiacryptus.mindseye.lang.*;
import com.simiacryptus.mindseye.lang.cudnn.CudaSettings;
import com.simiacryptus.mindseye.lang.cudnn.CudaSystem;
import com.simiacryptus.mindseye.lang.cudnn.MultiPrecision;
import com.simiacryptus.mindseye.lang.cudnn.Precision;
import com.simiacryptus.mindseye.layers.Explodable;
import com.simiacryptus.mindseye.network.PipelineNetwork;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.function.DoubleSupplier;
import java.util.function.IntToDoubleFunction;
/**
* This is the general convolution key, allowing any number of input and output bands at high scale. It implements an
* explosion operation to produce a convolution network whose components have a managabe size and the same overall
* function.
*/
@SuppressWarnings("serial")
public class ConvolutionLayer extends LayerBase implements MultiPrecision, Explodable {
@Nullable
private final Tensor kernel;
private final int inputBands;
private final int outputBands;
private int strideX = 1;
private int strideY = 1;
@Nullable
private Integer paddingX = null;
@Nullable
private Integer paddingY = null;
private Precision precision = Precision.Double;
private int batchBands = 0;
/**
* Instantiates a new Convolution key.
*/
protected ConvolutionLayer() {
this(1, 1, 1, 1);
}
/**
* Instantiates a new Convolution key.
*
* @param width the width
* @param height the height
* @param inputBands the input bands
* @param outputBands the output bands
*/
public ConvolutionLayer(final int width, final int height, final int inputBands, final int outputBands) {
super();
assert 0 < width;
assert 0 < height;
assert 0 < inputBands;
assert 0 < outputBands;
this.kernel = new Tensor(width, height, inputBands * outputBands);
if (getKernel().getDimensions().length != 3) throw new IllegalArgumentException();
if (getKernel().getDimensions()[0] <= 0) throw new IllegalArgumentException();
if (getKernel().getDimensions()[1] <= 0) throw new IllegalArgumentException();
if (getKernel().getDimensions()[2] <= 0) throw new IllegalArgumentException();
this.inputBands = inputBands;
this.outputBands = outputBands;
int batchBands = (int) Math.sqrt(CudaSettings.INSTANCE().getMaxFilterElements() / (width * height));
//batchBands = binaryFriendly(batchBands, 3);
setBatchBands(batchBands);
}
/**
* Instantiates a new Convolution key.
*
* @param json the json
* @param resources the resources
*/
protected ConvolutionLayer(@Nonnull final JsonObject json, Map resources) {
super(json);
this.kernel = Tensor.fromJson(json.get("filter"), resources);
assert getKernel().isValid();
this.setBatchBands(json.get("batchBands").getAsInt());
this.setStrideX(json.get("strideX").getAsInt());
this.setStrideY(json.get("strideY").getAsInt());
JsonElement paddingX = json.get("paddingX");
if (null != paddingX && paddingX.isJsonPrimitive()) this.setPaddingX((paddingX.getAsInt()));
JsonElement paddingY = json.get("paddingY");
if (null != paddingY && paddingY.isJsonPrimitive()) this.setPaddingY((paddingY.getAsInt()));
this.precision = Precision.valueOf(json.get("precision").getAsString());
this.inputBands = json.get("inputBands").getAsInt();
this.outputBands = json.get("outputBands").getAsInt();
}
/**
* Binary friendly int.
*
* @param value the value
* @param bits the bits
* @return the int
*/
public static int binaryFriendly(final int value, final int bits) {
return (int) Math.pow(2, (Math.floor(Math.log(value) * bits) / bits) / Math.log(2));
}
/**
* Add.
*
* @param f the f
* @param data the data
*/
public static void add(@Nonnull final DoubleSupplier f, @Nonnull final double[] data) {
for (int i = 0; i < data.length; i++) {
data[i] += f.getAsDouble();
}
}
/**
* From json convolution key.
*
* @param json the json
* @param rs the rs
* @return the convolution key
*/
public static ConvolutionLayer fromJson(@Nonnull final JsonObject json, Map rs) {
return new ConvolutionLayer(json, rs);
}
@Override
protected void _free() {
kernel.freeRef();
super._free();
}
/**
* Add weights convolution key.
*
* @param f the f
* @return the convolution key
*/
@Nonnull
public ConvolutionLayer addWeights(@Nonnull final DoubleSupplier f) {
ConvolutionLayer.add(f, getKernel().getData());
return this;
}
/**
* Gets compatibility key.
*
* @return the compatibility key
*/
@Nonnull
public Layer getCompatibilityLayer() {
return null;
// return this.as(com.simiacryptus.mindseye.layers.aparapi.ConvolutionLayer.class);
}
/**
* Explode nn key.
*
* @return the nn key
*/
@Nonnull
@Override
public Layer explode() {
@Nonnull ExplodedConvolutionGrid explodedNetwork = getExplodedNetwork();
try {
@Nonnull Layer network = explodedNetwork.getNetwork();
network.setName(getName());
return network;
} finally {
explodedNetwork.freeRef();
}
}
/**
* Gets exploded network.
*
* @return the exploded network
*/
@Nonnull
public ExplodedConvolutionGrid getExplodedNetwork() {
assertAlive();
int batchBands = getBatchBands();
if (0 == batchBands) {
batchBands = Math.max(inputBands, outputBands);
}
// if (batchBands > outputBands * 2) {
// batchBands = outputBands;
// }
return new ExplodedConvolutionGrid(getConvolutionParams(), batchBands).write(kernel);
}
/**
* Gets convolution params.
*
* @return the convolution params
*/
@Nonnull
public ConvolutionParams getConvolutionParams() {
return new ConvolutionParams(inputBands, outputBands, precision, strideX, strideY, paddingX, paddingY, kernel.getDimensions());
}
@Nullable
@Override
public Result evalAndFree(@Nonnull final Result... inObj) {
final Tensor kernel = getKernel();
kernel.addRef();
assert kernel.isValid();
assert 1 == inObj.length;
assert 3 == inObj[0].getData().getDimensions().length;
assert inputBands == inObj[0].getData().getDimensions()[2] : Arrays.toString(inObj[0].getData().getDimensions()) + "[2] != " + inputBands;
if (!CudaSystem.isEnabled()) return getCompatibilityLayer().evalAndFree(inObj);
@Nonnull ExplodedConvolutionGrid grid = getExplodedNetwork();
@Nonnull PipelineNetwork network = grid.getNetwork();
final Result result;
try {
if (isFrozen()) {
network.freeze();
}
result = network.evalAndFree(inObj);
} finally {
network.freeRef();
}
final TensorList resultData = result.getData();
assert inObj[0].getData().length() == resultData.length();
assert 3 == resultData.getDimensions().length;
assert outputBands == resultData.getDimensions()[2];
ConvolutionLayer.this.addRef();
return new Result(resultData, (@Nonnull final DeltaSet deltaSet, @Nonnull final TensorList delta) -> {
result.accumulate(deltaSet, delta);
if (!isFrozen()) {
Tensor read = grid.read(deltaSet, true);
deltaSet.get(ConvolutionLayer.this.getId(), kernel.getData()).addInPlace(read.getData()).freeRef();
read.freeRef();
}
}) {
@Override
public void accumulate(final DeltaSet buffer, final TensorList delta) {
getAccumulator().accept(buffer, delta);
}
@Override
protected void _free() {
grid.freeRef();
result.freeRef();
kernel.freeRef();
ConvolutionLayer.this.freeRef();
}
@Override
public boolean isAlive() {
return result.isAlive();
}
};
}
@Nonnull
@Override
public JsonObject getJson(Map resources, @Nonnull DataSerializer dataSerializer) {
@Nonnull final JsonObject json = super.getJsonStub();
json.add("filter", getKernel().toJson(resources, dataSerializer));
json.addProperty("batchBands", getBatchBands());
json.addProperty("strideX", getStrideX());
json.addProperty("strideY", getStrideY());
json.addProperty("paddingX", getPaddingX());
json.addProperty("paddingY", getPaddingY());
json.addProperty("precision", precision.name());
json.addProperty("inputBands", inputBands);
json.addProperty("outputBands", outputBands);
return json;
}
@Override
public Precision getPrecision() {
return precision;
}
@Nonnull
@Override
public ConvolutionLayer setPrecision(final Precision precision) {
this.precision = precision;
return this;
}
/**
* Sets weights.
*
* @param f the f
* @return the weights
*/
@Nonnull
public ConvolutionLayer set(@Nonnull final DoubleSupplier f) {
return set(i -> f.getAsDouble());
}
/**
* Set convolution key.
*
* @param tensor the tensor
* @return the convolution key
*/
@Nonnull
public ConvolutionLayer set(@Nonnull final Tensor tensor) {
getKernel().set(tensor);
return this;
}
/**
* Sets and free.
*
* @param tensor the tensor
* @return the and free
*/
@Nonnull
public ConvolutionLayer setAndFree(@Nonnull final Tensor tensor) {
set(tensor);
tensor.freeRef();
return this;
}
/**
* Set convolution key.
*
* @param f the f
* @return the convolution key
*/
@Nonnull
public ConvolutionLayer set(@Nonnull final IntToDoubleFunction f) {
getKernel().set(f);
return this;
}
@Nonnull
@Override
public List state() {
return Arrays.asList(getKernel().getData());
}
/**
* The Stride x.
*
* @return the stride x
*/
public int getStrideX() {
return strideX;
}
/**
* Sets stride x.
*
* @param strideX the stride x
* @return the stride x
*/
@Nonnull
public ConvolutionLayer setStrideX(int strideX) {
this.strideX = strideX;
return this;
}
/**
* The Stride y.
*
* @return the stride y
*/
public int getStrideY() {
return strideY;
}
/**
* Sets stride y.
*
* @param strideY the stride y
* @return the stride y
*/
@Nonnull
public ConvolutionLayer setStrideY(int strideY) {
this.strideY = strideY;
return this;
}
/**
* Sets weights log.
*
* @param f the f
* @return the weights log
*/
@Nonnull
public ConvolutionLayer setWeightsLog(double f) {
return set(() -> Math.pow(10, f) * (Math.random() - 0.5));
}
/**
* Sets stride xy.
*
* @param x the x
* @param y the y
* @return the stride xy
*/
@Nonnull
public ConvolutionLayer setStrideXY(int x, int y) {
return setStrideX(x).setStrideY(y);
}
/**
* Sets padding xy.
*
* @param x the x
* @param y the y
* @return the padding xy
*/
@Nonnull
public ConvolutionLayer setPaddingXY(Integer x, Integer y) {
return setPaddingX(x).setPaddingY(y);
}
/**
* Gets padding x.
*
* @return the padding x
*/
@Nullable
public Integer getPaddingX() {
return paddingX;
}
/**
* Sets padding x.
*
* @param paddingX the padding x
* @return the padding x
*/
@Nonnull
public ConvolutionLayer setPaddingX(Integer paddingX) {
this.paddingX = paddingX;
return this;
}
/**
* Gets padding y.
*
* @return the padding y
*/
@Nullable
public Integer getPaddingY() {
return paddingY;
}
/**
* Sets padding y.
*
* @param paddingY the padding y
* @return the padding y
*/
@Nonnull
public ConvolutionLayer setPaddingY(Integer paddingY) {
this.paddingY = paddingY;
return this;
}
/**
* The Filter.
*
* @return the kernel
*/
@Nullable
public Tensor getKernel() {
return kernel;
}
/**
* Gets batch bands.
*
* @return the batch bands
*/
public int getBatchBands() {
return batchBands;
}
/**
* Sets batch bands.
*
* @param batchBands the batch bands
* @return the batch bands
*/
@Nonnull
public ConvolutionLayer setBatchBands(int batchBands) {
this.batchBands = batchBands;
return this;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy