com.simiacryptus.mindseye.art.ops.MomentMatcher Maven / Gradle / Ivy
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.art.ops;
import com.simiacryptus.mindseye.art.ArtSettings;
import com.simiacryptus.mindseye.art.TiledTrainable;
import com.simiacryptus.mindseye.art.VisualModifier;
import com.simiacryptus.mindseye.art.VisualModifierParameters;
import com.simiacryptus.mindseye.lang.Layer;
import com.simiacryptus.mindseye.lang.Result;
import com.simiacryptus.mindseye.lang.Tensor;
import com.simiacryptus.mindseye.lang.TensorList;
import com.simiacryptus.mindseye.lang.cudnn.MultiPrecision;
import com.simiacryptus.mindseye.lang.cudnn.Precision;
import com.simiacryptus.mindseye.layers.cudnn.*;
import com.simiacryptus.mindseye.layers.java.BoundedActivationLayer;
import com.simiacryptus.mindseye.layers.java.LinearActivationLayer;
import com.simiacryptus.mindseye.layers.java.NthPowerActivationLayer;
import com.simiacryptus.mindseye.network.DAGNode;
import com.simiacryptus.mindseye.network.InnerNode;
import com.simiacryptus.mindseye.network.PipelineNetwork;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.lang.ReferenceCountingBase;
import com.simiacryptus.ref.wrappers.RefArrays;
import com.simiacryptus.ref.wrappers.RefStream;
import com.simiacryptus.ref.wrappers.RefString;
import com.simiacryptus.util.Util;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.Arrays;
import java.util.UUID;
public class MomentMatcher implements VisualModifier {
private static final Logger log = LoggerFactory.getLogger(MomentMatcher.class);
private static int padding = 8;
private Precision precision = Precision.Float;
private int tileSize = ArtSettings.INSTANCE().defaultTileSize;
private double posCoeff = 1.0;
private double scaleCoeff = 1.0;
private double covCoeff = 1.0;
public double getCovCoeff() {
return covCoeff;
}
public void setCovCoeff(double covCoeff) {
this.covCoeff = covCoeff;
}
public double getPosCoeff() {
return posCoeff;
}
public void setPosCoeff(double posCoeff) {
this.posCoeff = posCoeff;
}
public Precision getPrecision() {
return precision;
}
public void setPrecision(Precision precision) {
this.precision = precision;
}
public double getScaleCoeff() {
return scaleCoeff;
}
public void setScaleCoeff(double scaleCoeff) {
this.scaleCoeff = scaleCoeff;
}
public int getTileSize() {
return tileSize;
}
@Nonnull
public MomentMatcher setTileSize(int tileSize) {
this.tileSize = tileSize;
return this;
}
@Nonnull
public static Layer lossSq(Precision precision, @Nonnull Tensor target) {
double rms = target.rms();
final Tensor bias = target.scale(0 == rms ? 1 : -Math.pow(rms, -1));
target.freeRef();
LinearActivationLayer linearActivationLayer = new LinearActivationLayer();
final double scale = 0 == rms ? 1 : Math.pow(rms, -1);
linearActivationLayer.setScale(scale);
AvgReducerLayer avgReducerLayerMultiPrecision = new AvgReducerLayer();
avgReducerLayerMultiPrecision.setPrecision(precision);
SquareActivationLayer squareActivationLayerMultiPrecision = new SquareActivationLayer();
squareActivationLayerMultiPrecision.setPrecision(precision);
ImgBandBiasLayer imgBandBiasLayerMultiPrecision = new ImgBandBiasLayer(bias);
imgBandBiasLayerMultiPrecision.setPrecision(precision);
Layer layer = PipelineNetwork.build(1,
linearActivationLayer,
imgBandBiasLayerMultiPrecision,
squareActivationLayerMultiPrecision,
avgReducerLayerMultiPrecision);
final String name = RefString.format("RMS[x-C] / %.0E", 0 == rms ? 1 : rms);
layer.setName(name);
return layer;
}
@Nonnull
public static Tensor sum(@Nonnull RefStream tensorStream) {
return RefUtil.orElse(tensorStream.reduce((a, b) -> {
return Tensor.add(a, b);
}), null);
}
@Nonnull
public static UUID getAppendUUID(@Nonnull PipelineNetwork network, @Nonnull Class layerClass) {
DAGNode head = network.getHead();
Layer layer = head.getLayer();
network.freeRef();
head.freeRef();
if (null == layer)
return UUID.randomUUID();
UUID uuid = UUID.nameUUIDFromBytes((layer.getId().toString() + layerClass.getName()).getBytes());
layer.freeRef();
return uuid;
}
@Nonnull
public static Tensor transform(PipelineNetwork network, Tensor in, Precision precision) {
if (null == in) {
network.freeRef();
return null;
}
assert in.assertAlive();
PipelineNetwork copyPipeline = network.copyPipeline();
network.freeRef();
MultiPrecision.setPrecision(copyPipeline.addRef(), precision);
assert copyPipeline != null;
copyPipeline.visitLayers(layer -> {
if (layer instanceof ImgBandBiasLayer) {
((ImgBandBiasLayer) layer).setWeights(i -> 0);
} else {
//log.info(String.format("Layer %s: %s", layer.getClass().getSimpleName(), layer.getName()));
}
layer.freeRef();
});
Result result = copyPipeline.eval(in);
TensorList data = result.getData();
final Tensor tensor = data.get(0);
result.freeRef();
data.freeRef();
copyPipeline.freeRef();
return tensor;
}
@Nonnull
public static PipelineNetwork gateNetwork(@Nonnull PipelineNetwork network, Tensor finalMask) {
final PipelineNetwork copyPipeline = network.copyPipeline();
network.freeRef();
assert copyPipeline != null;
final DAGNode head = copyPipeline.getHead();
copyPipeline.add(new ProductLayer(), head, copyPipeline.constValue(finalMask)).freeRef();
return copyPipeline;
}
@Nonnull
public static Tensor toMask(@Nonnull Tensor tensor) {
if (tensor == null) return null;
Tensor mapPixels = tensor.mapPixels(pixel -> {
if (Arrays.stream(pixel).filter(x -> x != 0).findFirst().isPresent()) {
return Arrays.stream(pixel).map(x -> 1).toArray();
} else {
return Arrays.stream(pixel).map(x -> 0).toArray();
}
});
tensor.freeRef();
return mapPixels;
}
public static boolean test(@Nonnull PipelineNetwork network, @Nonnull Tensor... images) {
if (images.length > 1) {
Boolean test = RefUtil.get(RefArrays.stream(images).map(x -> test(network.addRef(), x)).reduce((a, b) -> a && b));
network.freeRef();
return test;
}
try {
network.eval(images[0].addRef()).freeRef();
network.freeRef();
RefUtil.freeRef(images);
return true;
} catch (Throwable e) {
throw Util.throwException(e);
//return false;
}
}
@Nonnull
protected static Tensor eval(int pixels, @Nonnull PipelineNetwork network, int tileSize, double power, @Nonnull Tensor[] image) {
if (image.length <= 0) {
network.freeRef();
RefUtil.freeRef(image);
throw new IllegalArgumentException("image.length <= 0");
}
final Tensor sum = sum(RefArrays.stream(image).flatMap(img -> {
int[] imageDimensions = img.getDimensions();
final Layer[] selectors = TiledTrainable.selectors(padding, imageDimensions[0], imageDimensions[1], tileSize,
true);
if (selectors.length <= 0) {
img.freeRef();
RefUtil.freeRef(selectors);
throw new IllegalArgumentException("selectors.length <= 0");
}
return RefArrays.stream(selectors).map(RefUtil.wrapInterface(selector -> {
//log.info(selector.toString());
Result result = selector.eval(img.addRef());
selector.freeRef();
TensorList data = result.getData();
result.freeRef();
Tensor tile = data.get(0);
data.freeRef();
int[] tileDimensions = tile.getDimensions();
int tilePixels = tileDimensions[0] * tileDimensions[1];
Result result1 = network.eval(tile);
TensorList data1 = result1.getData();
result1.freeRef();
Tensor tensor1 = data1.get(0);
data1.freeRef();
Tensor tensor = tensor1.map(x -> Math.pow(x, power));
tensor1.freeRef();
tensor.scaleInPlace(tilePixels);
return tensor;
}, img));
}));
network.freeRef();
sum.scaleInPlace(1.0 / pixels);
Tensor tensor = sum.map(x -> {
double x1 = Math.pow(x, 0 == power ? 1 : 1.0 / power);
if (Double.isFinite(x1)) {
return x1;
} else {
return 0;
}
});
sum.freeRef();
return tensor;
}
@Nonnull
@Override
public PipelineNetwork build(@Nonnull VisualModifierParameters visualModifierParameters) {
PipelineNetwork network = visualModifierParameters.copyNetwork();
MultiPrecision.setPrecision(network.addRef(), precision);
assert network != null;
Tensor evalRoot = avg(network.addRef(), getPixels(visualModifierParameters.getStyle()), visualModifierParameters.getStyle());
assert evalRoot != null;
double sumSq = evalRoot.sumSq();
evalRoot.freeRef();
log.info(RefString.format("Adjust for %s by %s: %s", network.getName(), this.getClass().getSimpleName(), sumSq));
double factor;
if (Double.isFinite(sumSq) && 0 < sumSq) factor = 1;
else factor = 1.0 / sumSq;
final Layer nextHead = new ScaleLayer(factor);
network.add(nextHead).freeRef();
Tensor mask = visualModifierParameters.getMask();
final double maskFactor;
final Tensor boolMask;
if (null != mask) {
boolMask = toMask(transform(network.addRef(), mask.addRef(), getPrecision()));
log.info("Mask: " + RefArrays.toString(boolMask.getDimensions()));
maskFactor = boolMask.doubleStream().average().getAsDouble();
} else {
maskFactor = 1;
boolMask = null;
}
PipelineNetwork maskedNetwork = network.copyPipeline();
MultiPrecision.setPrecision(maskedNetwork.addRef(), getPrecision());
assert maskedNetwork != null;
assert mask == null || test(maskedNetwork.addRef(), mask.addRef());
final MomentParams params = getMomentParams(network, maskFactor, visualModifierParameters.getStyle());
assert mask == null || test(maskedNetwork.addRef(), mask.addRef());
if (null != boolMask) {
final DAGNode head = maskedNetwork.getHead();
maskedNetwork.add(new ProductLayer(), head, maskedNetwork.constValue(boolMask)).freeRef();
}
assert mask == null || test(maskedNetwork.addRef(), mask.addRef());
final MomentParams nodes = getMomentNodes(maskedNetwork.addRef(), maskFactor);
assert mask == null || test(maskedNetwork.addRef(), mask.addRef());
final MomentParams momentParams = new MomentParams(nodes.avgNode.addRef(), params.avgValue.addRef(),
nodes.rmsNode.addRef(), params.rmsValue.addRef(), nodes.covNode.addRef(), params.covValue.addRef(),
MomentMatcher.this);
params.freeRef();
nodes.freeRef();
momentParams.addLoss(maskedNetwork.addRef()).freeRef();
momentParams.freeRef();
assert mask == null || test(maskedNetwork.addRef(), mask.addRef());
visualModifierParameters.freeRef();
MultiPrecision.setPrecision(maskedNetwork.addRef(), getPrecision());
maskedNetwork.freeze();
RefUtil.freeRef(mask);
return maskedNetwork;
}
public int getPixels(@Nonnull Tensor[] images) {
return Math.max(1, RefArrays.stream(images).mapToInt(tensor -> {
int[] dimensions = tensor.getDimensions();
tensor.freeRef();
return dimensions[0] * dimensions[1];
}).sum());
}
@Nonnull
public MomentMatcher.MomentParams getMomentParams(@Nonnull PipelineNetwork network, double maskFactor, @Nonnull Tensor... images) {
int pixels = getPixels(RefUtil.addRef(images));
DAGNode mainIn = network.getHead();
BandAvgReducerLayer bandAvgReducerLayerMultiPrecision1 = new BandAvgReducerLayer();
bandAvgReducerLayerMultiPrecision1.setPrecision(getPrecision());
InnerNode avgNode = network.add(bandAvgReducerLayerMultiPrecision1, mainIn.addRef()); // Scale the average metrics by 1/x
Tensor avgValue = eval(pixels, network.addRef(), getTileSize(), 1.0, RefUtil.addRef(images));
ScaleLayer scaleLayerMultiPrecision2 = new ScaleLayer(-1 / maskFactor);
scaleLayerMultiPrecision2.setPrecision(getPrecision());
ImgBandDynamicBiasLayer imgBandDynamicBiasLayerMultiPrecision = new ImgBandDynamicBiasLayer();
imgBandDynamicBiasLayerMultiPrecision.setPrecision(getPrecision());
InnerNode recentered = network.add(imgBandDynamicBiasLayerMultiPrecision, mainIn,
network.add(scaleLayerMultiPrecision2, avgNode.addRef()));
NthPowerActivationLayer nthPowerActivationLayer1 = new NthPowerActivationLayer();
nthPowerActivationLayer1.setPower(0.5);
SquareActivationLayer squareActivationLayerMultiPrecision = new SquareActivationLayer();
squareActivationLayerMultiPrecision.setPrecision(getPrecision());
BandAvgReducerLayer bandAvgReducerLayerMultiPrecision = new BandAvgReducerLayer();
bandAvgReducerLayerMultiPrecision.setPrecision(getPrecision());
ScaleLayer scaleLayerMultiPrecision1 = new ScaleLayer(1 / maskFactor);
scaleLayerMultiPrecision1.setPrecision(getPrecision());
InnerNode rmsNode = network.add(nthPowerActivationLayer1,
network.add(scaleLayerMultiPrecision1,
network.add(bandAvgReducerLayerMultiPrecision,
network.add(squareActivationLayerMultiPrecision, recentered.addRef()))));
Tensor rmsValue = eval(pixels, network.addRef(), getTileSize(), 2.0, RefUtil.addRef(images));
BoundedActivationLayer boundedActivationLayer1 = new BoundedActivationLayer();
boundedActivationLayer1.setMinValue(0.0);
boundedActivationLayer1.setMaxValue(1e4);
NthPowerActivationLayer nthPowerActivationLayer = new NthPowerActivationLayer();
nthPowerActivationLayer.setPower(-1);
ProductLayer productLayerMultiPrecision = new ProductLayer();
productLayerMultiPrecision.setPrecision(getPrecision());
InnerNode rescaled = network.add(productLayerMultiPrecision, recentered,
network.add(boundedActivationLayer1,
network.add(nthPowerActivationLayer, rmsNode.addRef())));
GramianLayer gramianLayerMultiPrecision = new GramianLayer(getAppendUUID(network.addRef(), GramianLayer.class));
gramianLayerMultiPrecision.setPrecision(getPrecision());
ScaleLayer scaleLayerMultiPrecision = new ScaleLayer(1 / maskFactor);
scaleLayerMultiPrecision.setPrecision(getPrecision());
InnerNode covNode = network.add(scaleLayerMultiPrecision, network
.add(gramianLayerMultiPrecision, rescaled)); // Scale the gram matrix by 1/x (elements are averages)
Tensor covValue = eval(pixels, network, getTileSize(), 1.0, images);
return new MomentParams(avgNode, avgValue, rmsNode, rmsValue, covNode, covValue, MomentMatcher.this);
}
@Nonnull
public MomentMatcher.MomentParams getMomentNodes(@Nonnull PipelineNetwork network, double maskFactor) {
DAGNode mainIn = network.getHead();
BandAvgReducerLayer bandAvgReducerLayerMultiPrecision1 = new BandAvgReducerLayer();
bandAvgReducerLayerMultiPrecision1.setPrecision(getPrecision());
InnerNode avgNode = network.add(bandAvgReducerLayerMultiPrecision1, mainIn.addRef()); // Scale the average metrics by 1/x
ScaleLayer scaleLayerMultiPrecision2 = new ScaleLayer(-1 / maskFactor);
scaleLayerMultiPrecision2.setPrecision(getPrecision());
ImgBandDynamicBiasLayer imgBandDynamicBiasLayerMultiPrecision = new ImgBandDynamicBiasLayer();
imgBandDynamicBiasLayerMultiPrecision.setPrecision(getPrecision());
InnerNode recentered = network.add(imgBandDynamicBiasLayerMultiPrecision, mainIn,
network.add(scaleLayerMultiPrecision2, avgNode.addRef()));
NthPowerActivationLayer nthPowerActivationLayer1 = new NthPowerActivationLayer();
nthPowerActivationLayer1.setPower(0.5);
SquareActivationLayer squareActivationLayerMultiPrecision = new SquareActivationLayer();
squareActivationLayerMultiPrecision.setPrecision(getPrecision());
BandAvgReducerLayer bandAvgReducerLayerMultiPrecision = new BandAvgReducerLayer();
bandAvgReducerLayerMultiPrecision.setPrecision(getPrecision());
ScaleLayer scaleLayerMultiPrecision1 = new ScaleLayer(1 / maskFactor);
scaleLayerMultiPrecision1.setPrecision(getPrecision());
InnerNode rmsNode = network.add(nthPowerActivationLayer1,
network.add(scaleLayerMultiPrecision1,
network.add(bandAvgReducerLayerMultiPrecision,
network.add(squareActivationLayerMultiPrecision, recentered.addRef()))));
BoundedActivationLayer boundedActivationLayer1 = new BoundedActivationLayer();
boundedActivationLayer1.setMinValue(0.0);
boundedActivationLayer1.setMaxValue(1e4);
NthPowerActivationLayer nthPowerActivationLayer = new NthPowerActivationLayer();
nthPowerActivationLayer.setPower(-1);
ProductLayer productLayerMultiPrecision = new ProductLayer();
productLayerMultiPrecision.setPrecision(getPrecision());
InnerNode rescaled = network.add(productLayerMultiPrecision, recentered,
network.add(boundedActivationLayer1,
network.add(nthPowerActivationLayer, rmsNode.addRef())));
GramianLayer gramianLayerMultiPrecision = new GramianLayer(getAppendUUID(network.addRef(), GramianLayer.class));
gramianLayerMultiPrecision.setPrecision(getPrecision());
ScaleLayer scaleLayerMultiPrecision = new ScaleLayer(1 / maskFactor);
scaleLayerMultiPrecision.setPrecision(getPrecision());
InnerNode covNode = network.add(scaleLayerMultiPrecision, network
.add(gramianLayerMultiPrecision, rescaled)); // Scale the gram matrix by 1/x (elements are averages)
network.freeRef();
return new MomentParams(avgNode, null, rmsNode, null, covNode, null, MomentMatcher.this);
}
@Nullable
protected Tensor avg(@Nonnull PipelineNetwork network, int pixels, @Nonnull Tensor[] image) {
// return eval(pixels, network, getTileSize(), 1.0, image);
PipelineNetwork avgNet = PipelineNetwork.build(1, network, new BandAvgReducerLayer());
return eval(pixels, avgNet, getTileSize(), 1.0, image);
}
private static class MomentParams extends ReferenceCountingBase {
private final InnerNode avgNode;
private final Tensor avgValue;
private final InnerNode rmsNode;
private final Tensor rmsValue;
private final InnerNode covNode;
private final Tensor covValue;
private final MomentMatcher parent;
private MomentParams(InnerNode avgNode, Tensor avgValue, InnerNode rmsNode, Tensor rmsValue, InnerNode covNode,
Tensor covValue, MomentMatcher parent) {
this.parent = parent;
this.avgNode = avgNode;
this.avgValue = avgValue;
this.rmsNode = rmsNode;
this.rmsValue = rmsValue;
this.covNode = covNode;
this.covValue = covValue;
}
public void _free() {
if (null != avgNode)
avgNode.freeRef();
if (null != avgValue)
avgValue.freeRef();
if (null != rmsNode)
rmsNode.freeRef();
if (null != rmsValue)
rmsValue.freeRef();
if (null != covNode)
covNode.freeRef();
if (null != covValue)
covValue.freeRef();
super._free();
}
@Nullable
public InnerNode addLoss(@Nonnull PipelineNetwork network) {
ScaleLayer scaleLayerMultiPrecision = new ScaleLayer(parent.getCovCoeff());
scaleLayerMultiPrecision.setPrecision(parent.getPrecision());
ScaleLayer scaleLayerMultiPrecision1 = new ScaleLayer(parent.getScaleCoeff());
scaleLayerMultiPrecision1.setPrecision(parent.getPrecision());
ScaleLayer scaleLayerMultiPrecision2 = new ScaleLayer(parent.getPosCoeff());
scaleLayerMultiPrecision2.setPrecision(parent.getPrecision());
SumInputsLayer sumInputsLayerMultiPrecision = new SumInputsLayer();
sumInputsLayerMultiPrecision.setPrecision(parent.getPrecision());
final InnerNode wrap = network.add(sumInputsLayerMultiPrecision,
network.add(scaleLayerMultiPrecision2,
network.add(lossSq(parent.getPrecision(), avgValue.addRef()), avgNode.addRef())),
network.add(scaleLayerMultiPrecision1,
network.add(lossSq(parent.getPrecision(), rmsValue.addRef()), rmsNode.addRef())),
network.add(scaleLayerMultiPrecision,
network.add(lossSq(parent.getPrecision(), covValue.addRef()), covNode.addRef())));
network.freeRef();
return wrap;
}
@Nonnull
public @Override
@SuppressWarnings("unused")
MomentParams addRef() {
return (MomentParams) super.addRef();
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy