All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.applications.ImageSegmenter Maven / Gradle / Ivy

There is a newer version: 2.1.0
Show newest version
/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.applications;

import com.simiacryptus.mindseye.lang.*;
import com.simiacryptus.mindseye.lang.cudnn.Precision;
import com.simiacryptus.mindseye.layers.cudnn.ImgBandSelectLayer;
import com.simiacryptus.mindseye.layers.cudnn.ImgConcatLayer;
import com.simiacryptus.mindseye.layers.java.SumReducerLayer;
import com.simiacryptus.mindseye.models.CVPipe;
import com.simiacryptus.mindseye.models.CVPipe_VGG19;
import com.simiacryptus.mindseye.models.LayerEnum;
import com.simiacryptus.mindseye.network.DAGNetwork;
import com.simiacryptus.mindseye.network.PipelineNetwork;
import com.simiacryptus.mindseye.test.TestUtil;
import com.simiacryptus.notebook.NotebookOutput;
import com.simiacryptus.notebook.NullNotebookOutput;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import java.awt.image.BufferedImage;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
 * The type Object location.
 *
 * @param  the type parameter
 * @param  the type parameter
 */
public abstract class ImageSegmenter, U extends CVPipe> extends PixelClusterer {

  private static final Logger logger = LoggerFactory.getLogger(ImageSegmenter.class);

  /**
   * Instantiates a new Image segmenter.
   *
   * @param clusters                   the clusters
   * @param orientation                the orientation
   * @param globalDistributionEmphasis the global distribution emphasis
   * @param selectionEntropyAdj        the selection entropy adj
   * @param maxIterations              the max iterations
   * @param timeoutMinutes             the timeout minutes
   * @param seedPcaPower               the seed pca power
   * @param seedMagnitude              the seed magnitude
   */
  public ImageSegmenter(
      final int clusters,
      final int orientation,
      final double globalDistributionEmphasis,
      final double selectionEntropyAdj,
      final int maxIterations,
      final int timeoutMinutes,
      final double seedPcaPower,
      final double seedMagnitude
  ) {
    super(
        clusters,
        orientation,
        globalDistributionEmphasis,
        selectionEntropyAdj,
        maxIterations,
        timeoutMinutes,
        seedPcaPower,
        seedMagnitude,
        false,
        true,
        0.0,
        1.0,
        new double[]{1e-1, 1e-3}
    );
  }

  /**
   * Instantiates a new Image segmenter.
   *
   * @param clusters the clusters
   */
  public ImageSegmenter(final int clusters) {
    super(clusters);
  }

  /**
   * Quick masks list.
   *
   * @param img the img
   * @return the list
   */
  public static List quickMasks(final Tensor img) {
    return quickMasks(img, 3);
  }

  /**
   * Quick masks list.
   *
   * @param img      the img
   * @param clusters
   * @return the list
   */
  public static List quickMasks(final Tensor img, final int clusters) {
    return quickMasks(img, clusters, clusters, clusters);
  }

  public static List quickMasks(final Tensor img, final int masks, final int colorClusters, final int textureClusters) {
    return quickMasks(new NullNotebookOutput(), img, masks, colorClusters, textureClusters);
  }

  public static List quickMasks(
      @Nonnull final NotebookOutput log,
      final Tensor img,
      final int masks,
      final int colorClusters,
      final int textureClusters
  ) {
    return quickmasks(log, img, masks, colorClusters, textureClusters);
  }

  public static List quickmasks(
      @Nonnull final NotebookOutput log,
      final Tensor img,
      final int masks,
      final int colorClusters,
      final int textureClusters
  ) {
    if (1 >= masks) return Arrays.asList(img.sumChannels().map(x -> 1.0));
    return quickmasks(
        log,
        img,
        9,
        masks,
        colorClusters,
        textureClusters,
        CVPipe_VGG19.Layer.Layer_0,
        CVPipe_VGG19.Layer.Layer_1a,
        CVPipe_VGG19.Layer.Layer_1e
    );
  }

  public static List quickmasks(
      @Nonnull final NotebookOutput log,
      final Tensor img,
      final int blur,
      final int masks,
      final int colorClusters,
      final int textureClusters,
      final CVPipe_VGG19.Layer... layers
  ) {
    ImageSegmenter segmenter = new VGG19(masks) {
      @Override
      public Layer modelingNetwork(final CVPipe_VGG19.Layer layer, final Tensor metrics) {
        if (layer == CVPipe_VGG19.Layer.Layer_0) {
          return modelingNetwork(getGlobalBias(), getGlobalGain(), metrics, true, isRescale(), colorClusters, getSeedMagnitude(), 0);
        } else {
          return modelingNetwork(
              getGlobalBias(),
              getGlobalGain(),
              metrics,
              isRecenter(),
              isRescale(),
              textureClusters,
              getSeedMagnitude(),
              getSeedPcaPower()
          );
        }
      }
    };
    List featureMasks = segmenter.featureClusters(log, img, layers);
    List blur1 = PCAObjectLocation.blur(featureMasks, blur);
    List spatialClusters = segmenter.spatialClusters(log, img, blur1);
    blur1.forEach(ReferenceCountingBase::freeRef);
    return spatialClusters;
  }

  /**
   * Alpha png mask buffered png.
   *
   * @param log  the log
   * @param img  the img
   * @param mask the mask
   * @return the buffered png
   */
  public static BufferedImage alphaImageMask(@Nonnull final NotebookOutput log, final Tensor img, Tensor mask) {
    return log.eval(() -> {
      return img.mapCoords(c -> img.get(c) * mask.get(
          c.getCoords()[0],
          c.getCoords()[1],
          Math.min(c.getCoords()[2], mask.getDimensions()[2])
      )).toImage();
    });
  }

  /**
   * Alpha png mask buffered png.
   *
   * @param img  the img
   * @param mask the mask
   * @return the buffered png
   */
  public static BufferedImage alphaImageMask(final Tensor img, Tensor mask) {
    Tensor tensor = img.mapCoords(c -> img.get(c) * mask.get(
        c.getCoords()[0],
        c.getCoords()[1],
        Math.min(c.getCoords()[2], mask.getDimensions()[2] - 1)
    ));
    BufferedImage image = tensor.toImage();
    tensor.freeRef();
    return image;
  }

  /**
   * Display png mask.
   *
   * @param log  the log
   * @param img  the img
   * @param mask the mask
   */
  public static void displayImageMask(@Nonnull final NotebookOutput log, final Tensor img, Tensor mask) {
    Tensor scale = mask.scale(255.0);
    Tensor alphaMask = mask.normalizeDistribution().scaleInPlace(255.0);
    log.p(log.png(img.toRgbImageAlphaMask(0, 1, 2, scale), "") +
        log.png(img.toRgbImageAlphaMask(0, 1, 2, alphaMask), ""));
    alphaMask.freeRef();
    scale.freeRef();
  }

  /**
   * Feature clusters list.
   *
   * @param log    the log
   * @param img    the img
   * @param layers the layers
   * @return the list
   */
  public List featureClusters(@Nonnull final NotebookOutput log, final Tensor img, final T... layers) {
    if (1 >= getClusters()) return Arrays.asList(img.map(x -> 1.0));
    return Arrays.stream(getLayerTypes()).filter(x -> Arrays.asList(layers).contains(x)).flatMap(layer -> {
      log.h2(layer.name());
      Map prototypes = getInstance().getPrototypes();
      Layer network = prototypes.get(layer);
      assert null != network : prototypes.toString();
      ArtistryUtil.setPrecision((DAGNetwork) network, Precision.Float);
      network.setFrozen(true);
      Result imageFeatures = network.evalAndFree(new MutableResult(img));
      Tensor featureImage = imageFeatures.getData().get(0);
      log.p("Feature Image Dimension: " + Arrays.toString(featureImage.getDimensions()));
      Layer analyze1 = analyze(layer, log, featureImage);
      featureImage.freeRef();
      List layerMasks = IntStream.range(0, getClusters()).mapToObj(i -> {
        try {
          PipelineNetwork net = PipelineNetwork.wrap(
              1,
              analyze1.copy().freeze(),
              new ImgBandSelectLayer(i, i + 1),
              new SumReducerLayer()
          );
          ArtistryUtil.setPrecision(net, Precision.Float);
          double[] singleDelta;
          try {
            Result eval = net.eval(imageFeatures);
            try {
              singleDelta = eval.getSingleDelta();
            } finally {
              eval.getData().freeRef();
              eval.freeRef();
            }
          } finally {
            net.freeRef();
          }
          Tensor maskData = new Tensor(singleDelta, img.getDimensions()).mapAndFree(v -> Math.abs(v));
          Tensor sumChannels = maskData.sumChannels();
          double rms = sumChannels.rms();
          displayImageMask(log, img, sumChannels.scaleInPlace(1.0 / rms));
          sumChannels.freeRef();
          return maskData;
        } catch (Throwable e) {
          logger.warn("Error", e);
          return null;
        }
      }).filter(x -> x != null).collect(Collectors.toList());
      imageFeatures.freeRef();
      imageFeatures.getData().freeRef();
      analyze1.freeRef();
      log.p(TestUtil.animatedGif(log, layerMasks.stream().map(selectedBand -> {
        Tensor mask = selectedBand.rescaleRms(1.0);
        BufferedImage image = alphaImageMask(img, mask);
        mask.freeRef();
        return image;
      }).toArray(i -> new BufferedImage[i])));
      return layerMasks.stream();
    }).collect(Collectors.toList());
  }

  /**
   * Spatial clusters list.
   *
   * @param log          the log
   * @param img          the img
   * @param featureMasks the feature masks
   * @return the list
   */
  public List spatialClusters(@Nonnull final NotebookOutput log, final Tensor img, final List featureMasks) {
    List tensors = featureMasks.stream().map(Tensor::sumChannels).collect(Collectors.toList());
    Tensor concat = ImgConcatLayer.eval(tensors);
    tensors.forEach(ReferenceCountingBase::freeRef);
    PipelineNetwork analyze = analyze(null, log, concat);
    ArtistryUtil.setPrecision(analyze, Precision.Float);
    Tensor reclustered = analyze.eval(concat).getDataAndFree().getAndFree(0);
    analyze.freeRef();
    concat.freeRef();
    List tensorList = IntStream.range(
        0,
        reclustered.getDimensions()[2]
    ).mapToObj(i -> reclustered.selectBand(i)).collect(Collectors.toList());
    reclustered.freeRef();
    log.p(TestUtil.animatedGif(log, tensorList.stream().map(selectedBand -> alphaImageMask(img, selectedBand)).toArray(i -> new BufferedImage[i])));
    for (Tensor selectBand : tensorList) {
      displayImageMask(log, img, selectBand);
    }
    return tensorList;
  }

  /**
   * Gets instance.
   *
   * @return the instance
   */
  public abstract U getInstance();

  /**
   * Get key types t [ ].
   *
   * @return the t [ ]
   */
  @Nonnull
  public abstract T[] getLayerTypes();

  /**
   * The type Vgg 19.
   */
  public static class VGG19 extends ImageSegmenter {

    /**
     * Instantiates a new Vgg 19.
     *
     * @param clusters the clusters
     */
    public VGG19(final int clusters) {
      super(clusters);
    }

    /**
     * Instantiates a new Vgg 19.
     *
     * @param clusters                   the clusters
     * @param orientation                the orientation
     * @param globalDistributionEmphasis the global distribution emphasis
     * @param selectionEntropyAdj        the selection entropy adj
     * @param maxIterations              the max iterations
     * @param timeoutMinutes             the timeout minutes
     * @param seedPcaPower               the seed pca power
     * @param seedMagnitude              the seed magnitude
     */
    public VGG19(
        final int clusters,
        final int orientation,
        final double globalDistributionEmphasis,
        final double selectionEntropyAdj,
        final int maxIterations,
        final int timeoutMinutes,
        final double seedPcaPower,
        final double seedMagnitude
    ) {
      super(clusters, orientation, globalDistributionEmphasis, selectionEntropyAdj, maxIterations, timeoutMinutes, seedPcaPower, seedMagnitude);
    }

    @Override
    public CVPipe_VGG19 getInstance() {
      return CVPipe_VGG19.INSTANCE;
    }

    @Override
    @Nonnull
    public CVPipe_VGG19.Layer[] getLayerTypes() {
      return CVPipe_VGG19.Layer.values();
    }


  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy