All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.art.ops.ContentPCAMatcher Maven / Gradle / Ivy

/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.art.ops;

import com.simiacryptus.mindseye.art.VisualModifier;
import com.simiacryptus.mindseye.art.util.PCA;
import com.simiacryptus.mindseye.lang.Tensor;
import com.simiacryptus.mindseye.layers.ValueLayer;
import com.simiacryptus.mindseye.layers.cudnn.*;
import com.simiacryptus.mindseye.layers.cudnn.conv.ConvolutionLayer;
import com.simiacryptus.mindseye.layers.java.ImgBandScaleLayer;
import com.simiacryptus.mindseye.layers.java.LinearActivationLayer;
import com.simiacryptus.mindseye.layers.java.NthPowerActivationLayer;
import com.simiacryptus.mindseye.network.DAGNode;
import com.simiacryptus.mindseye.network.PipelineNetwork;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class ContentPCAMatcher implements VisualModifier {
  private static final Logger log = LoggerFactory.getLogger(ContentPCAMatcher.class);
  private int minValue = -1;
  private int maxValue = 1;
  private boolean averaging = true;
  private int bands = 16;

  @Override
  public PipelineNetwork build(PipelineNetwork network, Tensor... image) {
    network = network.copyPipeline();
    Tensor baseContent = network.eval(image).getDataAndFree().getAndFree(0);
    int[] contentDimensions = baseContent.getDimensions();
    List components;
    PipelineNetwork signalProjection;
    try {
      PCA pca = new PCA().setRecenter(true).setRescale(false).setEigenvaluePower(0.0);
      Tensor channelMeans = pca.getChannelMeans(baseContent);
      Tensor channelRms = pca.getChannelRms(baseContent, contentDimensions[2], channelMeans);
      double[] covariance = PCA.bandCovariance(baseContent.getPixelStream(), PCA.countPixels(baseContent), channelMeans.getData(), channelRms.getData());
      signalProjection = PipelineNetwork.wrap(1,
          new ImgBandBiasLayer(channelMeans.scaleInPlace(-1)),
          new ImgBandScaleLayer(channelRms.mapAndFree(x -> 1 / x).getData())
      );
      channelMeans.freeRef();
      components = PCA.pca(covariance, pca.getEigenvaluePower()).stream().collect(Collectors.toList());
    } catch (Throwable e) {
      log.info("Error processing PCA for dimensions " + Arrays.toString(contentDimensions), e);
      PipelineNetwork pipelineNetwork = new PipelineNetwork(1);
      pipelineNetwork.wrap(new ValueLayer(new Tensor(0.0)), new DAGNode[]{});
      return pipelineNetwork;
    }
    int bands = Math.min(getBands(), contentDimensions[2]);

    Tensor prefixPattern = signalProjection.eval(baseContent).getDataAndFree().getAndFree(0);
    channelStats(getConvolutionLayer1(new ConvolutionLayer(1, 1, contentDimensions[2], bands).setPaddingXY(0, 0), components, bands).explodeAndFree().eval(prefixPattern).getDataAndFree().getAndFree(0), bands);
    channelStats(getConvolutionLayer2(new ConvolutionLayer(1, 1, contentDimensions[2], bands).setPaddingXY(0, 0), components, bands).explodeAndFree().eval(prefixPattern).getDataAndFree().getAndFree(0), bands);

    signalProjection.wrap(getConvolutionLayer1(new ConvolutionLayer(1, 1, contentDimensions[2], bands).setPaddingXY(0, 0), components, bands).explodeAndFree());
    Tensor spacialPattern = signalProjection.eval(baseContent).getDataAndFree().getAndFree(0);
    channelStats(spacialPattern, bands);

    double mag = spacialPattern.rms();
    DAGNode head = signalProjection.getHead();
    DAGNode constNode = signalProjection.constValueWrap(spacialPattern.scaleInPlace(-1));
    signalProjection.wrap(new SumInputsLayer().setName("Difference"), head, constNode).freeRef();
    signalProjection.wrap(PipelineNetwork.wrap(1,
        new SquareActivationLayer(),
        isAveraging() ? new AvgReducerLayer() : new SumReducerLayer(),
        new LinearActivationLayer().setScale(Math.pow(mag, -2))
//        ,new NthPowerActivationLayer().setPower(0.5)
    ).setName(String.format("RMS / %.0E", mag)));

    network.wrap(signalProjection.setName(String.format("PCA Content Match"))).freeRef();
    return (PipelineNetwork) network.freeze();
  }

  public void channelStats(Tensor spacialPattern, int bands) {
    double[] means = IntStream.range(0, bands).mapToDouble(band -> {
      return spacialPattern.selectBand(band).mean();
    }).toArray();
    double[] stdDevs = IntStream.range(0, bands).mapToDouble(band -> {
      Tensor bandPattern = spacialPattern.selectBand(band);
      return Math.sqrt(Math.pow(bandPattern.rms(), 2) - Math.pow(bandPattern.mean(), 2));
    }).toArray();
    log.info("Means: " + Arrays.toString(means) + "; StdDev: " + Arrays.toString(stdDevs));
  }

  @NotNull
  public ConvolutionLayer getConvolutionLayer1(ConvolutionLayer convolutionLayer, List components, int stride) {
    convolutionLayer.getKernel().setByCoord(c -> {
      int[] coords = c.getCoords();
      return components.get(coords[2] % stride).get(coords[2] / stride);
    });
    return convolutionLayer;
  }

  @NotNull
  public ConvolutionLayer getConvolutionLayer2(ConvolutionLayer convolutionLayer, List components, int stride) {
    convolutionLayer.getKernel().setByCoord(c -> {
      int[] coords = c.getCoords();
      return components.get(coords[2] / stride).get(coords[2] % stride);
    });
    return convolutionLayer;
  }

  public boolean isAveraging() {
    return averaging;
  }

  public ContentPCAMatcher setAveraging(boolean averaging) {
    this.averaging = averaging;
    return this;
  }

  public int getMinValue() {
    return minValue;
  }

  public ContentPCAMatcher setMinValue(int minValue) {
    this.minValue = minValue;
    return this;
  }

  public int getMaxValue() {
    return maxValue;
  }

  public ContentPCAMatcher setMaxValue(int maxValue) {
    this.maxValue = maxValue;
    return this;
  }

  public int getBands() {
    return bands;
  }

  public ContentPCAMatcher setBands(int bands) {
    this.bands = bands;
    return this;
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy