com.simiacryptus.mindseye.applications.ImageSegmenter Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mindseye-art Show documentation
Show all versions of mindseye-art Show documentation
Visual Neural Network Applications
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.applications;
import com.simiacryptus.mindseye.lang.*;
import com.simiacryptus.mindseye.lang.cudnn.Precision;
import com.simiacryptus.mindseye.layers.cudnn.ImgBandSelectLayer;
import com.simiacryptus.mindseye.layers.cudnn.ImgConcatLayer;
import com.simiacryptus.mindseye.layers.java.SumReducerLayer;
import com.simiacryptus.mindseye.models.CVPipe;
import com.simiacryptus.mindseye.models.CVPipe_VGG19;
import com.simiacryptus.mindseye.models.LayerEnum;
import com.simiacryptus.mindseye.network.DAGNetwork;
import com.simiacryptus.mindseye.network.PipelineNetwork;
import com.simiacryptus.mindseye.test.TestUtil;
import com.simiacryptus.notebook.NotebookOutput;
import com.simiacryptus.notebook.NullNotebookOutput;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.Nonnull;
import java.awt.image.BufferedImage;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
/**
* The type Object location.
*
* @param the type parameter
* @param the type parameter
*/
public abstract class ImageSegmenter, U extends CVPipe> extends PixelClusterer {
private static final Logger logger = LoggerFactory.getLogger(ImageSegmenter.class);
/**
* Instantiates a new Image segmenter.
*
* @param clusters the clusters
* @param orientation the orientation
* @param globalDistributionEmphasis the global distribution emphasis
* @param selectionEntropyAdj the selection entropy adj
* @param maxIterations the max iterations
* @param timeoutMinutes the timeout minutes
* @param seedPcaPower the seed pca power
* @param seedMagnitude the seed magnitude
*/
public ImageSegmenter(
final int clusters,
final int orientation,
final double globalDistributionEmphasis,
final double selectionEntropyAdj,
final int maxIterations,
final int timeoutMinutes,
final double seedPcaPower,
final double seedMagnitude
) {
super(
clusters,
orientation,
globalDistributionEmphasis,
selectionEntropyAdj,
maxIterations,
timeoutMinutes,
seedPcaPower,
seedMagnitude,
false,
true,
0.0,
1.0,
new double[]{1e-1, 1e-3}
);
}
/**
* Instantiates a new Image segmenter.
*
* @param clusters the clusters
*/
public ImageSegmenter(final int clusters) {
super(clusters);
}
/**
* Quick masks list.
*
* @param img the img
* @return the list
*/
public static List quickMasks(final Tensor img) {
return quickMasks(img, 3);
}
/**
* Quick masks list.
*
* @param img the img
* @param clusters
* @return the list
*/
public static List quickMasks(final Tensor img, final int clusters) {
return quickMasks(img, clusters, clusters, clusters);
}
public static List quickMasks(final Tensor img, final int masks, final int colorClusters, final int textureClusters) {
return quickMasks(new NullNotebookOutput(), img, masks, colorClusters, textureClusters);
}
public static List quickMasks(
@Nonnull final NotebookOutput log,
final Tensor img,
final int masks,
final int colorClusters,
final int textureClusters
) {
return quickmasks(log, img, masks, colorClusters, textureClusters);
}
public static List quickmasks(
@Nonnull final NotebookOutput log,
final Tensor img,
final int masks,
final int colorClusters,
final int textureClusters
) {
if (1 >= masks) return Arrays.asList(img.sumChannels().map(x -> 1.0));
return quickmasks(
log,
img,
9,
masks,
colorClusters,
textureClusters,
CVPipe_VGG19.Layer.Layer_0,
CVPipe_VGG19.Layer.Layer_1a,
CVPipe_VGG19.Layer.Layer_1e
);
}
public static List quickmasks(
@Nonnull final NotebookOutput log,
final Tensor img,
final int blur,
final int masks,
final int colorClusters,
final int textureClusters,
final CVPipe_VGG19.Layer... layers
) {
ImageSegmenter segmenter = new VGG19(masks) {
@Override
public Layer modelingNetwork(final CVPipe_VGG19.Layer layer, final Tensor metrics) {
if (layer == CVPipe_VGG19.Layer.Layer_0) {
return modelingNetwork(getGlobalBias(), getGlobalGain(), metrics, true, isRescale(), colorClusters, getSeedMagnitude(), 0);
} else {
return modelingNetwork(
getGlobalBias(),
getGlobalGain(),
metrics,
isRecenter(),
isRescale(),
textureClusters,
getSeedMagnitude(),
getSeedPcaPower()
);
}
}
};
List featureMasks = segmenter.featureClusters(log, img, layers);
List blur1 = PCAObjectLocation.blur(featureMasks, blur);
List spatialClusters = segmenter.spatialClusters(log, img, blur1);
blur1.forEach(ReferenceCountingBase::freeRef);
return spatialClusters;
}
/**
* Alpha png mask buffered png.
*
* @param log the log
* @param img the img
* @param mask the mask
* @return the buffered png
*/
public static BufferedImage alphaImageMask(@Nonnull final NotebookOutput log, final Tensor img, Tensor mask) {
return log.eval(() -> {
return img.mapCoords(c -> img.get(c) * mask.get(
c.getCoords()[0],
c.getCoords()[1],
Math.min(c.getCoords()[2], mask.getDimensions()[2])
)).toImage();
});
}
/**
* Alpha png mask buffered png.
*
* @param img the img
* @param mask the mask
* @return the buffered png
*/
public static BufferedImage alphaImageMask(final Tensor img, Tensor mask) {
Tensor tensor = img.mapCoords(c -> img.get(c) * mask.get(
c.getCoords()[0],
c.getCoords()[1],
Math.min(c.getCoords()[2], mask.getDimensions()[2] - 1)
));
BufferedImage image = tensor.toImage();
tensor.freeRef();
return image;
}
/**
* Display png mask.
*
* @param log the log
* @param img the img
* @param mask the mask
*/
public static void displayImageMask(@Nonnull final NotebookOutput log, final Tensor img, Tensor mask) {
Tensor scale = mask.scale(255.0);
Tensor alphaMask = mask.normalizeDistribution().scaleInPlace(255.0);
log.p(log.png(img.toRgbImageAlphaMask(0, 1, 2, scale), "") +
log.png(img.toRgbImageAlphaMask(0, 1, 2, alphaMask), ""));
alphaMask.freeRef();
scale.freeRef();
}
/**
* Feature clusters list.
*
* @param log the log
* @param img the img
* @param layers the layers
* @return the list
*/
public List featureClusters(@Nonnull final NotebookOutput log, final Tensor img, final T... layers) {
if (1 >= getClusters()) return Arrays.asList(img.map(x -> 1.0));
return Arrays.stream(getLayerTypes()).filter(x -> Arrays.asList(layers).contains(x)).flatMap(layer -> {
log.h2(layer.name());
Map prototypes = getInstance().getPrototypes();
Layer network = prototypes.get(layer);
assert null != network : prototypes.toString();
ArtistryUtil.setPrecision((DAGNetwork) network, Precision.Float);
network.setFrozen(true);
Result imageFeatures = network.evalAndFree(new MutableResult(img));
Tensor featureImage = imageFeatures.getData().get(0);
log.p("Feature Image Dimension: " + Arrays.toString(featureImage.getDimensions()));
Layer analyze1 = analyze(layer, log, featureImage);
featureImage.freeRef();
List layerMasks = IntStream.range(0, getClusters()).mapToObj(i -> {
try {
PipelineNetwork net = PipelineNetwork.wrap(
1,
analyze1.copy().freeze(),
new ImgBandSelectLayer(i, i + 1),
new SumReducerLayer()
);
ArtistryUtil.setPrecision(net, Precision.Float);
double[] singleDelta;
try {
Result eval = net.eval(imageFeatures);
try {
singleDelta = eval.getSingleDelta();
} finally {
eval.getData().freeRef();
eval.freeRef();
}
} finally {
net.freeRef();
}
Tensor maskData = new Tensor(singleDelta, img.getDimensions()).mapAndFree(v -> Math.abs(v));
Tensor sumChannels = maskData.sumChannels();
double rms = sumChannels.rms();
displayImageMask(log, img, sumChannels.scaleInPlace(1.0 / rms));
sumChannels.freeRef();
return maskData;
} catch (Throwable e) {
logger.warn("Error", e);
return null;
}
}).filter(x -> x != null).collect(Collectors.toList());
imageFeatures.freeRef();
imageFeatures.getData().freeRef();
analyze1.freeRef();
log.p(TestUtil.animatedGif(log, layerMasks.stream().map(selectedBand -> {
Tensor mask = selectedBand.rescaleRms(1.0);
BufferedImage image = alphaImageMask(img, mask);
mask.freeRef();
return image;
}).toArray(i -> new BufferedImage[i])));
return layerMasks.stream();
}).collect(Collectors.toList());
}
/**
* Spatial clusters list.
*
* @param log the log
* @param img the img
* @param featureMasks the feature masks
* @return the list
*/
public List spatialClusters(@Nonnull final NotebookOutput log, final Tensor img, final List featureMasks) {
List tensors = featureMasks.stream().map(Tensor::sumChannels).collect(Collectors.toList());
Tensor concat = ImgConcatLayer.eval(tensors);
tensors.forEach(ReferenceCountingBase::freeRef);
PipelineNetwork analyze = analyze(null, log, concat);
ArtistryUtil.setPrecision(analyze, Precision.Float);
Tensor reclustered = analyze.eval(concat).getDataAndFree().getAndFree(0);
analyze.freeRef();
concat.freeRef();
List tensorList = IntStream.range(
0,
reclustered.getDimensions()[2]
).mapToObj(i -> reclustered.selectBand(i)).collect(Collectors.toList());
reclustered.freeRef();
log.p(TestUtil.animatedGif(log, tensorList.stream().map(selectedBand -> alphaImageMask(img, selectedBand)).toArray(i -> new BufferedImage[i])));
for (Tensor selectBand : tensorList) {
displayImageMask(log, img, selectBand);
}
return tensorList;
}
/**
* Gets instance.
*
* @return the instance
*/
public abstract U getInstance();
/**
* Get key types t [ ].
*
* @return the t [ ]
*/
@Nonnull
public abstract T[] getLayerTypes();
/**
* The type Vgg 19.
*/
public static class VGG19 extends ImageSegmenter {
/**
* Instantiates a new Vgg 19.
*
* @param clusters the clusters
*/
public VGG19(final int clusters) {
super(clusters);
}
/**
* Instantiates a new Vgg 19.
*
* @param clusters the clusters
* @param orientation the orientation
* @param globalDistributionEmphasis the global distribution emphasis
* @param selectionEntropyAdj the selection entropy adj
* @param maxIterations the max iterations
* @param timeoutMinutes the timeout minutes
* @param seedPcaPower the seed pca power
* @param seedMagnitude the seed magnitude
*/
public VGG19(
final int clusters,
final int orientation,
final double globalDistributionEmphasis,
final double selectionEntropyAdj,
final int maxIterations,
final int timeoutMinutes,
final double seedPcaPower,
final double seedMagnitude
) {
super(clusters, orientation, globalDistributionEmphasis, selectionEntropyAdj, maxIterations, timeoutMinutes, seedPcaPower, seedMagnitude);
}
@Override
public CVPipe_VGG19 getInstance() {
return CVPipe_VGG19.INSTANCE;
}
@Override
@Nonnull
public CVPipe_VGG19.Layer[] getLayerTypes() {
return CVPipe_VGG19.Layer.values();
}
}
}