All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.layers.cudnn.conv.ConvolutionLayer Maven / Gradle / Ivy

/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.layers.cudnn.conv;

import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.simiacryptus.mindseye.lang.*;
import com.simiacryptus.mindseye.lang.cudnn.CudaSettings;
import com.simiacryptus.mindseye.lang.cudnn.CudaSystem;
import com.simiacryptus.mindseye.lang.cudnn.MultiPrecision;
import com.simiacryptus.mindseye.lang.cudnn.Precision;
import com.simiacryptus.mindseye.layers.Explodable;
import com.simiacryptus.mindseye.network.PipelineNetwork;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.wrappers.RefArrays;
import com.simiacryptus.ref.wrappers.RefList;
import com.simiacryptus.ref.wrappers.RefString;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.Map;
import java.util.UUID;
import java.util.function.DoubleSupplier;
import java.util.function.IntToDoubleFunction;
import java.util.function.ToDoubleFunction;

@SuppressWarnings("serial")
public class ConvolutionLayer extends LayerBase implements MultiPrecision, Explodable {

  @Nullable
  private final Tensor kernel;
  private final int inputBands;
  private final int outputBands;
  private int strideX = 1;
  private int strideY = 1;
  @Nullable
  private Integer paddingX = null;
  @Nullable
  private Integer paddingY = null;
  private Precision precision = CudaSettings.INSTANCE().getDefaultPrecision();
  private int batchBands = 0;

  protected ConvolutionLayer() {
    this(1, 1, 1, 1);
  }

  public ConvolutionLayer(final int width, final int height, final int inputBands, final int outputBands) {
    super();
    assert 0 < width;
    assert 0 < height;
    assert 0 < inputBands;
    assert 0 < outputBands;
    this.kernel = new Tensor(width, height, inputBands * outputBands);
    int[] kernelDimensions = getKernelDimensions();
    if (kernelDimensions.length != 3)
      throw new IllegalArgumentException();
    if (kernelDimensions[0] <= 0)
      throw new IllegalArgumentException();
    if (kernelDimensions[1] <= 0)
      throw new IllegalArgumentException();
    if (kernelDimensions[2] <= 0)
      throw new IllegalArgumentException();
    this.inputBands = inputBands;
    this.outputBands = outputBands;
    setBatchBands((int) Math.sqrt(CudaSettings.INSTANCE().maxFilterElements / (width * height)));
  }

  protected ConvolutionLayer(@Nonnull final JsonObject json, Map resources) {
    super(json);
    this.kernel = Tensor.fromJson(json.get("filter"), resources);
    assert kernel != null;
    assert kernel.isValid();
    assert kernel.rms() > 0;
    setBatchBands(json.get("batchBands").getAsInt());
    setStrideX(json.get("strideX").getAsInt());
    setStrideY(json.get("strideY").getAsInt());
    JsonElement paddingX = json.get("paddingX");
    if (null != paddingX && paddingX.isJsonPrimitive()) {
      setPaddingX(paddingX.getAsInt());
    }
    JsonElement paddingY = json.get("paddingY");
    if (null != paddingY && paddingY.isJsonPrimitive()) {
      setPaddingY(paddingY.getAsInt());
    }
    this.precision = Precision.valueOf(json.get("precision").getAsString());
    this.inputBands = json.get("inputBands").getAsInt();
    this.outputBands = json.get("outputBands").getAsInt();
  }

  public int getBatchBands() {
    return batchBands;
  }

  public void setBatchBands(int batchBands) {
    this.batchBands = batchBands;
  }

  @Nonnull
  public Layer getCompatibilityLayer() {
    return null;
    //    return this.as(com.simiacryptus.mindseye.layers.aparapi.ConvolutionLayer.class);
  }

  @Nonnull
  public ConvolutionParams getConvolutionParams() {
    return new ConvolutionParams(inputBands, outputBands, precision, strideX, strideY, paddingX, paddingY, getKernelDimensions());
  }

  @Nonnull
  public ExplodedConvolutionGrid getExplodedNetwork() {
    assertAlive();
    int batchBands = getBatchBands();
    if (0 == batchBands) {
      batchBands = Math.max(inputBands, outputBands);
    }
    ExplodedConvolutionGrid grid = new ExplodedConvolutionGrid(getConvolutionParams(), batchBands);
    grid.write(getKernel());
    //    if (batchBands > outputBands * 2) {
    //      batchBands = outputBands;
    //    }
    return grid;
  }

  @Nullable
  public Tensor getKernel() {
    return kernel.addRef();
  }

  public int[] getKernelDimensions() {
    assert kernel != null;
    return kernel.getDimensions();
  }

  @Nullable
  @Override
  public String getName() {
    int[] kernelDimensions = getKernelDimensions();
    if (kernelDimensions.length == 4) {
      return RefString.format("Conv [%d/%d x %d/%d, %d -> %d]", kernelDimensions[0], strideX, kernelDimensions[1],
          strideY, kernelDimensions[2], kernelDimensions[3]);
    } else {
      return RefString.format("Conv [%d/%d x %d/%d, %d]", kernelDimensions[0], strideX, kernelDimensions[1], strideY,
          kernelDimensions[2]);
    }
  }

  @Nullable
  public Integer getPaddingX() {
    return paddingX;
  }

  public void setPaddingX(Integer paddingX) {
    this.paddingX = paddingX;
  }

  @Nullable
  public Integer getPaddingY() {
    return paddingY;
  }

  public void setPaddingY(Integer paddingY) {
    this.paddingY = paddingY;
  }

  @Override
  public Precision getPrecision() {
    return precision;
  }

  @Override
  public void setPrecision(final Precision precision) {
    this.precision = precision;
  }

  public int getStrideX() {
    return strideX;
  }

  public void setStrideX(int strideX) {
    this.strideX = strideX;
  }

  public int getStrideY() {
    return strideY;
  }

  public void setStrideY(int strideY) {
    this.strideY = strideY;
  }

  public void setByCoord(ToDoubleFunction coordinateToDoubleFunction) {
    kernel.setByCoord(coordinateToDoubleFunction);
    assert kernel.rms() > 0;
  }

  @Nonnull
  @SuppressWarnings("unused")
  public static ConvolutionLayer fromJson(@Nonnull final JsonObject json, Map rs) {
    return new ConvolutionLayer(json, rs);
  }

  @Nonnull
  @Override
  public Layer explode() {
    @Nonnull
    ExplodedConvolutionGrid explodedNetwork = getExplodedNetwork();
    @Nonnull
    Layer network = explodedNetwork.getNetwork();
    network.setName(getName() + "+");
    explodedNetwork.freeRef();
    return network;
  }

  @Nullable
  @Override
  public Result eval(@Nonnull final Result... inObj) {
    final Tensor kernel = getKernel();
    assert kernel != null;
    assert kernel.isValid();
    assert 1 == inObj.length;
    TensorList data0 = inObj[0].getData();
    int[] data0Dimensions = data0.getDimensions();
    assert 3 == data0Dimensions.length;
    assert inputBands == data0Dimensions[2] : RefArrays.toString(data0Dimensions) + "[2] != " + inputBands;
    if (!CudaSystem.isEnabled()) {
      kernel.freeRef();
      Layer compatibilityLayer = getCompatibilityLayer();
      Result result = compatibilityLayer.eval(inObj);
      compatibilityLayer.freeRef();
      data0.freeRef();
      return result;
    }
    @Nonnull
    ExplodedConvolutionGrid grid = getExplodedNetwork();
    @Nonnull
    PipelineNetwork network = grid.getNetwork();
    if (isFrozen()) {
      network.freeze();
    }
    final Result result = network.eval(RefUtil.addRef(inObj));
    network.freeRef();
    assert result != null;
    final TensorList resultData = result.getData();
    assert data0.length() == resultData.length();
    assert 3 == resultData.getDimensions().length;
    assert outputBands == resultData.getDimensions()[2];
    data0.freeRef();
    RefUtil.freeRef(inObj);
    boolean alive = result.isAlive();
    Result.Accumulator accumulator = new Accumulator(kernel, grid, isFrozen(), result.getAccumulator(), getId());
    result.freeRef();
    return new Result(resultData, accumulator, alive || !isFrozen());
  }

  @Nonnull
  @Override
  public JsonObject getJson(Map resources, @Nonnull DataSerializer dataSerializer) {
    @Nonnull final JsonObject json = super.getJsonStub();
    json.add("filter", kernel.getJson(resources, dataSerializer));
    json.addProperty("batchBands", getBatchBands());
    json.addProperty("strideX", getStrideX());
    json.addProperty("strideY", getStrideY());
    json.addProperty("paddingX", getPaddingX());
    json.addProperty("paddingY", getPaddingY());
    json.addProperty("precision", precision.name());
    json.addProperty("inputBands", inputBands);
    json.addProperty("outputBands", outputBands);
    return json;
  }

  public void set(@Nonnull DoubleSupplier f) {
    set(i -> f.getAsDouble());
  }

  public void set(@Nonnull Tensor tensor) {
    kernel.set(tensor);
    assert kernel.rms() > 0;
  }

  public void set(@Nonnull IntToDoubleFunction f) {
    kernel.set(f);
    assert kernel.rms() > 0;
  }

  @Nonnull
  @Override
  public RefList state() {
    return RefArrays.asList(kernel.getData());
  }

  public void setStrideXY(int x, int y) {
    setStrideX(x);
    setStrideY(y);
  }

  public void setPaddingXY(Integer x, Integer y) {
    setPaddingX(x);
    setPaddingY(y);
  }

  public void _free() {
    if (null != kernel)
      kernel.freeRef();
    super._free();
  }

  @Nonnull
  public @Override
  @SuppressWarnings("unused")
  ConvolutionLayer addRef() {
    return (ConvolutionLayer) super.addRef();
  }

  private static class Accumulator extends Result.Accumulator {

    private final Tensor kernel;
    private final ExplodedConvolutionGrid grid;
    private boolean frozen;
    private Result.Accumulator accumulator;
    private UUID id;

    public Accumulator(Tensor kernel, ExplodedConvolutionGrid grid, boolean frozen, Result.Accumulator accumulator, UUID id) {
      this.kernel = kernel;
      this.grid = grid;
      this.frozen = frozen;
      this.accumulator = accumulator;
      this.id = id;
    }

    @Override
    public void accept(@Nullable DeltaSet deltaSet, @Nullable TensorList delta) {
      assert deltaSet != null;
      this.accumulator.accept(deltaSet.addRef(), delta);
      if (!frozen) {
        Tensor read = grid.read(deltaSet.addRef(), true);
        Delta uuidDelta = deltaSet.get(id, kernel.addRef());
        assert uuidDelta != null;
        uuidDelta.addInPlace(read);
        uuidDelta.freeRef();
      }
      if (null != deltaSet)
        deltaSet.freeRef();
    }

    public @SuppressWarnings("unused")
    void _free() {
      super._free();
      kernel.freeRef();
      grid.freeRef();
      accumulator.freeRef();
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy