com.simiacryptus.mindseye.layers.cudnn.conv.ExplodedConvolutionLeg Maven / Gradle / Ivy
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.layers.cudnn.conv;
import com.simiacryptus.mindseye.lang.*;
import com.simiacryptus.mindseye.lang.cudnn.CudaSettings;
import com.simiacryptus.mindseye.lang.cudnn.Precision;
import com.simiacryptus.mindseye.layers.cudnn.ImgConcatLayer;
import com.simiacryptus.mindseye.layers.cudnn.ImgTileSubnetLayer;
import com.simiacryptus.mindseye.layers.cudnn.ImgZeroPaddingLayer;
import com.simiacryptus.mindseye.network.DAGNetwork;
import com.simiacryptus.mindseye.network.DAGNode;
import com.simiacryptus.mindseye.network.InnerNode;
import com.simiacryptus.mindseye.network.PipelineNetwork;
import com.simiacryptus.ref.lang.RefAware;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.lang.ReferenceCountingBase;
import com.simiacryptus.ref.wrappers.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.UUID;
import java.util.function.Function;
class ExplodedConvolutionLeg extends ReferenceCountingBase {
private static final Logger log = LoggerFactory.getLogger(ExplodedConvolutionLeg.class);
@Nonnull
public final ConvolutionParams convolutionParams;
@Nonnull
public final RefList subLayers;
@Nonnull
public final RefList subKernels = new RefArrayList<>();
public final int fromBand;
public final int toBand;
public ExplodedConvolutionLeg(@Nonnull ConvolutionParams convolutionParams, int fromBand, int toBand) {
this.fromBand = fromBand;
this.toBand = toBand;
this.convolutionParams = convolutionParams;
this.subLayers = new RefArrayList<>();
int inputBands = getInputBands();
final int inputBandsSq = inputBands * inputBands;
@Nonnull final int[] filterDimensions = RefArrays.copyOf(this.convolutionParams.masterFilterDimensions,
this.convolutionParams.masterFilterDimensions.length);
filterDimensions[2] = inputBands * this.convolutionParams.outputBands;
for (int offset = 0; offset < filterDimensions[2]; offset += inputBandsSq) {
int paddingX = (convolutionParams.masterFilterDimensions[0] - 1) / 2;
int paddingY = (convolutionParams.masterFilterDimensions[1] - 1) / 2;
SimpleConvolutionLayer convolutionLayer = new SimpleConvolutionLayer(filterDimensions[0], filterDimensions[1], inputBandsSq);
convolutionLayer.setStrideX(this.convolutionParams.strideX);
convolutionLayer.setStrideY(this.convolutionParams.strideY);
convolutionLayer.setPrecision(this.convolutionParams.precision);
PipelineNetwork stackableConv = new PipelineNetwork(1);
if (paddingY != 0 || paddingX != 0)
stackableConv.add(new ImgZeroPaddingLayer(paddingX, paddingY)).freeRef();
RefUtil.freeRef(stackableConv.add(convolutionLayer.addRef()));
if (paddingY != 0 || paddingX != 0) {
final Layer nextHead = new ImgZeroPaddingLayer(-paddingX, -paddingY);
RefUtil.freeRef(stackableConv.add(nextHead.addRef()));
nextHead.freeRef();
}
Precision precision = convolutionLayer.getPrecision();
int[] kernelDimensions = convolutionLayer.getKernelDimensions();
subKernels.add(convolutionLayer);
this.subLayers.add(getTileSubnet(stackableConv,
Math.max(filterDimensions[0], filterDimensions[1]), kernelDimensions, precision));
}
}
public int getInputBands() {
return this.toBand - this.fromBand;
}
public void write(@Nonnull Tensor filter) {
assert filter.rms() > 0;
int inputBands = getInputBands();
@Nonnull final int[] filterDimensions = RefArrays.copyOf(this.convolutionParams.masterFilterDimensions,
this.convolutionParams.masterFilterDimensions.length);
int outputBands = this.convolutionParams.outputBands;
int squareOutputBands = (int) (Math.ceil(convolutionParams.outputBands * 1.0 / inputBands) * inputBands);
assert squareOutputBands >= convolutionParams.outputBands : RefString.format("%d >= %d", squareOutputBands,
convolutionParams.outputBands);
assert squareOutputBands % inputBands == 0 : RefString.format("%d %% %d", squareOutputBands, inputBands);
filterDimensions[2] = inputBands * outputBands;
assert RefArrays.equals(filter.getDimensions(), filterDimensions) : RefArrays.toString(filter.getDimensions())
+ " != " + RefArrays.toString(filterDimensions);
final int inputBandsSq = inputBands * inputBands;
assert subLayers.size() > 0;
RefIntStream.range(0, subLayers.size()).parallel().forEach(layerNumber -> {
final int filterBandOffset = layerNumber * inputBandsSq;
Tensor kernel = new Tensor(filterDimensions[0], filterDimensions[1], inputBandsSq);
kernel.setByCoord(c -> {
int[] coords = c.getCoords();
int filterBand = getFilterBand(filterBandOffset, coords[2], squareOutputBands);
if (filterBand < filterDimensions[2]) {
return filter.get(coords[0], coords[1], filterBand);
} else {
return 0;
}
}, true);
assert kernel.rms() > 0;
SimpleConvolutionLayer simpleConvolutionLayer = subKernels.get(layerNumber);
simpleConvolutionLayer.set(kernel);
simpleConvolutionLayer.freeRef();
});
filter.freeRef();
}
@Nonnull
public Tensor read(@Nonnull @RefAware Function extractor) {
int inputBands = getInputBands();
@Nonnull final int[] filterDimensions = RefArrays.copyOf(this.convolutionParams.masterFilterDimensions,
this.convolutionParams.masterFilterDimensions.length);
filterDimensions[2] = inputBands * this.convolutionParams.outputBands;
int outputBands = convolutionParams.outputBands;
int squareOutputBands = (int) (Math.ceil(convolutionParams.outputBands * 1.0 / inputBands) * inputBands);
assert squareOutputBands >= convolutionParams.outputBands : RefString.format("%d >= %d", squareOutputBands,
convolutionParams.outputBands);
assert squareOutputBands % inputBands == 0 : RefString.format("%d %% %d", squareOutputBands, inputBands);
@Nonnull
Tensor resultDelta = new Tensor(filterDimensions[0], filterDimensions[1], inputBands * outputBands);
for (int layerNumber = 0; layerNumber < subLayers.size(); layerNumber++) {
int _layerNumber = layerNumber;
Tensor deltaTensor = extractor.apply(subKernels.get(layerNumber));
if (null != deltaTensor) {
deltaTensor.forEach(RefUtil.wrapInterface((v, c) -> {
int[] coords = c.getCoords();
int filterBand = getFilterBand(_layerNumber * inputBands * inputBands, coords[2], squareOutputBands);
if (filterBand < filterDimensions[2]) {
resultDelta.set(coords[0], coords[1], filterBand, v);
}
}, resultDelta.addRef()), false);
}
if (null != deltaTensor)
deltaTensor.freeRef();
}
RefUtil.freeRef(extractor);
return resultDelta;
}
public int getFilterBand(int filterBandOffset, int cellFilterBand, int squareOutputBands) {
int inputBands = getInputBands();
assert cellFilterBand >= 0;
assert cellFilterBand < inputBands * inputBands;
assert filterBandOffset < inputBands * squareOutputBands;
int filterBand = cellFilterBand + filterBandOffset;
filterBand = Coordinate.transposeXY(inputBands, convolutionParams.outputBands, filterBand);
return filterBand;
}
@Nonnull
public Tensor read(@Nonnull DeltaSet deltaSet, boolean remove) {
return read(RefUtil.wrapInterface(sublayer -> {
RefMap> map = deltaSet.getMap();
Delta uuidDelta = map.get(sublayer.getId());
assert uuidDelta != null;
final Delta subnetDelta;
if (remove) {
subnetDelta = map.remove(sublayer.addRef());
uuidDelta.freeRef();
} else {
subnetDelta = uuidDelta;
}
map.freeRef();
if (null == subnetDelta) {
String toString = sublayer.toString();
sublayer.freeRef();
throw new RuntimeException("No Delta for " + toString);
}
Tensor kernel = new Tensor(subnetDelta.getDelta(), sublayer.getKernelDimensions());
sublayer.freeRef();
subnetDelta.freeRef();
return kernel;
}, deltaSet));
}
@Nonnull
public Tensor read() {
return read(sublayer -> {
Tensor kernel = sublayer.getKernel();
sublayer.freeRef();
return kernel;
});
}
@Nullable
public DAGNode add(@Nonnull final DAGNode input, DAGNetwork network) {
assertAlive();
if (getInputBands() == this.convolutionParams.outputBands) {
assert 1 == subLayers.size();
assert network != null;
InnerNode node = network.add(subLayers.get(0), input);
network.freeRef();
return node;
} else {
ImgConcatLayer concatLayer = new ImgConcatLayer();
concatLayer.setMaxBands(this.convolutionParams.outputBands);
concatLayer.setPrecision(this.convolutionParams.precision);
assert network != null;
concatLayer.setParallel(CudaSettings.INSTANCE().conv_para_2);
InnerNode node = network.add(concatLayer,
subLayers.stream().map(RefUtil.wrapInterface((Function super Layer, ? extends InnerNode>) layer -> {
return network.add(layer, input.addRef());
}, network, input)).toArray(i -> new DAGNode[i]));
node.setParallel(CudaSettings.INSTANCE().conv_para_2);
return node;
}
}
@Nonnull
@Override
public String toString() {
return "ExplodedConvolutionLeg{" + "fromBand=" + fromBand + ", toBand=" + toBand + '}';
}
public void _free() {
subKernels.freeRef();
subLayers.freeRef();
super._free();
}
@Nonnull
public @Override
@SuppressWarnings("unused")
ExplodedConvolutionLeg addRef() {
return (ExplodedConvolutionLeg) super.addRef();
}
@Nonnull
private ImgTileSubnetLayer getTileSubnet(@Nullable final Layer network, final int bands, final int[] kernelDimensions,
final Precision precision) {
int maxSize = (int) Math.sqrt(CudaSettings.INSTANCE().maxIoElements / bands);
int width = kernelDimensions[0];
int height = kernelDimensions[1];
ImgTileSubnetLayer subnetLayer = new ImgTileSubnetLayer(network, maxSize,
maxSize, maxSize - (width - 1) / 2, maxSize - (height - 1) / 2);
subnetLayer.setParallel(CudaSettings.INSTANCE().conv_para_3);
subnetLayer.setPrecision(precision);
return subnetLayer;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy