All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.art.photo.affinity.ContextAffinity Maven / Gradle / Ivy

/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.art.photo.affinity;

import com.simiacryptus.mindseye.art.photo.MultivariateFrameOfReference;
import com.simiacryptus.mindseye.art.photo.topology.IteratedRasterTopology;
import com.simiacryptus.mindseye.art.photo.topology.RasterTopology;
import com.simiacryptus.mindseye.lang.CoreSettings;
import com.simiacryptus.mindseye.lang.Tensor;
import com.simiacryptus.ref.lang.ReferenceCountingBase;
import org.ejml.simple.SimpleMatrix;

import javax.annotation.Nonnull;
import java.util.Arrays;
import java.util.List;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

public abstract class ContextAffinity extends ReferenceCountingBase implements RasterAffinity {
  @Nonnull
  protected final Tensor content;
  @Nonnull
  protected final int[] dimensions;
  private RasterTopology topology;
  private double mixing = 0.5;
  private int graphPower1 = 3;
  private int graphPower2 = 2;

  public ContextAffinity(@Nonnull Tensor content) {
    this.content = content;
    this.dimensions = this.content.getDimensions();
  }

  public int getGraphPower1() {
    return graphPower1;
  }

  public void setGraphPower1(int graphPower1) {
    this.graphPower1 = graphPower1;
  }

  public int getGraphPower2() {
    return graphPower2;
  }

  public void setGraphPower2(int graphPower2) {
    this.graphPower2 = graphPower2;
  }

  public double getMixing() {
    return mixing;
  }

  public void setMixing(double mixing) {
    this.mixing = mixing;
  }

  public RasterTopology getTopology() {
    return topology.addRef();
  }

  public void setTopology(RasterTopology topology) {
    this.topology = topology;
  }

  @Nonnull
  public static SimpleMatrix toMatrix(@Nonnull double[] array) {
    return new SimpleMatrix(array.length, 1, false, array);
  }

  public static SimpleMatrix mix(@Nonnull SimpleMatrix means_local, @Nonnull SimpleMatrix means_global, double pos) {
    assert pos >= 0;
    assert pos <= 1;
    return means_local.scale(pos).plus(means_global.scale(1 - pos));
  }

  @Nonnull
  public static SimpleMatrix covariance(@Nonnull SimpleMatrix means, @Nonnull SimpleMatrix rms, @Nonnull Supplier> supplier,
                                        int size) {
    final SimpleMatrix cov = new SimpleMatrix(size, size);
    IntStream stream = IntStream.range(0, size);
    if (!CoreSettings.INSTANCE().singleThreaded) stream = stream.parallel();
    stream.forEach(c1 -> {
      final double mean1 = means.get(c1);
      final double rms1 = rms.get(c1);
      IntStream.range(0, size).forEach(c2 -> {
        final double mean2 = means.get(c2);
        final double rms2 = rms.get(c2);
        final double covariance = supplier.get().mapToDouble(p -> (p[c1] - mean1) / rms1 * ((p[c2] - mean2) / rms2))
            .average().getAsDouble();
        cov.set(c1, c2, covariance);
      });
    });
    return cov;
  }

  @Nonnull
  public static SimpleMatrix magnitude(@Nonnull SimpleMatrix means, @Nonnull Supplier> stream, int size) {
    final SimpleMatrix rms = new SimpleMatrix(size, 1);
    IntStream.range(0, size).forEach(c -> rms.set(c, 0,
        //        256
        //        1.0 * Math.sqrt(neighborhood.stream().mapToDouble(p -> p[c] - means.get(c)).map(p -> p * p).average().getAsDouble())
        Math.sqrt(stream.get().mapToDouble(p -> p[c] - means.get(c)).map(a -> Math.abs(a)).max().getAsDouble())));
    return rms;
  }

  @Nonnull
  public static SimpleMatrix means(@Nonnull Supplier> stream, int size) {
    final SimpleMatrix means = new SimpleMatrix(size, 1);
    IntStream.range(0, size).forEach(c -> means.set(c, 0, stream.get().mapToDouble(p -> p[c]).average().orElse(0)));
    return means;
  }

  @Nonnull
  private static IntStream expand(@Nonnull List graphEdges, @Nonnull IntStream stream, int iterations) {
    if (iterations <= 0) {
      return stream;
    } else if (iterations == 1) {
      return expand(graphEdges, stream);
    } else {
      return expand(graphEdges, expand(graphEdges, stream, iterations - 1));
    }
  }

  @Nonnull
  private static IntStream expand(@Nonnull List graphEdges, @Nonnull IntStream intStream) {
    return intStream.mapToObj(index -> graphEdges.get(index)).flatMapToInt(data -> Arrays.stream(data)).distinct();
  }

  @Override
  public List affinityList(@Nonnull List graphEdges) {
    final int channels = dimensions[2];
    final int pixels = dimensions[0] * dimensions[1];
    final List iteratedGraph = IteratedRasterTopology.iterate(graphEdges, getGraphPower1());
    MultivariateFrameOfReference region_global = new MultivariateFrameOfReference(
        () -> content.getPixelStream().parallel(), channels);
    return IntStream.range(0, pixels).parallel().mapToObj(i -> Arrays.stream(graphEdges.get(i)).mapToDouble(j -> {
      if (i == j) {
        return 1;
      } else {
        final List neighborhood = expand(iteratedGraph, IntStream.of(i, j), getGraphPower2())
            .mapToObj(i1 -> pixel(i1)).collect(Collectors.toList());
        if (neighborhood.isEmpty())
          return 1;
        MultivariateFrameOfReference mix = new MultivariateFrameOfReference(region_global,
            new MultivariateFrameOfReference(() -> neighborhood.stream(), channels), getMixing());
        return dist(toMatrix(mix.adjust(pixel(i))), toMatrix(mix.adjust(pixel(j))), mix.cov, neighborhood.size()
        );
      }
    }).toArray()).collect(Collectors.toList());
  }

  public @SuppressWarnings("unused")
  void _free() {
    super._free();
    topology.freeRef();
    content.freeRef();
  }

  @Nonnull
  public @Override
  @SuppressWarnings("unused")
  ContextAffinity addRef() {
    return (ContextAffinity) super.addRef();
  }

  protected double[] adjust(double[] pixel_i, @Nonnull SimpleMatrix means, @Nonnull SimpleMatrix rms) {
    return IntStream.range(0, dimensions[2]).mapToDouble(c -> (pixel_i[c] - means.get(c)) / rms.get(c)).toArray();
  }

  protected abstract double dist(SimpleMatrix vector_i, SimpleMatrix vector_j, SimpleMatrix cov, int neighborhoodSize);

  protected double[] pixel(int i) {
    RasterTopology topology = getTopology();
    final int[] coords = topology.getCoordsFromIndex(i);
    topology.freeRef();
    return IntStream.range(0, dimensions[2]).mapToDouble(c -> {
      return content.get(coords[0], coords[1], c);
    }).toArray();
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy