com.simiacryptus.mindseye.layers.cudnn.ImgTileAssemblyLayer Maven / Gradle / Ivy
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.layers.cudnn;
import com.google.gson.JsonObject;
import com.simiacryptus.lang.ref.ReferenceCounting;
import com.simiacryptus.mindseye.lang.*;
import com.simiacryptus.mindseye.lang.cudnn.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.*;
import java.util.stream.Stream;
@SuppressWarnings("serial")
public class ImgTileAssemblyLayer extends LayerBase implements MultiPrecision {
private static final Logger log = LoggerFactory.getLogger(ImgTileAssemblyLayer.class);
private int columns;
private int rows;
private Precision precision = CudaSettings.INSTANCE().defaultPrecision;
private boolean parallel;
private ImgTileAssemblyLayer() {
}
public ImgTileAssemblyLayer(int columns, int rows) {
this.columns = columns;
this.rows = rows;
}
protected ImgTileAssemblyLayer(@Nonnull final JsonObject json, Map rs) {
super(json);
columns = json.get("columns").getAsInt();
rows = json.get("rows").getAsInt();
this.parallel = json.get("parallel").getAsBoolean();
this.precision = Precision.valueOf(json.getAsJsonPrimitive("precision").getAsString());
}
public static ImgTileAssemblyLayer fromJson(@Nonnull final JsonObject json, Map rs) {
return new ImgTileAssemblyLayer(json, rs);
}
@Nonnull
public Layer getCompatibilityLayer() {
return this.as(com.simiacryptus.mindseye.layers.java.ImgTileAssemblyLayer.class);
}
@Nullable
@Override
public Result evalAndFree(@Nonnull final Result... inObj) {
if (!CudaSystem.isEnabled()) return getCompatibilityLayer().evalAndFree(inObj);
if (1 == inObj.length) {
return inObj[0];
}
int[] inputDimensions = inObj[0].getData().getDimensions();
assert 3 == inputDimensions.length;
final int length = inObj[0].getData().length();
int[] outputDims = getOutputDims(inObj);
final TensorList outputData = CudaSystem.run(gpu -> {
assert CudaDevice.isThreadDeviceId(gpu.getDeviceId());
assert outputDims[0] > 0;
assert outputDims[1] > 0;
assert outputDims[2] > 0;
@Nonnull final CudaMemory outputBuffer = gpu.allocate(
(long) length * outputDims[2] * outputDims[1] * outputDims[0] * precision.size, MemoryType.Managed.ifEnabled(), false);
int totalWidth = 0;
int totalHeight = 0;
int inputIndex = 0;
List copies = new ArrayList<>();
for (int row = 0; row < rows; row++) {
int positionX = 0;
int rowHeight = 0;
for (int col = 0; col < columns; col++) {
int[] tileDimensions = inObj[inputIndex].getData().getDimensions();
rowHeight = Math.max(rowHeight, tileDimensions[1]);
copies.add(new CopyParams(gpu, inObj, outputBuffer, length, outputDims, tileDimensions, inputIndex, positionX, totalHeight));
positionX += tileDimensions[0];
inputIndex += 1;
assert CudaDevice.isThreadDeviceId(gpu.getDeviceId());
}
totalHeight += rowHeight;
totalWidth = Math.max(totalWidth, positionX);
}
assert CudaDevice.isThreadDeviceId(gpu.getDeviceId());
Stream stream = copies.stream();
if (!CoreSettings.INSTANCE().isSingleThreaded() && parallel) stream = stream.parallel();
stream.forEach(this::copy);
Arrays.stream(inObj).forEach(r -> r.getData().freeRef());
CudaDevice.CudaTensorDescriptor descriptor = gpu.newTensorDescriptor(precision, length, outputDims[2], outputDims[1], outputDims[0]);
CudaTensor ptr = CudaTensor.wrap(outputBuffer, descriptor, precision);
return CudaTensorList.wrap(ptr, length, outputDims, precision);
}, Arrays.stream(inObj).map(Result::getData).toArray());
return new Result(outputData, (@Nonnull final DeltaSet buffer, @Nonnull final TensorList error) -> {
if (!Arrays.equals(error.getDimensions(), outputData.getDimensions())) {
throw new AssertionError(Arrays.toString(error.getDimensions()) + " != " + Arrays.toString(outputData.getDimensions()));
}
if (error.length() != outputData.length()) {
throw new AssertionError(error.length() + " != " + outputData.length());
}
assert error.length() == length;
int totalHeight = 0;
int inputIndex = 0;
List tasks = new ArrayList<>();
for (int row = 0; row < rows; row++) {
int positionX = 0;
int rowHeight = 0;
for (int col = 0; col < columns; col++) {
Result in = inObj[inputIndex];
int[] tileDimensions = in.getData().getDimensions();
rowHeight = Math.max(rowHeight, tileDimensions[1]);
if (inObj[inputIndex].isAlive()) {
tasks.add(new BackpropParams(inObj, buffer, error, outputDims, tileDimensions, length, positionX, totalHeight, inputIndex));
}
positionX += tileDimensions[0];
inputIndex += 1;
}
totalHeight += rowHeight;
}
Stream stream = tasks.stream();
if (!CoreSettings.INSTANCE().isSingleThreaded() && parallel) stream = stream.parallel();
stream.forEach(this::backprop);
error.freeRef();
}) {
@Override
protected void _free() {
Arrays.stream(inObj).forEach(nnResult -> nnResult.freeRef());
}
@Override
public boolean isAlive() {
return Arrays.stream(inObj).anyMatch(x -> x.isAlive());
}
};
}
public void backprop(final BackpropParams backpropParams) {
final TensorList passbackTensorList = CudaSystem.run(gpu -> {
CudaTensor ptr = copy(gpu, backpropParams.error, backpropParams.tileDimensions, backpropParams.outputDims, backpropParams.length, -backpropParams.positionX, -backpropParams.totalHeight);
return CudaTensorList.wrap(ptr, backpropParams.length, backpropParams.tileDimensions, precision);
}, backpropParams.error);
backpropParams.inObj[backpropParams.inputIndex].accumulate(backpropParams.buffer, passbackTensorList);
}
public CudaTensor copy(final CudnnHandle gpu, final TensorList error, final int[] tileDimensions, final int[] outputDims, final int length, final int positionX, final int positionY) {
@Nullable final CudaTensor errorPtr = gpu.getTensor(error, precision, MemoryType.Device, false);
@Nonnull final CudaMemory passbackBuffer = gpu.allocate(
(long) length * tileDimensions[2] * tileDimensions[1] * tileDimensions[0] * precision.size, MemoryType.Managed.ifEnabled(), false);
copy(gpu, length, outputDims, errorPtr, tileDimensions, passbackBuffer, positionX, positionY);
errorPtr.freeRef();
CudaDevice.CudaTensorDescriptor descriptor = gpu.newTensorDescriptor(precision, length, tileDimensions[2], tileDimensions[1], tileDimensions[0]);
return CudaTensor.wrap(passbackBuffer, descriptor, precision);
}
public void copy(final CopyParams copyParams) {
CudnnHandle gpu = copyParams.gpu;
gpu.initThread();
assert CudaDevice.isThreadDeviceId(gpu.getDeviceId());
@Nullable final CudaTensor inputBuffer = gpu.getTensor(copyParams.inObj[copyParams.inputIndex].getData(), precision, MemoryType.Device, false);
copy(gpu, copyParams.length, copyParams.tileDimensions, inputBuffer, copyParams.outputDims, copyParams.outputBuffer, copyParams.positionX, copyParams.totalHeight);
inputBuffer.freeRef();
}
private int[] getOutputDims(final Result[] inObj) {
int bands = inObj[0].getData().getDimensions()[2];
int totalWidth = 0;
int totalHeight = 0;
int inputIndex = 0;
for (int row = 0; row < rows; row++) {
int positionX = 0;
int rowHeight = 0;
for (int col = 0; col < columns; col++) {
int[] dimensions = inObj[inputIndex].getData().getDimensions();
rowHeight = Math.max(rowHeight, dimensions[1]);
positionX += dimensions[0];
inputIndex += 1;
}
totalHeight += rowHeight;
totalWidth = Math.max(totalWidth, positionX);
}
return new int[]{totalWidth, totalHeight, bands};
}
public int[] copy(@Nonnull CudnnHandle gpu, int length, @Nonnull int[] sourceDimensions, @Nonnull CudaTensor source, @Nonnull int[] destinationDimensions, @Nonnull CudaMemory destination, int positionX, int positionY) {
if (3 != sourceDimensions.length) throw new IllegalArgumentException("inputDimensions.length");
if (3 != destinationDimensions.length) throw new IllegalArgumentException("dimOut.length");
int bands = sourceDimensions[2];
if (bands != destinationDimensions[2])
throw new IllegalArgumentException(String.format("%d != %d", bands, destinationDimensions[2]));
//log.info(String.format("offset=%d,%d", offsetX, offsetY));
@Nonnull final int[] viewDim = getViewDimensions(sourceDimensions, destinationDimensions, new int[]{positionX, positionY, 0});
@Nonnull final CudaDevice.CudaTensorDescriptor sourceViewDescriptor = gpu.newTensorDescriptor(
precision,//
length,//
viewDim[2],//
viewDim[1],//
viewDim[0],//
source.descriptor.nStride,//
source.descriptor.cStride,//
source.descriptor.hStride,//
source.descriptor.wStride);
@Nonnull final CudaDevice.CudaTensorDescriptor destinationViewDescriptor = gpu.newTensorDescriptor(
precision,//
length,//
viewDim[2],//
viewDim[1],//
viewDim[0],//
destinationDimensions[2] * destinationDimensions[1] * destinationDimensions[0],//
destinationDimensions[1] * destinationDimensions[0],//
destinationDimensions[0],//
1);
int sourceOffset = 0;
int destinationOffset = 0;
if (positionX > 0) {
destinationOffset += Math.abs(positionX);
} else {
sourceOffset += source.descriptor.wStride * Math.abs(positionX);
}
if (positionY > 0) {
destinationOffset += destinationDimensions[0] * Math.abs((positionY));
} else {
sourceOffset += source.descriptor.hStride * (Math.abs(positionY));
}
assert sourceOffset >= 0;
assert destinationOffset >= 0;
assert sourceOffset + Tensor.length(viewDim) <= (source.descriptor.nStride * length);
assert destinationOffset + Tensor.length(viewDim) <= Tensor.length(destinationDimensions);
CudaMemory sourceMemory = source.getMemory(gpu);
CudaSystem.handle(gpu.cudnnTransformTensor(
precision.getPointer(1.0),
sourceViewDescriptor.getPtr(), sourceMemory.getPtr().withByteOffset(sourceOffset * precision.size),
precision.getPointer(1.0),
destinationViewDescriptor.getPtr(), destination.getPtr().withByteOffset(destinationOffset * precision.size)
));
assert CudaDevice.isThreadDeviceId(gpu.getDeviceId());
sourceMemory.dirty();
destination.dirty();
sourceMemory.freeRef();
Arrays.stream(new ReferenceCounting[]{sourceViewDescriptor, destinationViewDescriptor}).forEach(ReferenceCounting::freeRef);
return viewDim;
}
@Nonnull
public int[] getViewDimensions(int[] sourceDimensions, int[] destinationDimensions, int[] offset) {
@Nonnull final int[] viewDim = new int[3];
Arrays.parallelSetAll(viewDim, i ->
Math.min(sourceDimensions[i] + offset[i], destinationDimensions[i]) -
Math.max(offset[i], 0)
);
return viewDim;
}
@Nonnull
@Override
public JsonObject getJson(Map resources, DataSerializer dataSerializer) {
@Nonnull final JsonObject json = super.getJsonStub();
json.addProperty("rows", rows);
json.addProperty("columns", columns);
json.addProperty("precision", precision.name());
json.addProperty("parallel", isParallel());
return json;
}
@Nonnull
@Override
public List state() {
return Arrays.asList();
}
@Override
public Precision getPrecision() {
return precision;
}
@Nonnull
@Override
public ImgTileAssemblyLayer setPrecision(final Precision precision) {
this.precision = precision;
return this;
}
public boolean isParallel() {
return parallel;
}
public ImgTileAssemblyLayer setParallel(final boolean parallel) {
this.parallel = parallel;
return this;
}
private static class CopyParams {
public final int length;
public final int[] outputDims;
public final CudnnHandle gpu;
public final CudaMemory outputBuffer;
public final int totalHeight;
public final int inputIndex;
public final int positionX;
public final int[] tileDimensions;
@Nonnull
public final Result[] inObj;
private CopyParams(final CudnnHandle gpu, @Nonnull final Result[] inObj, final CudaMemory outputBuffer, final int length, final int[] outputDims, final int[] tileDimensions, final int inputIndex, final int positionX, final int totalHeight) {
this.length = length;
this.outputDims = outputDims;
this.gpu = gpu;
this.outputBuffer = outputBuffer;
this.totalHeight = totalHeight;
this.inputIndex = inputIndex;
this.positionX = positionX;
this.tileDimensions = tileDimensions;
this.inObj = inObj;
}
}
private static class BackpropParams {
@Nonnull
public final Result[] inObj;
@Nonnull
public final DeltaSet buffer;
@Nonnull
public final TensorList error;
public final int[] outputDims;
public final int[] tileDimensions;
public final int length;
public final int positionX;
public final int totalHeight;
public final int inputIndex;
private BackpropParams(@Nonnull final Result[] inObj, @Nonnull final DeltaSet buffer, @Nonnull final TensorList error, final int[] outputDims, final int[] tileDimensions, final int length, final int positionX, final int totalHeight, final int inputIndex) {
this.inObj = inObj;
this.buffer = buffer;
this.error = error;
this.outputDims = outputDims;
this.tileDimensions = tileDimensions;
this.length = length;
this.positionX = positionX;
this.totalHeight = totalHeight;
this.inputIndex = inputIndex;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy