com.simiacryptus.mindseye.layers.cudnn.conv.SimpleConvolutionLayer Maven / Gradle / Ivy
/*
* Copyright (c) 2018 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.layers.cudnn.conv;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.simiacryptus.mindseye.lang.*;
import com.simiacryptus.mindseye.lang.cudnn.*;
import com.simiacryptus.util.Util;
import jcuda.jcudnn.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.DoubleSupplier;
import java.util.function.ToDoubleFunction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
/**
* This convolution key only supports an equal number of input and output bands. It is used as the foundational
* component for ConvolutionLayer, since the CudaSystem api has this restriction (in recent versions).
*/
@SuppressWarnings("serial")
public class SimpleConvolutionLayer extends LayerBase implements MultiPrecision {
/**
* The Log.
*/
static final Logger log = LoggerFactory.getLogger(SimpleConvolutionLayer.class);
/**
* The Kernel.
*/
public final Tensor kernel;
/**
* The Filter.
*/
@Nullable
private final Map gpuFilters = new ConcurrentHashMap<>();
private int paddingX;
private int paddingY;
private Precision precision = Precision.Double;
private int strideX = 1;
private int strideY = 1;
/**
* Instantiates a new Convolution key.
*/
protected SimpleConvolutionLayer() {
this(null);
}
/**
* Instantiates a new Convolution key.
*
* @param width the width
* @param height the height
* @param bands the bands
*/
public SimpleConvolutionLayer(final int width, final int height, final int bands) {
this(new Tensor(width, height, bands));
kernel.freeRef();
assert !false || 0 == (width - 1) % 2 : "Simple kernels must have odd width";
assert !false || 0 == (height - 1) % 2 : "Simple kernels must have odd height";
}
/**
* Instantiates a new Convolution key.
*
* @param json the json
* @param resources the resources
*/
protected SimpleConvolutionLayer(@Nonnull final JsonObject json, Map resources) {
super(json);
kernel = Tensor.fromJson(json.get("filter"), resources);
strideX = json.get("strideX").getAsInt();
strideY = json.get("strideY").getAsInt();
setPaddingX(json.get("paddingX").getAsInt());
setPaddingY(json.get("paddingY").getAsInt());
precision = Precision.valueOf(json.get("precision").getAsString());
}
/**
* Instantiates a new Convolution key.
*
* @param kernel the filter
*/
protected SimpleConvolutionLayer(@Nonnull final Tensor kernel) {
super();
@Nonnull int[] kernelSize = kernel.getDimensions();
if (kernelSize.length != 3) throw new IllegalArgumentException();
if (kernelSize[0] <= 0) throw new IllegalArgumentException();
if (kernelSize[1] <= 0) throw new IllegalArgumentException();
if (kernelSize[2] <= 0) throw new IllegalArgumentException();
this.kernel = kernel;
this.kernel.addRef(this);
this.setPaddingX((int) Math.ceil((kernelSize[0] - 1) / 2.0));
this.setPaddingY((int) Math.ceil((kernelSize[1] - 1) / 2.0));
}
/**
* From json convolution key.
*
* @param json the json
* @param rs the rs
* @return the convolution key
*/
public static SimpleConvolutionLayer fromJson(@Nonnull final JsonObject json, Map rs) {
return new SimpleConvolutionLayer(json, rs);
}
/**
* Reverse int [ ].
*
* @param array the array
* @return the int [ ]
*/
@Nonnull
public static int[] reverse(@Nonnull int... array) {
for (int i = 0; i < array.length / 2; i++) {
int j = array[array.length - (i + 1)];
array[array.length - (i + 1)] = array[i];
array[i] = j;
}
return array;
}
/**
* Add weights convolution key.
*
* @param f the f
* @return the convolution key
*/
@Nonnull
public SimpleConvolutionLayer addWeights(@Nonnull final DoubleSupplier f) {
Util.add(f, kernel.getData());
return this;
}
private boolean cmp(final int[] outputSize, @Nonnull final int[] outputDims) {
if (4 != outputDims.length) return false;
if (outputSize[0] != outputDims[3]) return false;
if (outputSize[1] != outputDims[2]) return false;
return outputSize[2] == outputDims[1];
}
@Nullable
@Override
public Result evalAndFree(@Nonnull final Result... inObj) {
if (!CudaSystem.isEnabled()) return getCompatibilityLayer().eval(inObj);
final Result input = inObj[0];
final TensorList inputData = input.getData();
@Nonnull final int[] inputSize = inputData.getDimensions();
@Nonnull final int[] kernelSize = kernel.getDimensions();
final int[] outputSize = getOutputSize(inputSize);
final int length = inputData.length();
kernel.addRef();
SimpleConvolutionLayer.this.addRef();
return new Result(CudaSystem.run(gpu -> {
assert CudaDevice.isThreadDeviceId(gpu.getDeviceId());
@Nullable final CudaTensor inputTensor = gpu.getTensor(inputData, precision, MemoryType.Device, false);
final CudaResource filterDescriptor = gpu.newFilterDescriptor(
precision, cudnnTensorFormat.CUDNN_TENSOR_NCHW, outputSize[2], inputSize[2], kernelSize[1], kernelSize[0]);
final CudaResource convolutionDescriptor = gpu.newConvolutions2dDescriptor(cudnnConvolutionMode.CUDNN_CONVOLUTION, precision,
paddingY, paddingX,
strideY, strideX,
1, 1);
final int[] outputDims = IntStream.of(reverse(CudaSystem.getOutputDims(inputTensor.descriptor.getPtr(), filterDescriptor.getPtr(), convolutionDescriptor.getPtr()))).limit(3).toArray();
final CudaDevice.CudaTensorDescriptor outputDescriptor = gpu.newTensorDescriptor(precision, length,
outputDims[2], outputDims[1], outputDims[0],
outputDims[2] * outputDims[1] * outputDims[0], outputDims[1] * outputDims[0], outputDims[0], 1);
final int forwardAlgorithm = getForwardAlgorithm(gpu, inputTensor, filterDescriptor, convolutionDescriptor, outputDescriptor);
final CudaMemory forwardWorkspace = gpu.allocateForwardWorkspace(
inputTensor.descriptor.getPtr(), filterDescriptor.getPtr(), convolutionDescriptor.getPtr(), outputDescriptor.getPtr(), forwardAlgorithm, 1);
try {
assert 0 < kernel.getData().length;
assert kernelSize[0] * kernelSize[1] * kernelSize[2] == kernel.getData().length;
@Nonnull CudaMemory filterPtr = getCudaFilter(gpu);
@Nonnull final CudaMemory outputBuffer = gpu.allocate(
(long) Tensor.length(outputDims) * length * precision.size, MemoryType.Managed.normalize(), true);
CudaMemory inputTensorMemory = inputTensor.getMemory(gpu);
// inputTensorMemory.synchronize();
CudaSystem.handle(gpu.cudnnConvolutionForward(precision.getPointer(1.0),
inputTensor.descriptor.getPtr(), inputTensorMemory.getPtr(),
filterDescriptor.getPtr(), filterPtr.getPtr(),
convolutionDescriptor.getPtr(),
forwardAlgorithm,
null == forwardWorkspace ? null : forwardWorkspace.getPtr(),
null == forwardWorkspace ? 0 : forwardWorkspace.size,
precision.getPointer(0.0), outputDescriptor.getPtr(), outputBuffer.getPtr()));
assert CudaDevice.isThreadDeviceId(gpu.getDeviceId());
forwardWorkspace.dirty();
filterPtr.dirty();
outputBuffer.dirty();
inputTensorMemory.dirty();
// inputTensorMemory.synchronize();
inputTensorMemory.freeRef();
filterPtr.freeRef();
outputDescriptor.addRef();
return CudaTensorList.wrap(CudaTensor.wrap(outputBuffer, outputDescriptor, precision), length, outputDims, precision);
} catch (@Nonnull final Throwable e) {
throw new ComponentException(String.format("Error in convolution %s x %s", Arrays.toString(inputSize), Arrays.toString(kernelSize)), e);
} finally {
Stream.of(inputTensor, filterDescriptor, outputDescriptor, forwardWorkspace, convolutionDescriptor).forEach(ReferenceCounting::freeRef);
}
}, inputData), (@Nonnull final DeltaSet buffer, @Nonnull final TensorList delta) -> {
delta.assertAlive();
buffer.assertAlive();
inputData.assertAlive();
assert delta.length() == length;
delta.addRef();
Runnable learnFn = () -> {
if (!isFrozen()) {
@Nonnull final Tensor weightGradient = CudaSystem.run(gpu -> {
@Nullable final CudaTensor deltaTensor = gpu.getTensor(delta, precision, MemoryType.Device, true);
delta.freeRef();
@Nullable final CudaTensor inputTensor = gpu.getTensor(inputData, precision, MemoryType.Device, false);
final CudaResource filterDescriptor = gpu.newFilterDescriptor(
precision, cudnnTensorFormat.CUDNN_TENSOR_NCHW, outputSize[2], inputSize[2], kernelSize[1], kernelSize[0]);
final CudaResource convolutionDescriptor = gpu.newConvolutions2dDescriptor(cudnnConvolutionMode.CUDNN_CONVOLUTION, precision,
paddingY, paddingX,
strideY, strideX,
1, 1);
final int backwardFilterAlgorithm = getBackwardFilterAlgorithm(gpu, deltaTensor, inputTensor, filterDescriptor, convolutionDescriptor);
final CudaMemory backwardsFilterWorkSpace = gpu.allocateBackwardFilterWorkspace(
inputTensor.descriptor.getPtr(), filterDescriptor.getPtr(),
convolutionDescriptor.getPtr(), deltaTensor.descriptor.getPtr(), backwardFilterAlgorithm, 1);
@Nonnull CudaMemory filterPtr = gpu.allocate((long) kernel.length() * precision.size, MemoryType.Device, true);
try {
CudaMemory inputTensorMemory = inputTensor.getMemory(gpu);
CudaMemory deltaTensorMemory = deltaTensor.getMemory(gpu, MemoryType.Managed.normalize());
// inputTensorMemory.synchronize();
CudaSystem.handle(gpu.cudnnConvolutionBackwardFilter(precision.getPointer(1.0),
inputTensor.descriptor.getPtr(), inputTensorMemory.getPtr(),
deltaTensor.descriptor.getPtr(), deltaTensorMemory.getPtr(),
convolutionDescriptor.getPtr(),
backwardFilterAlgorithm,
backwardsFilterWorkSpace.getPtr(),
backwardsFilterWorkSpace.size,
precision.getPointer(0.0), filterDescriptor.getPtr(), filterPtr.getPtr()));
filterPtr.dirty();
deltaTensorMemory.dirty();
inputTensorMemory.dirty();
backwardsFilterWorkSpace.dirty();
// backwardsFilterWorkSpace.synchronize();
inputTensorMemory.freeRef();
deltaTensorMemory.freeRef();
return filterPtr.read(precision, kernel.getDimensions());
} catch (@Nonnull final Throwable e) {
throw new ComponentException(String.format("Error in convolution %s x %s => %s", Arrays.toString(inputSize), Arrays.toString(kernelSize), Arrays.toString(outputSize)), e);
} finally {
inputTensor.freeRef();
filterPtr.freeRef();
deltaTensor.freeRef();
Stream.of(filterDescriptor, convolutionDescriptor, backwardsFilterWorkSpace).forEach(ReferenceCounting::freeRef);
}
}, delta);
buffer.get(SimpleConvolutionLayer.this.getId(), kernel.getData()).addInPlace(weightGradient.getData()).freeRef();
weightGradient.freeRef();
clearCudaFilters();
} else {
delta.freeRef();
}
};
Runnable backpropFn = () -> {
if (input.isAlive()) {
final TensorList inputBufferTensors = CudaSystem.run(gpu -> {
final CudaDevice.CudaTensorDescriptor inputDescriptor = gpu.newTensorDescriptor(precision, length, inputSize[2], inputSize[1], inputSize[0], inputSize[2] * inputSize[1] * inputSize[0], inputSize[1] * inputSize[0], inputSize[0], 1);
final CudaResource filterDescriptor = gpu.newFilterDescriptor(
precision, cudnnTensorFormat.CUDNN_TENSOR_NCHW, outputSize[2], inputSize[2], kernelSize[1], kernelSize[0]);
final CudaResource convolutionDescriptor = gpu.newConvolutions2dDescriptor(cudnnConvolutionMode.CUDNN_CONVOLUTION, precision,
paddingY, paddingX,
strideY, strideX,
1, 1);
@Nullable final CudaTensor deltaTensor = gpu.getTensor(delta, precision, MemoryType.Device, false);
delta.freeRef();
final int backwardDataAlgorithm = getBackwardDataAlgorithm(gpu, inputDescriptor, filterDescriptor, convolutionDescriptor, deltaTensor);
final CudaMemory backwardsDataWorkSpace = gpu.allocateBackwardDataWorkspace(
inputDescriptor.getPtr(), filterDescriptor.getPtr(),
convolutionDescriptor.getPtr(), deltaTensor.descriptor.getPtr(), backwardDataAlgorithm, 1);
@Nonnull final CudaMemory filterPtr = getCudaFilter(gpu);
try {
@Nonnull final CudaMemory passbackMemory = gpu.allocate((long) Tensor.length(inputData.getDimensions()) * length * precision.size, MemoryType.Managed.normalize(), true);
CudaMemory deltaTensorMemory = deltaTensor.getMemory(gpu);
// deltaTensorMemory.synchronize();
CudaSystem.handle(gpu.cudnnConvolutionBackwardData(precision.getPointer(1.0),
filterDescriptor.getPtr(), filterPtr.getPtr(),
deltaTensor.descriptor.getPtr(), deltaTensorMemory.getPtr(),
convolutionDescriptor.getPtr(),
backwardDataAlgorithm,
backwardsDataWorkSpace.getPtr(),
backwardsDataWorkSpace.size,
precision.getPointer(0.0), inputDescriptor.getPtr(), passbackMemory.getPtr()));
passbackMemory.dirty();
backwardsDataWorkSpace.dirty();
deltaTensorMemory.dirty();
// deltaTensorMemory.synchronize();
filterPtr.dirty();
deltaTensorMemory.freeRef();
inputDescriptor.addRef();
return CudaTensorList.wrap(CudaTensor.wrap(passbackMemory, inputDescriptor, precision), length, inputSize, precision);
} catch (@Nonnull final Throwable e) {
throw new ComponentException(String.format("Error in convolution %s x %s => %s", Arrays.toString(inputSize), Arrays.toString(kernelSize), Arrays.toString(outputSize)), e);
} finally {
filterPtr.freeRef();
deltaTensor.freeRef();
Stream.of(inputDescriptor, filterDescriptor, convolutionDescriptor, backwardsDataWorkSpace).forEach(ReferenceCounting::freeRef);
}
}, delta);
if (null != inputBufferTensors) {
input.accumulate(buffer, inputBufferTensors);
}
} else {
delta.freeRef();
}
};
Stream.of(learnFn, backpropFn).forEach(Runnable::run);
}) {
@Override
public final void accumulate(DeltaSet buffer, TensorList delta) {
getAccumulator().accept(buffer, delta);
}
@Override
protected void _free() {
kernel.assertAlive();
kernel.freeRef();
inputData.freeRef();
Arrays.stream(inObj).forEach(ReferenceCounting::freeRef);
SimpleConvolutionLayer.this.freeRef();
}
@Override
public boolean isAlive() {
return input.isAlive() || !isFrozen();
}
};
}
/**
* Gets forward algorithm.
*
* @param gpu the gpu
* @param inputTensor the input tensor
* @param filterDescriptor the filter descriptor
* @param convolutionDescriptor the convolution descriptor
* @param outputDescriptor the output descriptor
* @return the forward algorithm
*/
public int getForwardAlgorithm(final CudnnHandle gpu, final CudaTensor inputTensor, final CudaResource filterDescriptor, final CudaResource convolutionDescriptor, final CudaDevice.CudaTensorDescriptor outputDescriptor) {
// return cudnnConvolutionFwdAlgo.CUDNN_CONVOLUTION_FWD_ALGO_FFT;
return gpu.getForwardAlgorithm(
inputTensor.descriptor.getPtr(), filterDescriptor.getPtr(), convolutionDescriptor.getPtr(),
outputDescriptor.getPtr(), CudaSettings.INSTANCE().getConvolutionWorkspaceSizeLimit());
}
/**
* Gets backward filter algorithm.
*
* @param gpu the gpu
* @param deltaTensor the evalInputDelta tensor
* @param inputTensor the input tensor
* @param filterDescriptor the filter descriptor
* @param convolutionDescriptor the convolution descriptor
* @return the backward filter algorithm
*/
public int getBackwardFilterAlgorithm(final CudnnHandle gpu, final CudaTensor deltaTensor, final CudaTensor inputTensor, final CudaResource filterDescriptor, final CudaResource convolutionDescriptor) {
return gpu.getBackwardFilterAlgorithm(
inputTensor.descriptor.getPtr(), filterDescriptor.getPtr(), convolutionDescriptor.getPtr(), deltaTensor.descriptor.getPtr(), CudaSettings.INSTANCE().getConvolutionWorkspaceSizeLimit());
}
/**
* Gets backward data algorithm.
*
* @param gpu the gpu
* @param inputDescriptor the input descriptor
* @param filterDescriptor the filter descriptor
* @param convolutionDescriptor the convolution descriptor
* @param deltaTensor the evalInputDelta tensor
* @return the backward data algorithm
*/
public int getBackwardDataAlgorithm(final CudnnHandle gpu, final CudaDevice.CudaTensorDescriptor inputDescriptor, final CudaResource filterDescriptor, final CudaResource convolutionDescriptor, final CudaTensor deltaTensor) {
return cudnnConvolutionBwdDataAlgo.CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
//return gpu.getBackwardDataAlgorithm(inputDescriptor.getPtr(), filterDescriptor.getPtr(), convolutionDescriptor.getPtr(), deltaTensor.descriptor.getPtr(), CudaSettings.INSTANCE.getConvolutionWorkspaceSizeLimit());
}
/**
* Evict device data long.
*
* @param deviceId the device id
* @return the long
*/
public long evictDeviceData(final int deviceId) {
CudaMemory remove = gpuFilters.remove(deviceId);
if (null != remove) {
if (1 == remove.currentRefCount()) {
remove.freeRef();
return remove.size;
} else {
CudaMemory race = gpuFilters.put(deviceId, remove);
if (race != null) race.freeRef();
return 0;
}
} else {
return 0;
}
}
@Nonnull
private synchronized CudaMemory getCudaFilter(final CudaDevice deviceNumber) {
return CudaSettings.INSTANCE().isConvolutionCache() ? getCudaFilter_cached(deviceNumber) : getCudaFilter_instance(deviceNumber);
}
@Nonnull
private synchronized CudaMemory getCudaFilter_instance(final CudaDevice deviceNumber) {
double[] data = kernel.getData();
return deviceNumber.allocate((long) data.length * precision.size, MemoryType.Device, true).write(precision, data);
}
@Nonnull
private CudaMemory getCudaFilter_cached(final CudaDevice deviceNumber) {
CudaMemory cudaMemory;
if (gpuFilters.containsKey(deviceNumber.getDeviceId())) {
cudaMemory = gpuFilters.get(deviceNumber.getDeviceId());
} else {
double[] data = kernel.getData();
cudaMemory = deviceNumber.allocate((long) data.length * precision.size, MemoryType.Device, true).write(precision, data);
CudaMemory replaced = gpuFilters.put(deviceNumber.getDeviceId(), cudaMemory);
if (null != replaced) replaced.freeRef();
}
cudaMemory.addRef();
return cudaMemory;
}
@Nonnull
private void clearCudaFilters() {
gpuFilters.keySet().stream().collect(Collectors.toList()).stream().forEach(i -> {
CudaMemory cudaMemory = gpuFilters.remove(i);
if (null != cudaMemory) cudaMemory.freeRef();
});
}
@Override
protected void _free() {
kernel.freeRef(this);
clearCudaFilters();
super._free();
}
/**
* Gets compatibility key.
*
* @return the compatibility key
*/
@Nonnull
public Layer getCompatibilityLayer() {
// log.info(String.format("Using compatibility key for %s", this));
// int bands = (int) Math.sqrt(this.kernel.getDimensions()[2]);
// @Nonnull final com.simiacryptus.mindseye.layers.aparapi.ConvolutionLayer convolutionLayer = new com.simiacryptus.mindseye.layers.aparapi.ConvolutionLayer(this.kernel.getDimensions()[0], this.kernel.getDimensions()[1], this.kernel.getDimensions()[2], true);
// @Nonnull final Tensor tensor = new Tensor(kernel.getDimensions());
// tensor.setByCoord(c -> {
// final int band = c.getCoords()[2];
// final int bandX = band % bands;
// final int bandY = (band - bandX) / bands;
// assert band == bandX + bandY * bands;
// final int bandT = bandY + bandX * bands;
// return kernel.get(c.getCoords()[0], c.getCoords()[1], bandT);
// });
// convolutionLayer.kernel.set(tensor);
// return new LayerBase() {
// @Nonnull
// @Override
// public Result eval(@Nonnull Result... array) {
// Arrays.stream(array).forEach(x -> x.addRef());
// @Nonnull Result result = convolutionLayer.eval(array);
// return new Result(result.getData(), (DeltaSet buffer, TensorList data) -> {
// throw new IllegalStateException();
// }) {
//
//
// @Override
// protected void _free() {
// Arrays.stream(array).forEach(x -> x.freeRef());
// }
//
// @Override
// public boolean isAlive() {
// return false;
// }
// };
// }
//
// @Nonnull
// @Override
// public JsonObject getJson(Map resources, DataSerializer dataSerializer) {
// throw new IllegalStateException();
// }
//
// @Nonnull
// @Override
// public List state() {
// throw new IllegalStateException();
// }
// };
return null;
}
@Nonnull
@Override
public JsonObject getJson(Map resources, @Nonnull DataSerializer dataSerializer) {
@Nonnull final JsonObject json = super.getJsonStub();
JsonElement value;
try {
value = kernel.toJson(resources, dataSerializer);
} catch (Throwable e) {
throw new RuntimeException("Error serializing convolution" + Arrays.toString(this.kernel.getDimensions()), e);
}
json.add("filter", value);
json.addProperty("strideX", strideX);
json.addProperty("strideY", strideY);
json.addProperty("paddingX", getPaddingX());
json.addProperty("paddingY", getPaddingY());
json.addProperty("precision", precision.name());
return json;
}
/**
* Get output size int [ ].
*
* @param inputSize the input size
* @return the int [ ]
*/
public int[] getOutputSize(final int... inputSize) {
@Nonnull final int[] kernelSize = kernel.getDimensions();
try {
return IntStream.range(0, kernelSize.length).map(i -> {
int x;
if (i == kernelSize.length - 1) {
//assert kernelSize[i] == inputSize[i];
x = kernelSize[i] / inputSize[i];
} else {
int padding;
if (i == 0) {
padding = this.paddingX;
} else if (i == 1) {
padding = this.paddingY;
} else {
throw new IllegalStateException();
}
x = inputSize[i] - (kernelSize[i] - 1) + padding * 2;
}
assert 0 < x;
return x;
}).toArray();
} catch (Throwable e) {
throw new RuntimeException(String.format("Error apply convolution %s x %s (%s)", Arrays.toString(inputSize), Arrays.toString(kernelSize), getName()), e);
}
}
@Override
public Precision getPrecision() {
return precision;
}
@Nonnull
@Override
public SimpleConvolutionLayer setPrecision(final Precision precision) {
clearCudaFilters();
this.precision = precision;
return this;
}
/**
* The Stride x.
*
* @return the stride x
*/
public int getStrideX() {
return strideX;
}
/**
* Sets stride x.
*
* @param strideX the stride x
* @return the stride x
*/
@Nonnull
public SimpleConvolutionLayer setStrideX(final int strideX) {
this.strideX = strideX;
return this;
}
/**
* The Stride y.
*
* @return the stride y
*/
public int getStrideY() {
return strideY;
}
/**
* Sets stride y.
*
* @param strideY the stride y
* @return the stride y
*/
@Nonnull
public SimpleConvolutionLayer setStrideY(final int strideY) {
this.strideY = strideY;
return this;
}
/**
* Sets weights.
*
* @param f the f
* @return the weights
*/
@Nonnull
public SimpleConvolutionLayer set(@Nonnull final DoubleSupplier f) {
kernel.coordStream(true).parallel().forEach(c -> {
kernel.set(c, f.getAsDouble());
});
return this;
}
/**
* Sets weights.
*
* @param f the f
* @return the weights
*/
@Nonnull
public SimpleConvolutionLayer set(@Nonnull final ToDoubleFunction f) {
kernel.coordStream(true).parallel().forEach(c -> {
kernel.set(c, f.applyAsDouble(c));
});
return this;
}
@Nonnull
@Override
public List state() {
return Arrays.asList(kernel.getData());
}
/**
* Gets padding x.
*
* @return the padding x
*/
public int getPaddingX() {
return paddingX;
}
/**
* Sets padding x.
*
* @param paddingX the padding x
* @return the padding x
*/
@Nonnull
public SimpleConvolutionLayer setPaddingX(int paddingX) {
this.paddingX = paddingX;
return this;
}
/**
* Gets padding y.
*
* @return the padding y
*/
public int getPaddingY() {
return paddingY;
}
/**
* Sets padding y.
*
* @param paddingY the padding y
* @return the padding y
*/
@Nonnull
public SimpleConvolutionLayer setPaddingY(int paddingY) {
this.paddingY = paddingY;
return this;
}
/**
* Sets padding xy.
*
* @param x the x
* @param y the y
* @return the padding xy
*/
@Nonnull
public SimpleConvolutionLayer setPaddingXY(int x, int y) {
return setPaddingX(x).setPaddingY(y);
}
/**
* Sets weights log.
*
* @param f the f
* @return the weights log
*/
@Nonnull
public SimpleConvolutionLayer setWeightsLog(double f) {
return set(() -> Math.pow(10, f) * (Math.random() - 0.5));
}
/**
* Set.
*
* @param kernel the kernel
*/
public void set(@Nonnull Tensor kernel) {
this.kernel.set(kernel);
}
/**
* Get kernel dimensions int [ ].
*
* @return the int [ ]
*/
public int[] getKernelDimensions() {
return kernel.getDimensions();
}
@Override
public boolean assertAlive() {
if (!super.assertAlive()) {
assert false;
return false;
}
if (!kernel.assertAlive()) {
assert false;
return false;
}
return true;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy