All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.art.photo.RegionAssembler Maven / Gradle / Ivy

/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.art.photo;

import com.simiacryptus.mindseye.art.photo.cuda.SparseMatrixFloat;
import com.simiacryptus.mindseye.art.photo.topology.RasterTopology;
import com.simiacryptus.mindseye.lang.CoreSettings;
import com.simiacryptus.mindseye.lang.Tensor;
import com.simiacryptus.ref.lang.RefAware;
import com.simiacryptus.ref.lang.RefUtil;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.*;
import java.util.function.Function;
import java.util.function.IntFunction;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static java.lang.Math.log;

public abstract class RegionAssembler implements Comparator {

  public final HashSet regions = new HashSet<>();
  private final int pixels;
  private final int regionCount;
  private final Predicate connectionFilter;

  public RegionAssembler(@Nonnull SparseMatrixFloat graph, @Nonnull int[] pixelMap, @Nonnull Map assignments,
                         @Nonnull @RefAware IntFunction pixelFunction, @Nonnull @RefAware IntFunction coordFunction,
                         Predicate connectionFilter) {
    graph.assertSymmetric();
    this.regionCount = graph.rows;
    this.pixels = pixelMap.length;
    this.connectionFilter = connectionFilter;
    final Region[] rawRegions = IntStream.range(0, graph.rows).mapToObj(row -> new Region(row))
        .toArray(i -> new Region[i]);
    final Map> pixelAssignmentMap = IntStream.range(0, pixelMap.length).mapToObj(x -> x)
        .collect(Collectors.groupingBy(x -> pixelMap[x], Collectors.toList()));
    IntStream stream = Arrays.stream(graph.activeRows());
    if (!CoreSettings.INSTANCE().singleThreaded) stream = stream.parallel();
    final List collect = stream.mapToObj(row -> {
      assert this.regionCount > row;
      final int[] cols = graph.getCols(row);
      final float[] vals = graph.getVals(row);
      assert cols.length == vals.length;
      final Region region = rawRegions[row];
      final Integer assignment = assignments.get(row);
      if (null != assignment)
        region.marks.add(assignment);
      if (cols.length != 0) {
        region.connections_addAll(IntStream.range(0, cols.length).filter(x -> Double.isFinite(vals[x]) && vals[x] != 0)
            .mapToObj(x -> new Connection(region, rawRegions[cols[x]], vals[x])).filter(x -> x.to != x.from)
            .collect(Collectors.toList()));
      }
      final List pixels = pixelAssignmentMap.get(row);
      if (pixels != null) {
        pixels.stream().map(value1 -> pixelFunction.apply(value1)).collect(Collectors.toList()).forEach(doubles1 -> region.colorStats.accept(doubles1));
        pixels.stream().map(value -> coordFunction.apply(value)).collect(Collectors.toList()).forEach(doubles -> region.spacialStats.accept(doubles));
        region.pixels.addAll(pixels);
      }
      region.original_regions.add(row);
      return region;
    }).collect(Collectors.toList());
    RefUtil.freeRef(pixelFunction);
    RefUtil.freeRef(coordFunction);
    this.regions.addAll(collect);
    assert this.regions.stream().flatMap(x -> x.connections.values().stream()).allMatch(connection -> {
      final Connection reciprical = connection.reciprical();
      if (null == reciprical)
        return false;
      return Math.abs(connection.value - reciprical.value) < 1e-4;
    });
  }

  @Nonnull
  public int[] getPixelMap() {
    final int[] ints = new int[pixels];
    regions.forEach(i -> {
      final int id = i.minId();
      i.pixels.forEach(p -> ints[p] = id);
    });
    return ints;
  }

  @Nonnull
  public int[] getProjection() {
    final int[] ints = new int[regionCount];
    regions.forEach(i -> {
      final int id = i.minId();
      i.original_regions.forEach(p -> ints[p] = id);
    });
    return ints;
  }

  @Nonnull
  public RegionTree getTree() {
    return new RegionTree(regions.stream().map(x -> x.tree).toArray(i -> new RegionTree[i]));
  }

  @Nonnull
  public static RegionAssembler wrap(@Nonnull SparseMatrixFloat graph, @Nonnull int[] pixelMap, @Nonnull Function extractor,
                                     @Nonnull final Tensor content, @Nonnull @RefAware final RasterTopology topology, @Nonnull final Map assignments) {
    return new RegionAssembler(
        graph,
        pixelMap,
        assignments,
        RefUtil.wrapInterface(p -> content.getPixel(topology.getCoordsFromIndex(p)), content, RefUtil.addRef(topology)),
        RefUtil.wrapInterface(p -> Arrays.stream(topology.getCoordsFromIndex(p)).mapToDouble(x -> x).toArray(), topology),
        connection -> connection.to != connection.from && extractor.apply(connection) < Double.POSITIVE_INFINITY
    ) {
      @Override
      public int compare(Connection o1, Connection o2) {
        return Comparator.comparing(extractor).compare(o1, o2);
      }
    };
  }

  public @Nonnull
  static RegionAssembler volumeEntropy(@Nonnull SparseMatrixFloat graph, @Nonnull int[] pixelMap, @Nonnull Tensor content,
                                       @Nonnull RasterTopology topology) {
    return wrap(graph, pixelMap, new Function() {
      @Override
      public Double apply(@Nullable Connection entry) {
        if (null == entry)
          return Double.POSITIVE_INFINITY;
        if (0 == entry.to.getConnectionWeight())
          return -Double.POSITIVE_INFINITY;
        if (0 == entry.from.getConnectionWeight())
          return -Double.POSITIVE_INFINITY;
        final double entropy = reduce(-log(entry.value / entry.to.getConnectionWeight()),
            -log(entry.value / entry.from.getConnectionWeight()));
        final double smallness = reduce(entry.to.pixels.size(), entry.from.pixels.size());
        return smallness * entropy;
      }

      double reduce(double a, double b) {
        return Math.min(a, b);
      }
    }, content, topology, new HashMap());
  }

  public @Nonnull
  static RegionAssembler simpleEntropy(@Nonnull SparseMatrixFloat graph, @Nonnull int[] pixelMap, @Nonnull Tensor content,
                                       @Nonnull RasterTopology topology) {
    return wrap(graph, pixelMap, new Function() {
      @Override
      public Double apply(@Nullable Connection entry) {
        if (null == entry)
          return Double.POSITIVE_INFINITY;
        if (0 == entry.to.getConnectionWeight())
          return -Double.POSITIVE_INFINITY;
        if (0 == entry.from.getConnectionWeight())
          return -Double.POSITIVE_INFINITY;
        return reduce(-log(entry.value / entry.to.getConnectionWeight()),
            -log(entry.value / entry.from.getConnectionWeight()));
      }

      double reduce(double a, double b) {
        return Math.min(a, b);
      }
    }, content, topology, new HashMap());
  }

  public @Nonnull
  static RegionAssembler epidemic(@Nonnull SparseMatrixFloat graph, @Nonnull int[] pixelMap, @Nonnull Tensor content,
                                  @Nonnull RasterTopology topology, @Nonnull Map assignments) {
    return wrap(graph, pixelMap, new Function() {
      @Override
      public Double apply(@Nullable Connection entry) {
        if (null == entry)
          return Double.POSITIVE_INFINITY;
        if (0 == entry.to.getConnectionWeight())
          return -Double.POSITIVE_INFINITY;
        if (0 == entry.from.getConnectionWeight())
          return -Double.POSITIVE_INFINITY;

        final int[] to_marks = entry.to.marks.stream().mapToInt(x -> x).limit(1).toArray();
        final int[] from_marks = entry.from.marks.stream().mapToInt(x -> x).limit(1).toArray();
        if (to_marks.length == 0 && from_marks.length == 0)
          return Double.POSITIVE_INFINITY;
        if (to_marks.length > 0 && from_marks.length > 0) {
          if (to_marks[0] == from_marks[0])
            return -Double.POSITIVE_INFINITY;
          else
            return Double.POSITIVE_INFINITY;
        }

        final double entropy = reduce(-log(entry.value / entry.to.getConnectionWeight()),
            -log(entry.value / entry.from.getConnectionWeight()));
        final double smallness = reduce(entry.to.pixels.size(), entry.from.pixels.size());
        return smallness * entropy;
      }

      double reduce(double a, double b) {
        return Math.min(a, b);
      }
    }, content, topology, assignments);
  }

  public @Nonnull
  static RegionAssembler volume5D(@Nonnull SparseMatrixFloat graph, @Nonnull int[] pixelMap, @Nonnull Tensor content,
                                  @Nonnull RasterTopology topology) {
    final double minVol = 5e-1;
    final double color_coeff = 1e3;
    return wrap(graph, pixelMap, new Function() {
      @Override
      public Double apply(@Nullable Connection entry) {
        if (null == entry)
          return Double.POSITIVE_INFINITY;
        final int to_pixels = entry.to.pixels.size();
        final int from_pixels = entry.from.pixels.size();
        final double log_vol_to = logVol(entry.to.colorStats, entry.to.spacialStats);
        final double log_vol_from = logVol(entry.from.colorStats, entry.from.spacialStats);
        final double log_vol_union = logVol(union(entry.to.colorStats, entry.from.colorStats),
            union(entry.to.spacialStats, entry.from.spacialStats));
        final double sourceVolume = log_vol_to * to_pixels + log_vol_from * from_pixels;
        final double resultVolume = log_vol_union * (to_pixels + from_pixels);
        final double split_entropy = to_pixels * log((double) to_pixels / (to_pixels + from_pixels))
            + from_pixels * log((double) from_pixels / (to_pixels + from_pixels));
        return (resultVolume - sourceVolume) / split_entropy;
      }

      public double logVol(@Nonnull DoubleVectorStatistics colorStats, @Nonnull DoubleVectorStatistics spacialStats) {
        return color_coeff * log(volumeStdDev(minVol, colorStats)) + log(volumeExtrema(minVol, spacialStats));
      }

      public double volumeExtrema(double minVol, @Nonnull DoubleVectorStatistics union) {
        return Arrays.stream(union.firstOrder).mapToDouble(statistics -> statistics.getMax() - statistics.getMin())
            .map(x -> Math.abs(x) < minVol ? minVol : x).reduce((a, b) -> a * b).orElse(0);
      }

      public double volumeStdDev(double minVol, @Nonnull DoubleVectorStatistics stats) {
        final DoubleSummaryStatistics[] firstOrder = stats.firstOrder;
        final DoubleSummaryStatistics[] secondOrder = stats.secondOrder;
        return IntStream.range(0, firstOrder.length)
            .mapToDouble(
                i -> Math.pow(Math.abs(secondOrder[i].getAverage() - Math.pow(firstOrder[i].getAverage(), 2)), 0.5))
            .map(x -> Math.abs(x) < minVol ? minVol : x).reduce((a, b) -> a * b).orElse(0);
      }

      @Nonnull
      private DoubleVectorStatistics union(@Nonnull DoubleVectorStatistics a, @Nonnull DoubleVectorStatistics b) {
        final DoubleVectorStatistics statistics = new DoubleVectorStatistics(a.firstOrder.length);
        statistics.combine(a);
        statistics.combine(b);
        return statistics;
      }
    }, content, topology, new HashMap());
  }

  @Nonnull
  public static int[] reduce(@Nonnull SparseMatrixFloat graph, int targetCount, @Nonnull final int[] sizes, @Nonnull Tensor content,
                             @Nonnull RasterTopology topology) {
    return wrap(graph, sizes, (Connection entry) -> {
      return null == entry ? Double.POSITIVE_INFINITY : entry.value;
    }, content, topology, new HashMap()).reduceTo(targetCount).getProjection();
  }

  private static void assertEmpty(@Nonnull List collect) {
    if (!collect.isEmpty()) {
      throw new IllegalArgumentException("Items: " + collect.size());
    }
  }

  @Nonnull
  public RegionAssembler reduceTo(int count) {
    while (regions.parallelStream().map(region1 -> region1.connections_stream())
        .filter(stream -> stream.filter(connectionFilter).count() > 0).limit(count + 1).count() > count) {
      final int limit = Math.max(1, (regions.size() - count) / 100);
      List first = regions.parallelStream().flatMap(region -> region.connections_stream_parallel())
          .filter(connectionFilter).sorted(this).limit(limit).collect(Collectors.toList());
      if (first.isEmpty()) {
        System.out.println("No connections left");
        break;
      } else {
        final HashSet touched = new HashSet<>();
        first.stream().filter(connection -> touched.add(connection.to) && touched.add(connection.from))
            .forEach(connection1 -> connection1.join());
      }
    }
    return this;
  }

  public static class RegionTree {
    public final int[] regions;
    @Nonnull
    public final RegionTree[] children;

    public RegionTree(int... regions) {
      this.regions = regions;
      children = new RegionTree[]{};
    }

    public RegionTree(@Nonnull RegionTree... children) {
      this.regions = Arrays.stream(children).flatMapToInt(x -> Arrays.stream(x.regions)).toArray();
      this.children = children;
    }
  }

  public class Connection {
    public final Region from;
    public final Region to;
    public final float value;

    public Connection(Region from, Region to, float value) {
      this.from = from;
      this.to = to;
      this.value = value;
      assert value > 0.0;
    }

    public Connection reciprical() {
      return to.connections.get(from);
    }

    protected void join() {
      final Region toRemove;
      final Region consolidated;
      final int to_id = this.to.minId();
      final int from_id = this.from.minId();
      if (to_id < from_id) {
        toRemove = this.from;
        consolidated = this.to;
      } else {
        toRemove = this.to;
        consolidated = this.from;
      }
      if (!regions.remove(toRemove)) {
        System.out.println("Remove dead connection to Region " + toRemove.minId());
        if (!this.from.connections_remove(this)) {
          throw new IllegalStateException();
        } else {
          return;
        }
      } else {
        //com.simiacryptus.ref.wrappers.System.out.println("Remove Region " + minId(toRemove));
      }
      consolidated.union(toRemove);
      toRemove.clear();
    }
  }

  public class Region {
    public final HashMap connections = new HashMap<>();
    public final HashSet pixels = new HashSet<>();
    public final HashSet marks = new HashSet<>();
    public final HashSet original_regions = new HashSet<>();
    public final DoubleVectorStatistics colorStats = new DoubleVectorStatistics(3);
    public final DoubleVectorStatistics spacialStats = new DoubleVectorStatistics(2);
    @Nullable
    public RegionTree tree;
    private double connectionWeight = 0;

    public Region(int id, @Nonnull Connection... connections) {
      this.original_regions.add(id);
      tree = new RegionTree(id);
      Arrays.stream(connections).forEach(connection -> connections_add(connection));
    }

    public double getConnectionWeight() {
      return connectionWeight;
    }

    public Stream connections_stream() {
      return connections.values().stream();
    }

    public Stream connections_stream_parallel() {
      return connections.values().parallelStream();
    }

    public void connections_clear() {
      connectionWeight = 0;
      connections.clear();
    }

    public boolean connections_add(@Nonnull Connection connection) {
      assert connection.from == this;
      final boolean add = null == connections.put(connection.to, connection);
      if (add)
        connectionWeight += connection.value;
      return add;
    }

    public boolean connections_remove(@Nonnull Connection connection) {
      final boolean remove = null != connections.remove(connection.to);
      if (remove)
        connectionWeight -= connection.value;
      return remove;
    }

    public boolean connections_addAll(@Nonnull Collection connections) {
      return connections.stream().filter(x -> !this.connections_add(x)).allMatch(connection -> connections_add(connection));
    }

    public int minId() {
      return original_regions.stream().mapToInt(x -> x).min().getAsInt();
    }

    private void union(@Nonnull Region other) {
      final List newConnections = Stream.concat(other.connections_stream(), connections_stream())
          .filter(k -> k.to != other && k.to != this)
          .collect(Collectors.groupingBy(k -> k.to, Collectors.reducing((a, b) -> {
            assert a.from == this || a.from == other;
            assert b.from == this || b.from == other;
            assert a.to == b.to;
            return new Connection(a.from, a.to, a.value + b.value);
          }))).values().stream().map(optional -> RefUtil.get(optional)).map(c -> new Connection(this, c.to, c.value))
          .collect(Collectors.toList());

      this.colorStats.combine(other.colorStats);
      this.spacialStats.combine(other.spacialStats);
      this.pixels.addAll(other.pixels);
      this.original_regions.addAll(other.original_regions);
      this.marks.addAll(other.marks);
      this.tree = new RegionTree(this.tree, other.tree);
      connections_clear();
      connections_addAll(newConnections);

      newConnections.stream().allMatch(v -> {
        final Region thirdNode = v.to;
        final List connectionsToRemove = thirdNode.connections_stream()
            .filter(x -> x.to == other || x.to == this).collect(Collectors.toList());
        //assert Math.abs(connectionsToRemove.stream().mapToDouble(x -> x.value).sum() - v.value) < 1e-3;
        assertEmpty(
            connectionsToRemove.stream().filter(o -> !thirdNode.connections_remove(o)).collect(Collectors.toList()));
        thirdNode.connections_add(new Connection(thirdNode, v.from, v.value));
        return true;
      });
    }

    private void clear() {
      this.pixels.clear();
      this.marks.clear();
      this.original_regions.clear();
      this.tree = null;
      connections_clear();
    }
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy