com.simiacryptus.mindseye.layers.cudnn.conv.ExplodedConvolutionGrid Maven / Gradle / Ivy
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.layers.cudnn.conv;
import com.simiacryptus.mindseye.lang.DeltaSet;
import com.simiacryptus.mindseye.lang.Tensor;
import com.simiacryptus.mindseye.lang.cudnn.CudaSettings;
import com.simiacryptus.mindseye.layers.cudnn.ImgLinearSubnetLayer;
import com.simiacryptus.mindseye.layers.cudnn.ImgZeroPaddingLayer;
import com.simiacryptus.mindseye.network.DAGNetwork;
import com.simiacryptus.mindseye.network.DAGNode;
import com.simiacryptus.mindseye.network.InnerNode;
import com.simiacryptus.mindseye.network.PipelineNetwork;
import com.simiacryptus.ref.lang.RefAware;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.lang.ReferenceCountingBase;
import com.simiacryptus.ref.wrappers.RefCollectors;
import com.simiacryptus.ref.wrappers.RefFunction;
import com.simiacryptus.ref.wrappers.RefIntStream;
import com.simiacryptus.ref.wrappers.RefList;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.UUID;
import java.util.function.Consumer;
class ExplodedConvolutionGrid extends ReferenceCountingBase {
private static final Logger log = LoggerFactory.getLogger(ExplodedConvolutionGrid.class);
@Nonnull
public final RefList subLayers;
@Nonnull
public final ConvolutionParams convolutionParams;
public ExplodedConvolutionGrid(@Nonnull ConvolutionParams convolutionParams, int maxBandBatch) {
this.convolutionParams = convolutionParams;
int bandWidth = maxBandBatch == 0 ? convolutionParams.inputBands : maxBandBatch;
int rows = (int) Math.ceil((double) convolutionParams.inputBands / bandWidth);
subLayers = RefIntStream.range(0, rows).map(x -> x * bandWidth)
.mapToObj(fromBand -> {
int toBand = Math.min(convolutionParams.inputBands, fromBand + bandWidth);
if (fromBand >= toBand)
throw new RuntimeException(fromBand + " >= " + toBand);
return new ExplodedConvolutionLeg(convolutionParams, fromBand, toBand);
}).collect(RefCollectors.toList());
}
@Nonnull
public PipelineNetwork getNetwork() {
assertAlive();
@Nonnull
PipelineNetwork network = new PipelineNetwork(1);
add(network.getInput(0), network.addRef());
return network;
}
public void write(@Nonnull Tensor filter) {
assert filter.rms() > 0;
if (1 == subLayers.size()) {
ExplodedConvolutionLeg leg = subLayers.get(0);
leg.write(filter);
leg.freeRef();
} else {
subLayers.forEach(leg -> {
@Nonnull
int[] legDims = {convolutionParams.masterFilterDimensions[0], convolutionParams.masterFilterDimensions[1],
leg.getInputBands() * convolutionParams.outputBands};
@Nonnull
Tensor template = new Tensor(legDims);
@Nullable
Tensor tensor = template.mapCoords(RefUtil.wrapInterface(c -> {
int[] coords = c.getCoords();
return filter.get(coords[0], coords[1], getFilterBand(leg.addRef(), coords[2]));
}, leg.addRef(), filter.addRef()), false);
template.freeRef();
assert tensor.rms() > 0;
leg.write(tensor);
leg.freeRef();
});
filter.freeRef();
}
}
public Tensor read(@Nonnull @RefAware RefFunction extractor) {
if (1 == subLayers.size()) {
Tensor tensor = extractor.apply(subLayers.get(0));
RefUtil.freeRef(extractor);
return tensor;
} else {
@Nonnull final Tensor filterDelta = new Tensor(convolutionParams.masterFilterDimensions);
subLayers.forEach(leg -> {
Tensor tensor = extractor.apply(leg == null ? null : leg.addRef());
tensor.forEach(RefUtil.wrapInterface((v, c) -> {
int[] coords = c.getCoords();
filterDelta.set(coords[0], coords[1], getFilterBand(leg == null ? null : leg.addRef(), coords[2]), v);
}, leg, filterDelta.addRef()), false);
tensor.freeRef();
});
RefUtil.freeRef(extractor);
return filterDelta;
}
}
public Tensor read() {
return read(l -> {
Tensor read = l.read();
l.freeRef();
return read;
});
}
public Tensor read(@Nonnull DeltaSet deltaSet, boolean remove) {
return read(RefUtil.wrapInterface(l -> {
Tensor tensor = l.read(deltaSet.addRef(), remove);
l.freeRef();
return tensor;
}, deltaSet));
}
public void add(@Nonnull DAGNode input, DAGNetwork network) {
assertAlive();
int defaultPaddingX = 0;
int defaultPaddingY = 0;
boolean customPaddingX = this.convolutionParams.paddingX != null && convolutionParams.paddingX != defaultPaddingX;
boolean customPaddingY = this.convolutionParams.paddingY != null && convolutionParams.paddingY != defaultPaddingY;
DAGNode paddedInput = null;
if (customPaddingX || customPaddingY) {
int x;
if (this.convolutionParams.paddingX < -defaultPaddingX) {
x = this.convolutionParams.paddingX + defaultPaddingX;
} else if (this.convolutionParams.paddingX > defaultPaddingX) {
x = this.convolutionParams.paddingX - defaultPaddingX;
} else {
x = 0;
}
int y;
if (this.convolutionParams.paddingY < -defaultPaddingY) {
y = this.convolutionParams.paddingY + defaultPaddingY;
} else if (this.convolutionParams.paddingY > defaultPaddingY) {
y = this.convolutionParams.paddingY - defaultPaddingY;
} else {
y = 0;
}
ImgZeroPaddingLayer zeroPaddingLayer = new ImgZeroPaddingLayer(x, y);
zeroPaddingLayer.setPrecision(convolutionParams.precision);
RefUtil.freeRef(paddedInput);
paddedInput = network.add(zeroPaddingLayer, input.addRef());
} else {
RefUtil.freeRef(paddedInput);
paddedInput = input.addRef();
}
input.freeRef();
final InnerNode output;
if (subLayers.size() == 1) {
ExplodedConvolutionLeg leg = subLayers.get(0);
output = (InnerNode) leg.add(paddedInput, network.addRef());
leg.freeRef();
} else {
ImgLinearSubnetLayer linearSubnetLayer = new ImgLinearSubnetLayer();
subLayers.forEach(RefUtil.wrapInterface((Consumer super ExplodedConvolutionLeg>) leg -> {
PipelineNetwork subnet = new PipelineNetwork();
RefUtil.freeRef(leg.add(subnet.getHead(), subnet.addRef()));
linearSubnetLayer.add(leg.fromBand, leg.toBand, subnet);
leg.freeRef();
}, linearSubnetLayer.addRef()));
boolean isParallel = CudaSettings.INSTANCE().conv_para_1;
linearSubnetLayer.setPrecision(convolutionParams.precision);
linearSubnetLayer.setParallel(isParallel);
assert network != null;
output = network.add(linearSubnetLayer, paddedInput);
output.setParallel(isParallel);
}
if (customPaddingX || customPaddingY) {
int x = !customPaddingX ? 0 : this.convolutionParams.paddingX - defaultPaddingX;
int y = !customPaddingY ? 0 : this.convolutionParams.paddingY - defaultPaddingY;
if (x > 0)
x = 0;
if (y > 0)
y = 0;
if (x != 0 || y != 0) {
ImgZeroPaddingLayer zeroPaddingLayer = new ImgZeroPaddingLayer(x, y);
zeroPaddingLayer.setPrecision(convolutionParams.precision);
RefUtil.freeRef(network.add(zeroPaddingLayer, output));
} else {
if (null != output)
output.freeRef();
}
} else {
if (null != output)
output.freeRef();
}
if (null != network)
network.freeRef();
}
public void _free() {
subLayers.freeRef();
super._free();
}
@Nonnull
public @Override
@SuppressWarnings("unused")
ExplodedConvolutionGrid addRef() {
return (ExplodedConvolutionGrid) super.addRef();
}
private int getFilterBand(@Nonnull ExplodedConvolutionLeg leg, int legFilterBand) {
int filterBand = legFilterBand;
filterBand = filterBand + convolutionParams.outputBands * leg.fromBand;
leg.freeRef();
return filterBand;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy