All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.layers.cudnn.ImgTileAssemblyLayer Maven / Gradle / Ivy

/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.layers.cudnn;

import com.google.gson.JsonObject;
import com.simiacryptus.mindseye.lang.*;
import com.simiacryptus.mindseye.lang.cudnn.*;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.lang.ReferenceCountingBase;
import com.simiacryptus.ref.wrappers.*;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.Map;
import java.util.UUID;

@SuppressWarnings("serial")
public class ImgTileAssemblyLayer extends LayerBase implements MultiPrecision {
  private static final Logger log = LoggerFactory.getLogger(ImgTileAssemblyLayer.class);

  private int columns;
  private int rows;
  private Precision precision = CudaSettings.INSTANCE().getDefaultPrecision();
  private boolean parallel;

  private ImgTileAssemblyLayer() {
  }

  public ImgTileAssemblyLayer(int columns, int rows) {
    this.columns = columns;
    this.rows = rows;
  }

  protected ImgTileAssemblyLayer(@Nonnull final JsonObject json) {
    super(json);
    columns = json.get("columns").getAsInt();
    rows = json.get("rows").getAsInt();
    this.parallel = json.get("parallel").getAsBoolean();
    this.precision = Precision.valueOf(json.getAsJsonPrimitive("precision").getAsString());
  }

  @Nonnull
  public Layer getCompatibilityLayer() {
    return this.as(com.simiacryptus.mindseye.layers.java.ImgTileAssemblyLayer.class);
  }

  @Override
  public Precision getPrecision() {
    return precision;
  }

  @Override
  public void setPrecision(final Precision precision) {
    this.precision = precision;
  }

  public boolean isParallel() {
    return parallel;
  }

  public void setParallel(boolean parallel) {
    this.parallel = parallel;
  }

  @Nonnull
  @SuppressWarnings("unused")
  public static ImgTileAssemblyLayer fromJson(@Nonnull final JsonObject json, Map rs) {
    return new ImgTileAssemblyLayer(json);
  }

  @Nullable
  @Override
  public Result eval(@Nonnull final Result... inObj) {
    if (!CudaSystem.isEnabled()) {
      Layer compatibilityLayer = getCompatibilityLayer();
      Result result = compatibilityLayer.eval(inObj);
      compatibilityLayer.freeRef();
      return result;
    }
    if (1 == inObj.length) {
      Result result = inObj[0].addRef();
      RefUtil.freeRef(inObj);
      return result;
    }
    TensorList data0 = inObj[0].getData();
    int[] inputDimensions = data0.getDimensions();
    final int length = data0.length();
    data0.freeRef();
    assert 3 == inputDimensions.length;
    int[] outputDims = getOutputDims(RefUtil.addRef(inObj));
    final TensorList outputData = fwd(length, outputDims, RefUtil.addRef(inObj));
    Result.Accumulator accumulator = new Accumulator(this.addRef(), outputData.addRef(), length, outputDims, rows, columns, parallel, RefUtil.addRef(inObj));
    boolean isAlive = RefArrays.stream(inObj).anyMatch(x -> {
      boolean temp_09_0009 = x.isAlive();
      x.freeRef();
      return temp_09_0009;
    });
    return new Result(outputData, accumulator, isAlive);
  }

  public void backprop(@Nonnull final BackpropParams backpropParams) {
    final TensorList passbackTensorList = CudaSystem
        .run(RefUtil.wrapInterface((RefFunction) gpu -> {
          CudaTensor ptr = copy(gpu, backpropParams.error.addRef(), backpropParams.tileDimensions,
              backpropParams.outputDims, backpropParams.length, -backpropParams.positionX, -backpropParams.totalHeight);
          CudaTensorList temp_09_0010 = new CudaTensorList(ptr == null ? null : ptr.addRef(), backpropParams.length,
              backpropParams.tileDimensions, precision);
          if (null != ptr)
            ptr.freeRef();
          return temp_09_0010;
        }, backpropParams.addRef()), backpropParams.error.addRef());
    backpropParams.inObj[backpropParams.inputIndex].accumulate(backpropParams.buffer.addRef(),
        passbackTensorList == null ? null : passbackTensorList.addRef());
    backpropParams.freeRef();
    if (null != passbackTensorList)
      passbackTensorList.freeRef();
  }

  @Nullable
  public CudaTensor copy(@Nonnull final CudnnHandle gpu, @Nullable final TensorList error, @Nonnull final int[] tileDimensions,
                         @Nonnull final int[] outputDims, final int length, final int positionX, final int positionY) {
    @Nullable final CudaTensor errorPtr = gpu.getTensor(error == null ? null : error.addRef(), precision, MemoryType.Device,
        false);
    if (null != error)
      error.freeRef();
    @Nonnull final CudaMemory passbackBuffer = gpu.allocate(
        (long) length * tileDimensions[2] * tileDimensions[1] * tileDimensions[0] * precision.size,
        MemoryType.Managed.ifEnabled(), false);
    copy(gpu.addRef(), length, outputDims, errorPtr, tileDimensions,
        passbackBuffer.addRef(), positionX, positionY);
    CudaDevice.CudaTensorDescriptor descriptor = gpu.newTensorDescriptor(precision, length, tileDimensions[2],
        tileDimensions[1], tileDimensions[0]);
    gpu.freeRef();
    return new CudaTensor(passbackBuffer, descriptor, precision);
  }

  public void copy(@Nonnull final CopyParams copyParams) {
    CudnnHandle gpu = copyParams.gpu.addRef();
    gpu.initThread();
    assert CudaDevice.isThreadDeviceId(gpu.getDeviceId());
    @Nullable final CudaTensor inputBuffer = gpu.getTensor(copyParams.inObj[copyParams.inputIndex].getData(), precision,
        MemoryType.Device, false);
    copy(gpu, copyParams.length, copyParams.tileDimensions, inputBuffer,
        copyParams.outputDims, copyParams.outputBuffer.addRef(), copyParams.positionX, copyParams.totalHeight);
    copyParams.freeRef();
  }

  public void copy(@Nonnull CudnnHandle gpu, int length, @Nonnull int[] sourceDimensions, @Nonnull CudaTensor source,
                   @Nonnull int[] destinationDimensions, @Nonnull CudaMemory destination, int positionX, int positionY) {
    if (3 != sourceDimensions.length) {
      source.freeRef();
      destination.freeRef();
      gpu.freeRef();
      throw new IllegalArgumentException("inputDimensions.length");
    }
    if (3 != destinationDimensions.length) {
      source.freeRef();
      destination.freeRef();
      gpu.freeRef();
      throw new IllegalArgumentException("dimOut.length");
    }
    int bands = sourceDimensions[2];
    if (bands != destinationDimensions[2]) {
      source.freeRef();
      destination.freeRef();
      gpu.freeRef();
      throw new IllegalArgumentException(RefString.format("%d != %d", bands, destinationDimensions[2]));
    }
    //log.info(String.format("offset=%d,%d", offsetX, offsetY));
    @Nonnull final int[] viewDim = getViewDimensions(sourceDimensions, destinationDimensions,
        new int[]{positionX, positionY, 0});
    final CudaDevice.CudaTensorDescriptor sourceViewDescriptor = gpu.newTensorDescriptor(precision, //
        length, //
        viewDim[2], //
        viewDim[1], //
        viewDim[0], //
        source.descriptor.nStride, //
        source.descriptor.cStride, //
        source.descriptor.hStride, //
        source.descriptor.wStride);
    final CudaDevice.CudaTensorDescriptor destinationViewDescriptor = gpu.newTensorDescriptor(precision, //
        length, //
        viewDim[2], //
        viewDim[1], //
        viewDim[0], //
        destinationDimensions[2] * destinationDimensions[1] * destinationDimensions[0], //
        destinationDimensions[1] * destinationDimensions[0], //
        destinationDimensions[0], //
        1);
    int sourceOffset = 0;
    int destinationOffset = 0;

    if (positionX > 0) {
      destinationOffset += Math.abs(positionX);
    } else {
      sourceOffset += source.descriptor.wStride * Math.abs(positionX);
    }
    if (positionY > 0) {
      destinationOffset += destinationDimensions[0] * Math.abs(positionY);
    } else {
      sourceOffset += source.descriptor.hStride * Math.abs(positionY);
    }
    assert sourceOffset >= 0;
    assert destinationOffset >= 0;
    assert sourceOffset + Tensor.length(viewDim) <= source.descriptor.nStride * length;
    assert destinationOffset + Tensor.length(viewDim) <= Tensor.length(destinationDimensions);

    CudaMemory sourceMemory = source.getMemory(gpu.addRef());
    source.freeRef();
    assert sourceMemory != null;
    CudaSystem.handle(gpu.cudnnTransformTensor(precision.getPointer(1.0), sourceViewDescriptor.getPtr(),
        sourceMemory.getPtr().withByteOffset(sourceOffset * precision.size), precision.getPointer(1.0),
        destinationViewDescriptor.getPtr(), destination.getPtr().withByteOffset(destinationOffset * precision.size)));
    destinationViewDescriptor.freeRef();
    sourceViewDescriptor.freeRef();
    assert CudaDevice.isThreadDeviceId(gpu.getDeviceId());
    sourceMemory.dirty();
    sourceMemory.freeRef();
    gpu.freeRef();
    destination.dirty();
    destination.freeRef();
  }

  @Nonnull
  public int[] getViewDimensions(int[] sourceDimensions, int[] destinationDimensions, int[] offset) {
    @Nonnull final int[] viewDim = new int[3];
    RefArrays.parallelSetAll(viewDim,
        i -> Math.min(sourceDimensions[i] + offset[i], destinationDimensions[i]) - Math.max(offset[i], 0));
    return viewDim;
  }

  @Nonnull
  @Override
  public JsonObject getJson(Map resources, DataSerializer dataSerializer) {
    @Nonnull final JsonObject json = super.getJsonStub();
    json.addProperty("rows", rows);
    json.addProperty("columns", columns);
    json.addProperty("precision", precision.name());
    json.addProperty("parallel", isParallel());
    return json;
  }

  @Nonnull
  @Override
  public RefList state() {
    return RefArrays.asList();
  }

  public @SuppressWarnings("unused")
  void _free() {
    super._free();
  }

  @Nonnull
  public @Override
  @SuppressWarnings("unused")
  ImgTileAssemblyLayer addRef() {
    return (ImgTileAssemblyLayer) super.addRef();
  }

  @NotNull
  private TensorList fwd(int length, int[] outputDims, @Nonnull Result[] inObj) {
    return CudaSystem.run(RefUtil.wrapInterface((RefFunction) gpu -> {
          assert CudaDevice.isThreadDeviceId(gpu.getDeviceId());
          assert outputDims[0] > 0;
          assert outputDims[1] > 0;
          assert outputDims[2] > 0;
          @Nonnull final CudaMemory outputBuffer = gpu.allocate(
              (long) length * outputDims[2] * outputDims[1] * outputDims[0] * precision.size,
              MemoryType.Managed.ifEnabled(), false);
          int totalWidth = 0;
          int totalHeight = 0;
          int inputIndex = 0;
          RefList copies = new RefArrayList<>();
          for (int row = 0; row < rows; row++) {
            int positionX = 0;
            int rowHeight = 0;
            for (int col = 0; col < columns; col++) {
              TensorList temp_09_0020 = inObj[inputIndex].getData();
              int[] tileDimensions = temp_09_0020.getDimensions();
              temp_09_0020.freeRef();
              rowHeight = Math.max(rowHeight, tileDimensions[1]);
              copies.add(new CopyParams(gpu.addRef(), RefUtil.addRef(inObj), outputBuffer.addRef(),
                  length, outputDims, tileDimensions, inputIndex, positionX, totalHeight));
              positionX += tileDimensions[0];
              inputIndex += 1;
              assert CudaDevice.isThreadDeviceId(gpu.getDeviceId());
            }
            totalHeight += rowHeight;
            totalWidth = Math.max(totalWidth, positionX);
          }
          assert CudaDevice.isThreadDeviceId(gpu.getDeviceId());
          RefStream stream = copies.stream();
          if (!CoreSettings.INSTANCE().singleThreaded && parallel)
            stream = stream.parallel();
          stream.forEach(copyParams -> copy(copyParams));
          copies.freeRef();
          CudaTensorList cudaTensorList = new CudaTensorList(
              new CudaTensor(outputBuffer,
                  gpu.newTensorDescriptor(precision, length, outputDims[2], outputDims[1], outputDims[0]),
                  precision),
              length, outputDims, precision);
          gpu.freeRef();
          return cudaTensorList;
        }, RefUtil.addRef(inObj)),
        RefArrays.stream(inObj).map(result -> {
          TensorList data = result.getData();
          result.freeRef();
          return data;
        }).toArray());
  }

  @Nonnull
  private int[] getOutputDims(@Nullable final Result[] inObj) {
    assert inObj != null;
    TensorList temp_09_0022 = inObj[0].getData();
    int bands = temp_09_0022.getDimensions()[2];
    temp_09_0022.freeRef();
    int totalWidth = 0;
    int totalHeight = 0;
    int inputIndex = 0;
    for (int row = 0; row < rows; row++) {
      int positionX = 0;
      int rowHeight = 0;
      for (int col = 0; col < columns; col++) {
        TensorList temp_09_0023 = inObj[inputIndex].getData();
        int[] dimensions = temp_09_0023.getDimensions();
        temp_09_0023.freeRef();
        rowHeight = Math.max(rowHeight, dimensions[1]);
        positionX += dimensions[0];
        inputIndex += 1;
      }
      totalHeight += rowHeight;
      totalWidth = Math.max(totalWidth, positionX);
    }
    RefUtil.freeRef(inObj);
    return new int[]{totalWidth, totalHeight, bands};
  }

  private static class CopyParams extends ReferenceCountingBase {
    public final int length;
    public final int[] outputDims;
    public final CudnnHandle gpu;
    @Nonnull
    public final CudaMemory outputBuffer;
    public final int totalHeight;
    public final int inputIndex;
    public final int positionX;
    public final int[] tileDimensions;
    @Nonnull
    public final Result[] inObj;

    private CopyParams(final CudnnHandle gpu, @Nonnull final Result[] inObj, @Nullable final CudaMemory outputBuffer,
                       final int length, final int[] outputDims, final int[] tileDimensions, final int inputIndex, final int positionX,
                       final int totalHeight) {
      this.length = length;
      this.outputDims = outputDims;
      this.gpu = gpu;
      this.outputBuffer = outputBuffer;
      this.totalHeight = totalHeight;
      this.inputIndex = inputIndex;
      this.positionX = positionX;
      this.tileDimensions = tileDimensions;
      this.inObj = inObj;
    }

    public @SuppressWarnings("unused")
    void _free() {
      super._free();
      RefUtil.freeRef(inObj);
      if (null != gpu) gpu.freeRef();
      outputBuffer.freeRef();
    }

    @Nonnull
    public @Override
    @SuppressWarnings("unused")
    CopyParams addRef() {
      return (CopyParams) super.addRef();
    }
  }

  private static class BackpropParams extends ReferenceCountingBase {
    @Nonnull
    public final Result[] inObj;
    @Nonnull
    public final DeltaSet buffer;
    @Nonnull
    public final TensorList error;
    public final int[] outputDims;
    public final int[] tileDimensions;
    public final int length;
    public final int positionX;
    public final int totalHeight;
    public final int inputIndex;

    private BackpropParams(@Nonnull final Result[] inObj, @Nonnull final DeltaSet buffer,
                           @Nonnull final TensorList error, final int[] outputDims, final int[] tileDimensions, final int length,
                           final int positionX, final int totalHeight, final int inputIndex) {
      this.inObj = inObj;
      this.buffer = buffer;
      this.error = error;
      this.outputDims = outputDims;
      this.tileDimensions = tileDimensions;
      this.length = length;
      this.positionX = positionX;
      this.totalHeight = totalHeight;
      this.inputIndex = inputIndex;
    }

    public @SuppressWarnings("unused")
    void _free() {
      super._free();
      error.freeRef();
      buffer.freeRef();
      RefUtil.freeRef(inObj);
    }

    @Nonnull
    public @Override
    @SuppressWarnings("unused")
    BackpropParams addRef() {
      return (BackpropParams) super.addRef();
    }
  }

  private static class Accumulator extends Result.Accumulator {

    private final ImgTileAssemblyLayer imgTileAssemblyLayer;
    private final TensorList outputData;
    private final int length;
    private final int[] outputDims;
    private final Result[] inObj;
    private int rows;
    private int columns;
    private boolean parallel;

    public Accumulator(ImgTileAssemblyLayer imgTileAssemblyLayer, TensorList outputData, int length, int[] outputDims, int rows, int columns, boolean parallel, Result... inObj) {
      this.imgTileAssemblyLayer = imgTileAssemblyLayer;
      this.outputData = outputData;
      this.length = length;
      this.outputDims = outputDims;
      this.inObj = inObj;
      this.rows = rows;
      this.columns = columns;
      this.parallel = parallel;
    }

    @Override
    public void accept(@Nullable DeltaSet buffer, @Nonnull TensorList error) {
      if (!RefArrays.equals(error.getDimensions(), outputData.getDimensions())) {
        if (null != buffer)
          buffer.freeRef();
        AssertionError temp_09_0015 = new AssertionError(RefArrays.toString(error.getDimensions()) + " != "
            + RefArrays.toString(outputData.getDimensions()));
        error.freeRef();
        throw temp_09_0015;
      }
      if (error.length() != outputData.length()) {
        if (null != buffer)
          buffer.freeRef();
        AssertionError temp_09_0016 = new AssertionError(error.length() + " != " + outputData.length());
        error.freeRef();
        throw temp_09_0016;
      }
      assert error.length() == length;
      int totalHeight = 0;
      int inputIndex = 0;
      RefList tasks = new RefArrayList<>();
      for (int row = 0; row < rows; row++) {
        int positionX = 0;
        int rowHeight = 0;
        for (int col = 0; col < columns; col++) {
          Result in = inObj[inputIndex].addRef();
          TensorList temp_09_0021 = in.getData();
          int[] tileDimensions = temp_09_0021.getDimensions();
          temp_09_0021.freeRef();
          in.freeRef();
          rowHeight = Math.max(rowHeight, tileDimensions[1]);
          if (inObj[inputIndex].isAlive()) {
            tasks.add(new BackpropParams(RefUtil.addRef(inObj), buffer == null ? null : buffer.addRef(),
                error.addRef(), outputDims, tileDimensions, length, positionX,
                totalHeight, inputIndex));
          }
          positionX += tileDimensions[0];
          inputIndex += 1;
        }
        totalHeight += rowHeight;
      }
      error.freeRef();
      if (null != buffer)
        buffer.freeRef();
      RefStream stream = tasks.stream();
      tasks.freeRef();
      if (!CoreSettings.INSTANCE().singleThreaded && parallel)
        stream = stream.parallel();
      stream.forEach(backpropParams -> imgTileAssemblyLayer.backprop(backpropParams));
    }

    public @SuppressWarnings("unused")
    void _free() {
      super._free();
      RefUtil.freeRef(inObj);
      imgTileAssemblyLayer.freeRef();
      outputData.freeRef();
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy