All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.layers.java.ImgViewLayer Maven / Gradle / Ivy

There is a newer version: 2.1.0
Show newest version
/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.layers.java;

import com.google.gson.JsonArray;
import com.google.gson.JsonObject;
import com.google.gson.JsonPrimitive;
import com.simiacryptus.mindseye.lang.*;
import org.apache.commons.math3.util.FastMath;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.*;
import java.util.stream.IntStream;

/**
 * Reduces the resolution of the input by selecting a centered window. The output png will have the same number of
 * color bands.
 */
@SuppressWarnings("serial")
public class ImgViewLayer extends LayerBase {

  private double negativeBias = 255;
  private boolean wrap;
  private int sizeX;
  private int sizeY;
  private int offsetX;
  private int offsetY;
  private int rotationCenterX;
  private int rotationCenterY;
  private int[] channelSelector;
  private double rotationRadians;

  public ImgViewLayer(final int sizeX, final int sizeY) {
    this(sizeX, sizeY, false);
  }

  public ImgViewLayer(final int sizeX, final int sizeY, boolean wrap) {
    this(
        sizeX,
        sizeY,
        0,
        0,
        wrap
    );
  }

  public ImgViewLayer(final int sizeX, final int sizeY, final int offsetX, final int offsetY) {
    this(
        sizeX,
        sizeY,
        offsetX,
        offsetY,
        false
    );
  }

  public ImgViewLayer(final int sizeX, final int sizeY, final int offsetX, final int offsetY, final boolean wrap) {
    super();
    setSizeX(sizeX).setSizeY(sizeY).setOffsetX(offsetX).setOffsetY(offsetY).setWrap(wrap);
  }

  /**
   * Instantiates a new Img crop key.
   *
   * @param json the json
   */
  protected ImgViewLayer(@Nonnull final JsonObject json) {
    super(json);
    setSizeX(json.getAsJsonPrimitive("sizeX").getAsInt());
    setSizeY(json.getAsJsonPrimitive("sizeY").getAsInt());
    setOffsetX(json.getAsJsonPrimitive("offsetX").getAsInt());
    setOffsetY(json.getAsJsonPrimitive("offsetY").getAsInt());
    setNegativeBias(json.getAsJsonPrimitive("negativeBias").getAsDouble());
    setRotationCenterX(json.getAsJsonPrimitive("rotationCenterX").getAsInt());
    setRotationCenterY(json.getAsJsonPrimitive("rotationCenterY").getAsInt());
    setRotationRadians(json.getAsJsonPrimitive("rotationRadians").getAsDouble());
    JsonArray _channelPermutationFilter = json.getAsJsonArray("channelPermutationFilter");
    if (null != _channelPermutationFilter) {
      setChannelSelector(new int[_channelPermutationFilter.size()]);
      for (int i = 0; i < getChannelSelector().length; i++) {
        getChannelSelector()[i] = _channelPermutationFilter.get(i).getAsInt();
      }
    }
    //channelSelector
    JsonPrimitive toroidal = json.getAsJsonPrimitive("wrap");
    this.setWrap(null != toroidal && toroidal.getAsBoolean());
  }

  private static void set(@Nonnull Tensor tensor, int width, int height, int x, int y, int channel, boolean wrap, double value) {
    assert channel >= 0 : channel;
    if (wrap) {
      while (x < 0) x += width;
      x %= width;
      while (y < 0) y += height;
      y %= height;
    }
    if (x < 0) {
      return;
    } else if (x >= width) {
      return;
    }
    if (y < 0) {
      return;
    } else if (y >= height) {
      return;
    }
    tensor.set(x, y, channel, value);
  }

  private static double get(@Nonnull Tensor tensor, int width, int height, int x, int y, int channel, boolean wrap) {
    assert channel >= 0 : channel;
    if (wrap) {
      while (x < 0) x += width;
      x %= width;
      while (y < 0) y += height;
      y %= height;
    }
    if (x < 0) {
      return 0.0;
    } else if (x >= width) {
      return 0.0;
    }
    if (y < 0) {
      return 0.0;
    } else if (y >= height) {
      return 0.0;
    }
    return tensor.get(x, y, channel);
  }

  /**
   * From json img crop key.
   *
   * @param json the json
   * @param rs   the rs
   * @return the img crop key
   */
  public static ImgViewLayer fromJson(@Nonnull final JsonObject json, Map rs) {
    return new ImgViewLayer(json);
  }

  @Nonnull
  protected Tensor fwd(
      @Nonnull final Tensor inputData,
      @Nonnull final Tensor outputData
  ) {
    int[] inputDims = inputData.getDimensions();
    @Nonnull final int[] inDim = inputDims;
    @Nonnull final int[] outDim = outputData.getDimensions();
    assert 3 == inDim.length;
    assert 3 == outDim.length;
    assert inDim[2] == outDim[2] : Arrays.toString(inDim) + "; " + Arrays.toString(outDim);
    outputData.coordStream(true).forEach((c) -> {
      int[] coords = c.getCoords();
      double[] xy = coordinateMapping(coords[0], coords[1]);
      int x = (int) Math.round(xy[0]);
      int y = (int) Math.round(xy[1]);
      int channel;
      if (null != channelSelector) channel = channelSelector[coords[2]];
      else channel = coords[2] + 1;
      if (0 < channel) {
        outputData.set(c, get(inputData, inputDims[0], inputDims[1], x, y, channel - 1, wrap));
      } else {
        outputData.set(c, getNegativeBias() - get(inputData, inputDims[0], inputDims[1], x, y, -channel - 1, wrap));
      }
    });
    return outputData;
  }

  @Nonnull
  protected Tensor bck(
      @Nonnull final Tensor outputDelta,
      @Nonnull final Tensor inputDelta
  ) {
    int[] outDeltaDims = outputDelta.getDimensions();
    @Nonnull final int[] inputDeltaDims = inputDelta.getDimensions();
    assert 3 == outDeltaDims.length;
    assert 3 == inputDeltaDims.length;
    assert outDeltaDims[2] == inputDeltaDims[2] : Arrays.toString(outDeltaDims) + "; " + Arrays.toString(inputDeltaDims);
    outputDelta.coordStream(true).forEach((c) -> {
      int[] outCoord = c.getCoords();
      double[] inCoords = coordinateMapping(outCoord[0], outCoord[1]);
      int x = (int) Math.round(inCoords[0]);
      int y = (int) Math.round(inCoords[1]);
      int channel;
      if (null != channelSelector) channel = channelSelector[outCoord[2]];
      else channel = outCoord[2] + 1;
      if (0 < channel) {
        set(inputDelta, inputDeltaDims[0], inputDeltaDims[1], x, y, channel - 1, wrap, outputDelta.get(c));
      } else {
        set(inputDelta, inputDeltaDims[0], inputDeltaDims[1], x, y, -channel - 1, wrap, -outputDelta.get(c));
      }
    });
    return inputDelta;
  }

  protected double[] coordinateMapping(double... xy) {
    xy[0] += offsetX;
    xy[1] += offsetY;
    xy[0] -= rotationCenterX;
    xy[1] -= rotationCenterY;
    double x1 = xy[0];
    double y1 = xy[1];
    double sin = FastMath.sin(rotationRadians);
    double cos = FastMath.cos(rotationRadians);
    xy[0] = cos * x1 - sin * y1;
    xy[1] = sin * x1 + cos * y1;
    xy[0] += rotationCenterX;
    xy[1] += rotationCenterY;
    return xy;
  }

  @Nonnull
  @Override
  public Result eval(@Nonnull final Result... inObj) {
    Arrays.stream(inObj).forEach(nnResult -> nnResult.addRef());
    final Result input = inObj[0];
    final TensorList batch = input.getData();
    @Nonnull final int[] inputDims = batch.getDimensions();
    assert 3 == inputDims.length;
    @Nonnull final int[] dimOut = getViewDimensions(inputDims, new int[]{getSizeX(), getSizeY(), inputDims[2]}, new int[]{getOffsetX(), getOffsetY(), 0});
    if (null != channelSelector) dimOut[2] = channelSelector.length;
    return new Result(TensorArray.wrap(IntStream.range(0, batch.length())
        .mapToObj(dataIndex -> {
          @Nonnull final Tensor outputData = new Tensor(dimOut);
          Tensor inputData = batch.get(dataIndex);
          fwd(inputData, outputData);
          inputData.freeRef();
          return outputData;
        })
        .toArray(i -> new Tensor[i])), (@Nonnull final DeltaSet buffer, @Nonnull final TensorList error) -> {
      if (input.isAlive()) {
        @Nonnull TensorArray tensorArray = TensorArray.wrap(IntStream.range(0, error.length())
            .mapToObj(dataIndex -> {
              @Nullable final Tensor err = error.get(dataIndex);
              @Nonnull final Tensor passback = new Tensor(inputDims);
              bck(err, passback);
              err.freeRef();
              return passback;
            }).toArray(i -> new Tensor[i]));
        input.accumulate(buffer, tensorArray);
      }
      error.freeRef();
    }) {

      @Override
      protected void _free() {
        Arrays.stream(inObj).forEach(nnResult -> nnResult.freeRef());
      }

      @Override
      public boolean isAlive() {
        return input.isAlive() || !isFrozen();
      }
    };
  }

  /**
   * Get view dimensions int [ ].
   *
   * @param sourceDimensions      the source dimensions
   * @param destinationDimensions the destination dimensions
   * @param offset                the offset
   * @return the int [ ]
   */
  @Nonnull
  public int[] getViewDimensions(int[] sourceDimensions, int[] destinationDimensions, int[] offset) {
    @Nonnull final int[] viewDim = new int[3];
    Arrays.parallelSetAll(viewDim, i -> isWrap() ? (destinationDimensions[i]) : (Math.min(
        sourceDimensions[i],
        destinationDimensions[i] + offset[i]
    ) - Math.max(offset[i], 0)));
    return viewDim;
  }

  @Nonnull
  @Override
  public JsonObject getJson(Map resources, DataSerializer dataSerializer) {
    @Nonnull final JsonObject json = super.getJsonStub();
    json.addProperty("sizeX", getSizeX());
    json.addProperty("sizeY", getSizeY());
    json.addProperty("offsetX", getOffsetX());
    json.addProperty("offsetY", getOffsetY());
    json.addProperty("negativeBias", getNegativeBias());
    json.addProperty("rotationCenterX", getRotationCenterX());
    json.addProperty("rotationCenterY", getRotationCenterY());
    json.addProperty("rotationRadians", getRotationRadians());
    json.addProperty("wrap", isWrap());
    if (null != getChannelSelector()) {
      JsonArray _channelPermutationFilter = new JsonArray();
      for (int i : getChannelSelector()) {
        _channelPermutationFilter.add(i);
      }
      json.add("channelSelector", _channelPermutationFilter);
    }
    return json;
  }

  @Nonnull
  @Override
  public List state() {
    return new ArrayList<>();
  }


  public int getSizeX() {
    return sizeX;
  }

  public ImgViewLayer setSizeX(int sizeX) {
    this.sizeX = sizeX;
    return this;
  }

  public int getSizeY() {
    return sizeY;
  }

  public ImgViewLayer setSizeY(int sizeY) {
    this.sizeY = sizeY;
    return this;
  }

  public int getOffsetX() {
    return offsetX;
  }

  public ImgViewLayer setOffsetX(int offsetX) {
    this.offsetX = offsetX;
    return this;
  }

  public int getOffsetY() {
    return offsetY;
  }

  public ImgViewLayer setOffsetY(int offsetY) {
    this.offsetY = offsetY;
    return this;
  }

  public int getRotationCenterX() {
    return rotationCenterX;
  }

  public ImgViewLayer setRotationCenterX(int rotationCenterX) {
    this.rotationCenterX = rotationCenterX;
    return this;
  }

  public int getRotationCenterY() {
    return rotationCenterY;
  }

  public ImgViewLayer setRotationCenterY(int rotationCenterY) {
    this.rotationCenterY = rotationCenterY;
    return this;
  }

  public double getRotationRadians() {
    return rotationRadians;
  }

  public ImgViewLayer setRotationRadians(double rotationRadians) {
    this.rotationRadians = rotationRadians;
    return this;
  }

  public boolean isWrap() {
    return wrap;
  }

  public ImgViewLayer setWrap(boolean wrap) {
    this.wrap = wrap;
    return this;
  }

  public int[] getChannelSelector() {
    return channelSelector;
  }

  public ImgViewLayer setChannelSelector(int... channelSelector) {
    this.channelSelector = channelSelector;
    return this;
  }

  public double getNegativeBias() {
    return negativeBias;
  }

  public ImgViewLayer setNegativeBias(double negativeBias) {
    this.negativeBias = negativeBias;
    return this;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy