com.simiacryptus.mindseye.layers.cudnn.conv.ExplodedConvolutionGrid Maven / Gradle / Ivy
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.layers.cudnn.conv;
import com.simiacryptus.lang.ref.ReferenceCountingBase;
import com.simiacryptus.mindseye.lang.DeltaSet;
import com.simiacryptus.mindseye.lang.Tensor;
import com.simiacryptus.mindseye.lang.cudnn.CudaSettings;
import com.simiacryptus.mindseye.layers.cudnn.ImgLinearSubnetLayer;
import com.simiacryptus.mindseye.layers.cudnn.ImgZeroPaddingLayer;
import com.simiacryptus.mindseye.network.DAGNetwork;
import com.simiacryptus.mindseye.network.DAGNode;
import com.simiacryptus.mindseye.network.InnerNode;
import com.simiacryptus.mindseye.network.PipelineNetwork;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.List;
import java.util.UUID;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
class ExplodedConvolutionGrid extends ReferenceCountingBase {
private static final Logger log = LoggerFactory.getLogger(ExplodedConvolutionGrid.class);
public final List subLayers;
@Nonnull
public final ConvolutionParams convolutionParams;
public ExplodedConvolutionGrid(@Nonnull ConvolutionParams convolutionParams, int maxBandBatch) {
this.convolutionParams = convolutionParams;
int bandWidth = (maxBandBatch == 0) ? convolutionParams.inputBands : maxBandBatch;
int rows = (int) Math.ceil((double) convolutionParams.inputBands / bandWidth);
subLayers = IntStream.range(0, rows).map(x -> x * bandWidth).mapToObj(fromBand -> {
int toBand = Math.min(convolutionParams.inputBands, fromBand + bandWidth);
if (fromBand >= toBand) throw new RuntimeException(fromBand + " >= " + toBand);
return new ExplodedConvolutionLeg(convolutionParams, fromBand, toBand);
}).collect(Collectors.toList());
}
@Override
protected void _free() {
subLayers.stream().forEach(ReferenceCountingBase::freeRef);
super._free();
}
@Nonnull
public ExplodedConvolutionGrid write(@Nonnull Tensor filter) {
if (1 == subLayers.size()) {
subLayers.get(0).write(filter);
} else {
for (@Nonnull ExplodedConvolutionLeg leg : subLayers) {
@Nonnull int[] legDims = {convolutionParams.masterFilterDimensions[0], convolutionParams.masterFilterDimensions[1], leg.getInputBands() * convolutionParams.outputBands};
@Nonnull Tensor template = new Tensor(legDims);
@Nullable Tensor tensor = template.mapCoords(c -> {
int[] coords = c.getCoords();
return filter.get(coords[0], coords[1], getFilterBand(leg, coords[2]));
}, false);
template.freeRef();
leg.write(tensor);
tensor.freeRef();
}
}
return this;
}
public Tensor read(@Nonnull Function extractor) {
if (1 == subLayers.size()) {
return extractor.apply(subLayers.get(0));
} else {
@Nonnull final Tensor filterDelta = new Tensor(convolutionParams.masterFilterDimensions);
for (@Nonnull ExplodedConvolutionLeg leg : subLayers) {
Tensor tensor = extractor.apply(leg);
tensor.forEach((v, c) -> {
int[] coords = c.getCoords();
filterDelta.set(coords[0], coords[1], getFilterBand(leg, coords[2]), v);
}, false);
tensor.freeRef();
}
return filterDelta;
}
}
public Tensor read() {
return read(l -> l.read());
}
public Tensor read(@Nonnull DeltaSet deltaSet, boolean remove) {
return read(l -> l.read(deltaSet, remove));
}
private int getFilterBand(@Nonnull ExplodedConvolutionLeg leg, int legFilterBand) {
int filterBand = legFilterBand;
filterBand = filterBand + convolutionParams.outputBands * leg.fromBand;
return filterBand;
}
@Nonnull
public PipelineNetwork getNetwork() {
assertAlive();
@Nonnull PipelineNetwork network = new PipelineNetwork(1);
add(network.getInput(0)).freeRef();
return network;
}
public DAGNode add(@Nonnull DAGNode input) {
assertAlive();
DAGNetwork network = input.getNetwork();
int defaultPaddingX = 0;
int defaultPaddingY = 0;
boolean customPaddingX = this.convolutionParams.paddingX != null && convolutionParams.paddingX != defaultPaddingX;
boolean customPaddingY = this.convolutionParams.paddingY != null && convolutionParams.paddingY != defaultPaddingY;
final DAGNode paddedInput;
if (customPaddingX || customPaddingY) {
int x;
if (this.convolutionParams.paddingX < -defaultPaddingX) {
x = this.convolutionParams.paddingX + defaultPaddingX;
} else if (this.convolutionParams.paddingX > defaultPaddingX) {
x = this.convolutionParams.paddingX - defaultPaddingX;
} else {
x = 0;
}
int y;
if (this.convolutionParams.paddingY < -defaultPaddingY) {
y = this.convolutionParams.paddingY + defaultPaddingY;
} else if (this.convolutionParams.paddingY > defaultPaddingY) {
y = this.convolutionParams.paddingY - defaultPaddingY;
} else {
y = 0;
}
if (x != 0 || y != 0) {
paddedInput = network.wrap(new ImgZeroPaddingLayer(x, y).setPrecision(convolutionParams.precision), input);
} else {
paddedInput = input;
}
} else {
paddedInput = input;
}
InnerNode output;
if (subLayers.size() == 1) {
output = (InnerNode) subLayers.get(0).add(paddedInput);
} else {
ImgLinearSubnetLayer linearSubnetLayer = new ImgLinearSubnetLayer();
subLayers.forEach(leg -> {
PipelineNetwork subnet = new PipelineNetwork();
leg.add(subnet.getHead());
linearSubnetLayer.add(leg.fromBand, leg.toBand, subnet);
});
boolean isParallel = CudaSettings.INSTANCE().isConv_para_1();
linearSubnetLayer.setPrecision(convolutionParams.precision).setParallel(isParallel);
output = network.wrap(linearSubnetLayer, paddedInput).setParallel(isParallel);
}
if (customPaddingX || customPaddingY) {
int x = !customPaddingX ? 0 : (this.convolutionParams.paddingX - defaultPaddingX);
int y = !customPaddingY ? 0 : (this.convolutionParams.paddingY - defaultPaddingY);
if (x > 0) x = 0;
if (y > 0) y = 0;
if (x != 0 || y != 0) {
return network.wrap(new ImgZeroPaddingLayer(x, y).setPrecision(convolutionParams.precision), output);
}
}
return output;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy