com.simiacryptus.mindseye.layers.cudnn.ImgBandSelectLayer Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mindseye-cudnn Show documentation
Show all versions of mindseye-cudnn Show documentation
CuDNN Neural Network Components
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.layers.cudnn;
import com.google.gson.JsonObject;
import com.simiacryptus.lang.ref.ReferenceCounting;
import com.simiacryptus.mindseye.lang.*;
import com.simiacryptus.mindseye.lang.cudnn.*;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.IntStream;
import java.util.stream.Stream;
@SuppressWarnings("serial")
public class ImgBandSelectLayer extends LayerBase implements MultiPrecision {
private int from;
private int to;
private Precision precision = CudaSettings.INSTANCE().defaultPrecision;
public ImgBandSelectLayer(int from, int to) {
this.setFrom(from);
this.setTo(to);
}
protected ImgBandSelectLayer(@Nonnull final JsonObject json) {
super(json);
setFrom(json.get("from").getAsInt());
setTo(json.get("to").getAsInt());
precision = Precision.valueOf(json.get("precision").getAsString());
}
public static ImgBandSelectLayer fromJson(@Nonnull final JsonObject json, Map rs) {
return new ImgBandSelectLayer(json);
}
@Nonnull
public Layer getCompatibilityLayer() {
return new com.simiacryptus.mindseye.layers.java.ImgBandSelectLayer(IntStream.range(getFrom(), getTo()).toArray());
}
@Nullable
@Override
public Result evalAndFree(@Nonnull final Result... inObj) {
//assert Arrays.stream(this.bias).allMatch(Double::isFinite);
//assert Arrays.stream(inObj).flatMapToDouble(input->input.data.stream().flatMapToDouble(x-> Arrays.stream(x.getData()))).allMatch(v->Double.isFinite(v));
assert getFrom() < getTo();
assert getFrom() >= 0;
assert getTo() > 0;
assert 1 == inObj.length;
final Result in0 = inObj[0];
assert 3 == in0.getData().getDimensions().length;
if (!CudaSystem.isEnabled()) return getCompatibilityLayer().evalAndFree(inObj);
final TensorList inputData = in0.getData();
@Nonnull final int[] inputDimensions = inputData.getDimensions();
final int length = inputData.length();
@Nonnull final int[] outputDimensions = Arrays.copyOf(inputDimensions, 3);
outputDimensions[2] = getTo() - getFrom();
long size = (length * outputDimensions[2] * outputDimensions[1] * outputDimensions[0] * precision.size);
return new Result(CudaSystem.run(gpu -> {
@Nullable final CudaTensor cudaInput = gpu.getTensor(inputData, precision, MemoryType.Device, false);
inputData.freeRef();
final int byteOffset = cudaInput.descriptor.cStride * getFrom() * precision.size;
@Nonnull final CudaDevice.CudaTensorDescriptor inputDescriptor = gpu.newTensorDescriptor(
precision, length, outputDimensions[2], outputDimensions[1], outputDimensions[0], //
cudaInput.descriptor.nStride, //
cudaInput.descriptor.cStride, //
cudaInput.descriptor.hStride, //
cudaInput.descriptor.wStride);
CudaMemory cudaInputMemory = cudaInput.getMemory(gpu);
assert CudaDevice.isThreadDeviceId(gpu.getDeviceId());
CudaTensor cudaTensor = CudaTensor.wrap(cudaInputMemory.withByteOffset(byteOffset), inputDescriptor, precision);
Stream.of(cudaInput, cudaInputMemory).forEach(ReferenceCounting::freeRef);
return CudaTensorList.wrap(cudaTensor, length, outputDimensions, precision);
}, inputData), (@Nonnull final DeltaSet buffer, @Nonnull final TensorList delta) -> {
if (!Arrays.equals(delta.getDimensions(), outputDimensions)) {
throw new AssertionError(Arrays.toString(delta.getDimensions()) + " != " + Arrays.toString(outputDimensions));
}
if (in0.isAlive()) {
final TensorList passbackTensorList = CudaSystem.run(gpu -> {
@Nonnull final CudaDevice.CudaTensorDescriptor viewDescriptor = gpu.newTensorDescriptor(
precision, length, outputDimensions[2], outputDimensions[1], outputDimensions[0], //
inputDimensions[2] * inputDimensions[1] * inputDimensions[0], //
inputDimensions[1] * inputDimensions[0], //
inputDimensions[0], //
1);
@Nonnull final CudaDevice.CudaTensorDescriptor inputDescriptor = gpu.newTensorDescriptor(
precision, length, inputDimensions[2], inputDimensions[1], inputDimensions[0], //
inputDimensions[2] * inputDimensions[1] * inputDimensions[0], //
inputDimensions[1] * inputDimensions[0], //
inputDimensions[0], //
1);
final int byteOffset = viewDescriptor.cStride * getFrom() * precision.size;
assert delta.length() == length;
//assert error.stream().flatMapToDouble(x-> Arrays.stream(x.getData())).allMatch(Double::isFinite);
@Nullable final CudaTensor errorPtr = gpu.getTensor(delta, precision, MemoryType.Device, false);
delta.freeRef();
long size1 = (length * inputDimensions[2] * inputDimensions[1] * inputDimensions[0] * precision.size);
@Nonnull final CudaMemory passbackBuffer = gpu.allocate(size1, MemoryType.Managed.ifEnabled(), false);
CudaMemory errorPtrMemory = errorPtr.getMemory(gpu);
gpu.cudnnTransformTensor(
precision.getPointer(1.0), errorPtr.descriptor.getPtr(), errorPtrMemory.getPtr(),
precision.getPointer(0.0), viewDescriptor.getPtr(), passbackBuffer.getPtr().withByteOffset(byteOffset)
);
errorPtrMemory.dirty();
passbackBuffer.dirty();
errorPtrMemory.freeRef();
CudaTensor cudaTensor = CudaTensor.wrap(passbackBuffer, inputDescriptor, precision);
Stream.of(errorPtr, viewDescriptor).forEach(ReferenceCounting::freeRef);
return CudaTensorList.wrap(cudaTensor, length, inputDimensions, precision);
//assert passbackTensorList.stream().flatMapToDouble(x-> Arrays.stream(x.getData())).allMatch(v->Double.isFinite(v));
}, delta);
in0.accumulate(buffer, passbackTensorList);
} else {
delta.freeRef();
}
}) {
@Override
public void accumulate(final DeltaSet buffer, final TensorList delta) {
getAccumulator().accept(buffer, delta);
}
@Override
protected void _free() {
Arrays.stream(inObj).forEach(nnResult -> nnResult.freeRef());
}
@Override
public boolean isAlive() {
return Arrays.stream(inObj).anyMatch(x -> x.isAlive());
}
};
}
@Nonnull
@Override
public JsonObject getJson(Map resources, DataSerializer dataSerializer) {
@Nonnull final JsonObject json = super.getJsonStub();
json.addProperty("from", getFrom());
json.addProperty("to", getTo());
json.addProperty("precision", precision.name());
return json;
}
public int getFrom() {
return from;
}
@Nonnull
public ImgBandSelectLayer setFrom(final int from) {
this.from = from;
return this;
}
@Override
public Precision getPrecision() {
return precision;
}
@Nonnull
@Override
public ImgBandSelectLayer setPrecision(final Precision precision) {
this.precision = precision;
return this;
}
@Nonnull
@Override
public List state() {
return Arrays.asList();
}
public int getTo() {
return to;
}
@Nonnull
public ImgBandSelectLayer setTo(int to) {
this.to = to;
return this;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy