All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.layers.cudnn.ImgBandSelectLayer Maven / Gradle / Ivy

There is a newer version: 2.1.0
Show newest version
/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.layers.cudnn;

import com.google.gson.JsonObject;
import com.simiacryptus.lang.ref.ReferenceCounting;
import com.simiacryptus.mindseye.lang.*;
import com.simiacryptus.mindseye.lang.cudnn.*;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.IntStream;
import java.util.stream.Stream;

@SuppressWarnings("serial")
public class ImgBandSelectLayer extends LayerBase implements MultiPrecision {

  private int from;
  private int to;
  private Precision precision = CudaSettings.INSTANCE().defaultPrecision;

  public ImgBandSelectLayer(int from, int to) {
    this.setFrom(from);
    this.setTo(to);
  }

  protected ImgBandSelectLayer(@Nonnull final JsonObject json) {
    super(json);
    setFrom(json.get("from").getAsInt());
    setTo(json.get("to").getAsInt());
    precision = Precision.valueOf(json.get("precision").getAsString());
  }

  public static ImgBandSelectLayer fromJson(@Nonnull final JsonObject json, Map rs) {
    return new ImgBandSelectLayer(json);
  }

  @Nonnull
  public Layer getCompatibilityLayer() {
    return new com.simiacryptus.mindseye.layers.java.ImgBandSelectLayer(IntStream.range(getFrom(), getTo()).toArray());
  }


  @Nullable
  @Override
  public Result evalAndFree(@Nonnull final Result... inObj) {
    //assert Arrays.stream(this.bias).allMatch(Double::isFinite);
    //assert Arrays.stream(inObj).flatMapToDouble(input->input.data.stream().flatMapToDouble(x-> Arrays.stream(x.getData()))).allMatch(v->Double.isFinite(v));
    assert getFrom() < getTo();
    assert getFrom() >= 0;
    assert getTo() > 0;
    assert 1 == inObj.length;
    final Result in0 = inObj[0];
    assert 3 == in0.getData().getDimensions().length;
    if (!CudaSystem.isEnabled()) return getCompatibilityLayer().evalAndFree(inObj);
    final TensorList inputData = in0.getData();
    @Nonnull final int[] inputDimensions = inputData.getDimensions();
    final int length = inputData.length();
    @Nonnull final int[] outputDimensions = Arrays.copyOf(inputDimensions, 3);
    outputDimensions[2] = getTo() - getFrom();
    long size = (length * outputDimensions[2] * outputDimensions[1] * outputDimensions[0] * precision.size);
    return new Result(CudaSystem.run(gpu -> {
      @Nullable final CudaTensor cudaInput = gpu.getTensor(inputData, precision, MemoryType.Device, false);
      inputData.freeRef();
      final int byteOffset = cudaInput.descriptor.cStride * getFrom() * precision.size;
      @Nonnull final CudaDevice.CudaTensorDescriptor inputDescriptor = gpu.newTensorDescriptor(
          precision, length, outputDimensions[2], outputDimensions[1], outputDimensions[0], //
          cudaInput.descriptor.nStride, //
          cudaInput.descriptor.cStride, //
          cudaInput.descriptor.hStride, //
          cudaInput.descriptor.wStride);
      CudaMemory cudaInputMemory = cudaInput.getMemory(gpu);
      assert CudaDevice.isThreadDeviceId(gpu.getDeviceId());
      CudaTensor cudaTensor = CudaTensor.wrap(cudaInputMemory.withByteOffset(byteOffset), inputDescriptor, precision);
      Stream.of(cudaInput, cudaInputMemory).forEach(ReferenceCounting::freeRef);
      return CudaTensorList.wrap(cudaTensor, length, outputDimensions, precision);
    }, inputData), (@Nonnull final DeltaSet buffer, @Nonnull final TensorList delta) -> {
      if (!Arrays.equals(delta.getDimensions(), outputDimensions)) {
        throw new AssertionError(Arrays.toString(delta.getDimensions()) + " != " + Arrays.toString(outputDimensions));
      }
      if (in0.isAlive()) {
        final TensorList passbackTensorList = CudaSystem.run(gpu -> {
          @Nonnull final CudaDevice.CudaTensorDescriptor viewDescriptor = gpu.newTensorDescriptor(
              precision, length, outputDimensions[2], outputDimensions[1], outputDimensions[0], //
              inputDimensions[2] * inputDimensions[1] * inputDimensions[0], //
              inputDimensions[1] * inputDimensions[0], //
              inputDimensions[0], //
              1);
          @Nonnull final CudaDevice.CudaTensorDescriptor inputDescriptor = gpu.newTensorDescriptor(
              precision, length, inputDimensions[2], inputDimensions[1], inputDimensions[0], //
              inputDimensions[2] * inputDimensions[1] * inputDimensions[0], //
              inputDimensions[1] * inputDimensions[0], //
              inputDimensions[0], //
              1);
          final int byteOffset = viewDescriptor.cStride * getFrom() * precision.size;
          assert delta.length() == length;
          //assert error.stream().flatMapToDouble(x-> Arrays.stream(x.getData())).allMatch(Double::isFinite);
          @Nullable final CudaTensor errorPtr = gpu.getTensor(delta, precision, MemoryType.Device, false);
          delta.freeRef();
          long size1 = (length * inputDimensions[2] * inputDimensions[1] * inputDimensions[0] * precision.size);
          @Nonnull final CudaMemory passbackBuffer = gpu.allocate(size1, MemoryType.Managed.ifEnabled(), false);
          CudaMemory errorPtrMemory = errorPtr.getMemory(gpu);
          gpu.cudnnTransformTensor(
              precision.getPointer(1.0), errorPtr.descriptor.getPtr(), errorPtrMemory.getPtr(),
              precision.getPointer(0.0), viewDescriptor.getPtr(), passbackBuffer.getPtr().withByteOffset(byteOffset)
          );
          errorPtrMemory.dirty();
          passbackBuffer.dirty();
          errorPtrMemory.freeRef();
          CudaTensor cudaTensor = CudaTensor.wrap(passbackBuffer, inputDescriptor, precision);
          Stream.of(errorPtr, viewDescriptor).forEach(ReferenceCounting::freeRef);
          return CudaTensorList.wrap(cudaTensor, length, inputDimensions, precision);
          //assert passbackTensorList.stream().flatMapToDouble(x-> Arrays.stream(x.getData())).allMatch(v->Double.isFinite(v));
        }, delta);
        in0.accumulate(buffer, passbackTensorList);
      } else {
        delta.freeRef();
      }
    }) {

      @Override
      public void accumulate(final DeltaSet buffer, final TensorList delta) {
        getAccumulator().accept(buffer, delta);
      }


      @Override
      protected void _free() {
        Arrays.stream(inObj).forEach(nnResult -> nnResult.freeRef());
      }

      @Override
      public boolean isAlive() {
        return Arrays.stream(inObj).anyMatch(x -> x.isAlive());
      }
    };
  }

  @Nonnull
  @Override
  public JsonObject getJson(Map resources, DataSerializer dataSerializer) {
    @Nonnull final JsonObject json = super.getJsonStub();
    json.addProperty("from", getFrom());
    json.addProperty("to", getTo());
    json.addProperty("precision", precision.name());
    return json;
  }

  public int getFrom() {
    return from;
  }

  @Nonnull
  public ImgBandSelectLayer setFrom(final int from) {
    this.from = from;
    return this;
  }

  @Override
  public Precision getPrecision() {
    return precision;
  }

  @Nonnull
  @Override
  public ImgBandSelectLayer setPrecision(final Precision precision) {
    this.precision = precision;
    return this;
  }

  @Nonnull
  @Override
  public List state() {
    return Arrays.asList();
  }

  public int getTo() {
    return to;
  }

  @Nonnull
  public ImgBandSelectLayer setTo(int to) {
    this.to = to;
    return this;
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy