com.simiacryptus.mindseye.test.TestUtil Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mindseye-test Show documentation
Show all versions of mindseye-test Show documentation
Testing Tools for Neural Network Components
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.test;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.simiacryptus.mindseye.lang.Layer;
import com.simiacryptus.mindseye.lang.Tensor;
import com.simiacryptus.mindseye.layers.LoggingWrapperLayer;
import com.simiacryptus.mindseye.layers.MonitoringWrapperLayer;
import com.simiacryptus.mindseye.layers.StochasticComponent;
import com.simiacryptus.mindseye.network.DAGNetwork;
import com.simiacryptus.mindseye.network.DAGNode;
import com.simiacryptus.mindseye.opt.Step;
import com.simiacryptus.mindseye.opt.TrainingMonitor;
import com.simiacryptus.notebook.FileHTTPD;
import com.simiacryptus.notebook.NotebookOutput;
import com.simiacryptus.util.JsonUtil;
import com.simiacryptus.util.MonitoredObject;
import com.simiacryptus.util.data.DoubleStatistics;
import com.simiacryptus.util.data.PercentileStatistics;
import com.simiacryptus.util.data.ScalarStatistics;
import com.simiacryptus.util.io.GifSequenceWriter;
import guru.nidi.graphviz.attribute.Label;
import guru.nidi.graphviz.attribute.RankDir;
import guru.nidi.graphviz.engine.Format;
import guru.nidi.graphviz.engine.Graphviz;
import guru.nidi.graphviz.model.*;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.plot.PlotCanvas;
import smile.plot.ScatterPlot;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.imageio.ImageIO;
import javax.swing.*;
import javax.swing.filechooser.FileFilter;
import java.awt.*;
import java.awt.event.*;
import java.awt.image.BufferedImage;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.lang.ref.WeakReference;
import java.net.URI;
import java.util.*;
import java.util.List;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.*;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;
import java.util.stream.Stream;
/**
* The type Image apply util.
*/
public class TestUtil {
/**
* The constant S3_ROOT.
*/
public static final URI S3_ROOT = URI.create("https://s3-us-west-2.amazonaws.com/simiacryptus/");
private static final Logger logger = LoggerFactory.getLogger(TestUtil.class);
/**
* The constant scheduledThreadPool.
*/
public static ScheduledExecutorService scheduledThreadPool = Executors.newScheduledThreadPool(1, new ThreadFactoryBuilder().setDaemon(true).build());
private static int gifNumber = 0;
/**
* Add logging.
*
* @param network the network
*/
public static void addLogging(@Nonnull final DAGNetwork network) {
network.visitNodes(node -> {
if (!(node.getLayer() instanceof LoggingWrapperLayer)) {
node.setLayer(new LoggingWrapperLayer(node.getLayer()));
}
});
}
/**
* Add monitoring.
*
* @param network the network
* @param monitoringRoot the monitoring root
*/
public static void addMonitoring(@Nonnull final DAGNetwork network, @Nonnull final MonitoredObject monitoringRoot) {
network.visitNodes(node -> {
if (!(node.getLayer() instanceof MonitoringWrapperLayer)) {
node.setLayer(new MonitoringWrapperLayer(node.getLayer()).addTo(monitoringRoot));
}
});
}
/**
* Compare plot canvas.
*
* @param title the title
* @param trials the trials
* @return the plot canvas
*/
public static PlotCanvas compare(final String title, @Nonnull final ProblemRun... trials) {
try {
final DoubleSummaryStatistics xStatistics = Arrays.stream(trials)
.flatMapToDouble(x -> x.history.stream().mapToDouble(step -> step.iteration))
.filter(Double::isFinite)
.summaryStatistics();
final DoubleSummaryStatistics yStatistics = Arrays.stream(trials)
.flatMapToDouble(x -> x.history.stream().filter(y -> y.fitness > 0).mapToDouble(step -> Math.log10(step.fitness)))
.filter(Double::isFinite)
.summaryStatistics();
if (xStatistics.getCount() == 0) {
logger.info("No Data");
return null;
}
@Nonnull final double[] lowerBound = {xStatistics.getCount() == 0 ? 0 : xStatistics.getMin(), yStatistics.getCount() < 2 ? 0 : yStatistics.getMin()};
@Nonnull final double[] upperBound = {xStatistics.getCount() == 0 ? 1 : xStatistics.getMax(), yStatistics.getCount() < 2 ? 1 : yStatistics.getMax()};
@Nonnull final PlotCanvas canvas = new PlotCanvas(lowerBound, upperBound);
canvas.setTitle(title);
canvas.setAxisLabels("Iteration", "log10(Fitness)");
canvas.setSize(600, 400);
final List filtered = Arrays.stream(trials).filter(x -> !x.history.isEmpty()).collect(Collectors.toList());
if (filtered.isEmpty()) {
logger.info("No Data");
return null;
}
DoubleSummaryStatistics valueStatistics = filtered.stream().flatMap(x -> x.history.stream()).mapToDouble(x -> x.fitness).filter(x -> x > 0).summaryStatistics();
logger.info(String.format("Plotting range=%s, %s; valueStats=%s", Arrays.toString(lowerBound), Arrays.toString(upperBound), valueStatistics));
for (@Nonnull final ProblemRun trial : filtered) {
final double[][] pts = trial.history.stream().map(step -> new double[]{
step.iteration, Math.log10(Math.max(step.fitness, valueStatistics.getMin()))})
.filter(x -> Arrays.stream(x).allMatch(Double::isFinite))
.toArray(i -> new double[i][]);
if (pts.length > 1) {
logger.info(String.format("Plotting %s points for %s", pts.length, trial.name));
canvas.add(trial.plot(pts));
} else {
logger.info(String.format("Only %s points for %s", pts.length, trial.name));
}
}
return canvas;
} catch (@Nonnull final Exception e) {
e.printStackTrace(System.out);
return null;
}
}
/**
* Compare plot canvas.
*
* @param title the title
* @param trials the trials
* @return the plot canvas
*/
public static PlotCanvas compareTime(final String title, @Nonnull final ProblemRun... trials) {
try {
final DoubleSummaryStatistics[] xStatistics = Arrays.stream(trials)
.map(x -> x.history.stream().mapToDouble(step -> step.epochTime)
.filter(Double::isFinite)
.summaryStatistics()).toArray(i -> new DoubleSummaryStatistics[i]);
final double totalTime = Arrays.stream(xStatistics).mapToDouble(x -> x.getMax() - x.getMin()).max().getAsDouble();
final DoubleSummaryStatistics yStatistics = Arrays.stream(trials)
.flatMapToDouble(x -> x.history.stream().filter(y -> y.fitness > 0).mapToDouble(step -> Math.log10(step.fitness)))
.filter(Double::isFinite)
.summaryStatistics();
if (yStatistics.getCount() == 0) {
logger.info("No Data");
return null;
}
@Nonnull final double[] lowerBound = {0, yStatistics.getCount() == 0 ? 0 : yStatistics.getMin()};
@Nonnull final double[] upperBound = {totalTime / 1000.0, yStatistics.getCount() == 1 ? 0 : yStatistics.getMax()};
@Nonnull final PlotCanvas canvas = new PlotCanvas(lowerBound, upperBound);
canvas.setTitle(title);
canvas.setAxisLabels("Time", "log10(Fitness)");
canvas.setSize(600, 400);
final List filtered = Arrays.stream(trials).filter(x -> !x.history.isEmpty()).collect(Collectors.toList());
if (filtered.isEmpty()) {
logger.info("No Data");
return null;
}
DoubleSummaryStatistics valueStatistics = filtered.stream().flatMap(x -> x.history.stream()).mapToDouble(x -> x.fitness).filter(x -> x > 0).summaryStatistics();
logger.info(String.format("Plotting range=%s, %s; valueStats=%s", Arrays.toString(lowerBound), Arrays.toString(upperBound), valueStatistics));
for (int t = 0; t < filtered.size(); t++) {
final ProblemRun trial = filtered.get(t);
final DoubleSummaryStatistics trialStats = xStatistics[t];
final double[][] pts = trial.history.stream().map(step -> {
return new double[]{(step.epochTime - trialStats.getMin()) / 1000.0, Math.log10(Math.max(step.fitness, valueStatistics.getMin()))};
}).filter(x -> Arrays.stream(x).allMatch(Double::isFinite))
.toArray(i -> new double[i][]);
if (pts.length > 1) {
logger.info(String.format("Plotting %s points for %s", pts.length, trial.name));
canvas.add(trial.plot(pts));
} else {
logger.info(String.format("Only %s points for %s", pts.length, trial.name));
}
}
return canvas;
} catch (@Nonnull final Exception e) {
e.printStackTrace(System.out);
return null;
}
}
/**
* Remove performance wrappers.
*
* @param log the logger
* @param network the network
*/
public static void extractPerformance(@Nonnull final NotebookOutput log, @Nonnull final DAGNetwork network) {
log.p("Per-key Performance Metrics:");
log.run(() -> {
@Nonnull final Map metrics = new HashMap<>();
network.visitNodes(node -> {
if (node.getLayer() instanceof MonitoringWrapperLayer) {
@Nullable final MonitoringWrapperLayer layer = node.getLayer();
Layer inner = layer.getInner();
String str = inner.toString();
str += " class=" + inner.getClass().getName();
// if(inner instanceof MultiPrecision>) {
// str += "; precision=" + ((MultiPrecision) inner).getPrecision().name();
// }
metrics.put(str, layer);
}
});
TestUtil.logger.info("Performance: \n\t" + metrics.entrySet().stream().sorted(Comparator.comparing(x -> -x.getValue().getForwardPerformance().getMean())).map(e -> {
@Nonnull final PercentileStatistics performanceF = e.getValue().getForwardPerformance();
@Nonnull final PercentileStatistics performanceB = e.getValue().getBackwardPerformance();
return String.format("%.6fs +- %.6fs (%d) <- %s", performanceF.getMean(), performanceF.getStdDev(), performanceF.getCount(), e.getKey()) +
(performanceB.getCount() == 0 ? "" : String.format("%n\tBack: %.6fs +- %.6fs (%s)", performanceB.getMean(), performanceB.getStdDev(), performanceB.getCount()));
}).reduce((a, b) -> a + "\n\t" + b).get());
});
removeInstrumentation(network);
}
/**
* Remove instrumentation.
*
* @param network the network
*/
public static void removeInstrumentation(@Nonnull final DAGNetwork network) {
network.visitNodes(node -> {
if (node.getLayer() instanceof MonitoringWrapperLayer) {
Layer layer = node.getLayer().getInner();
layer.addRef();
node.setLayer(layer);
layer.freeRef();
}
});
}
/**
* Sample performance buildMap.
*
* @param network the network
* @return the buildMap
*/
public static Map samplePerformance(@Nonnull final DAGNetwork network) {
@Nonnull final Map metrics = new HashMap<>();
network.visitLayers(layer -> {
if (layer instanceof MonitoringWrapperLayer) {
MonitoringWrapperLayer monitoringWrapperLayer = (MonitoringWrapperLayer) layer;
Layer inner = monitoringWrapperLayer.getInner();
String str = inner.toString();
str += " class=" + inner.getClass().getName();
HashMap row = new HashMap<>();
row.put("fwd", monitoringWrapperLayer.getForwardPerformance().getMetrics());
row.put("rev", monitoringWrapperLayer.getBackwardPerformance().getMetrics());
metrics.put(str, row);
}
});
return metrics;
}
/**
* Gets monitor.
*
* @param history the history
* @return the monitor
*/
public static TrainingMonitor getMonitor(@Nonnull final List history) {
return getMonitor(history, null);
}
/**
* Gets monitor.
*
* @param history the history
* @param network the network
* @return the monitor
*/
public static TrainingMonitor getMonitor(@Nonnull final List history, final Layer network) {
return new TrainingMonitor() {
@Override
public void clear() {
super.clear();
}
@Override
public void log(final String msg) {
logger.info(msg);
super.log(msg);
}
@Override
public void onStepComplete(@Nonnull final Step currentPoint) {
history.add(new StepRecord(currentPoint.point.getMean(), currentPoint.time, currentPoint.iteration));
super.onStepComplete(currentPoint);
}
};
}
/**
* Add performance wrappers.
*
* @param network the network
*/
public static void instrumentPerformance(@Nonnull final DAGNetwork network) {
network.visitNodes(node -> {
Layer layer = node.getLayer();
if (layer instanceof MonitoringWrapperLayer) {
((MonitoringWrapperLayer) layer).shouldRecordSignalMetrics(false);
} else {
@Nonnull MonitoringWrapperLayer monitoringWrapperLayer = new MonitoringWrapperLayer(layer).shouldRecordSignalMetrics(false);
node.setLayer(monitoringWrapperLayer);
monitoringWrapperLayer.freeRef();
}
});
}
/**
* Plot plot canvas.
*
* @param history the history
* @return the plot canvas
*/
public static JPanel plot(@Nonnull final List history) {
try {
final DoubleSummaryStatistics valueStats = history.stream().mapToDouble(x -> x.fitness).summaryStatistics();
double min = valueStats.getMin();
if (0 < min) {
double[][] data = history.stream().map(step -> new double[]{
step.iteration, Math.log10(Math.max(min, step.fitness))})
.filter(x -> Arrays.stream(x).allMatch(Double::isFinite))
.toArray(i -> new double[i][]);
if (Arrays.stream(data).mapToInt(x -> x.length).sum() == 0) return null;
@Nonnull final PlotCanvas plot = ScatterPlot.plot(data);
plot.setTitle("Convergence Plot");
plot.setAxisLabels("Iteration", "log10(Fitness)");
plot.setSize(600, 400);
return plot;
} else {
double[][] data = history.stream().map(step -> new double[]{
step.iteration, step.fitness})
.filter(x -> Arrays.stream(x).allMatch(Double::isFinite))
.toArray(i -> new double[i][]);
if (Arrays.stream(data).mapToInt(x -> x.length).sum() == 0) return null;
@Nonnull final PlotCanvas plot = ScatterPlot.plot(data);
plot.setTitle("Convergence Plot");
plot.setAxisLabels("Iteration", "Fitness");
plot.setSize(600, 400);
return plot;
}
} catch (@Nonnull final Exception e) {
logger.warn("Error plotting", e);
return null;
}
}
/**
* Plot plot canvas.
*
* @param history the history
* @return the plot canvas
*/
public static PlotCanvas plotTime(@Nonnull final List history) {
try {
final LongSummaryStatistics timeStats = history.stream().mapToLong(x -> x.epochTime).summaryStatistics();
final DoubleSummaryStatistics valueStats = history.stream().mapToDouble(x -> x.fitness).filter(x -> x > 0).summaryStatistics();
@Nonnull final PlotCanvas plot = ScatterPlot.plot(history.stream().map(step -> new double[]{
(step.epochTime - timeStats.getMin()) / 1000.0, Math.log10(Math.max(valueStats.getMin(), step.fitness))})
.filter(x -> Arrays.stream(x).allMatch(Double::isFinite))
.toArray(i -> new double[i][]));
plot.setTitle("Convergence Plot");
plot.setAxisLabels("Time", "log10(Fitness)");
plot.setSize(600, 400);
return plot;
} catch (@Nonnull final Exception e) {
e.printStackTrace(System.out);
return null;
}
}
/**
* Print data statistics.
*
* @param log the logger
* @param data the data
*/
public static void printDataStatistics(@Nonnull final NotebookOutput log, @Nonnull final Tensor[][] data) {
for (int col = 1; col < data[0].length; col++) {
final int c = col;
log.out("Learned Representation Statistics for Column " + col + " (all bands)");
log.eval(() -> {
@Nonnull final ScalarStatistics scalarStatistics = new ScalarStatistics();
Arrays.stream(data)
.flatMapToDouble(row -> Arrays.stream(row[c].getData()))
.forEach(v -> scalarStatistics.add(v));
return scalarStatistics.getMetrics();
});
final int _col = col;
log.out("Learned Representation Statistics for Column " + col + " (by band)");
log.eval(() -> {
@Nonnull final int[] dimensions = data[0][_col].getDimensions();
return IntStream.range(0, dimensions[2]).mapToObj(x -> x).flatMap(b -> {
return Arrays.stream(data).map(r -> r[_col]).map(tensor -> {
@Nonnull final ScalarStatistics scalarStatistics = new ScalarStatistics();
scalarStatistics.add(new Tensor(dimensions[0], dimensions[1]).setByCoord(coord -> tensor.get(coord.getCoords()[0], coord.getCoords()[1], b)).getData());
return scalarStatistics;
});
}).map(x -> x.getMetrics().toString()).reduce((a, b) -> a + "\n" + b).get();
});
}
}
/**
* Print history.
*
* @param log the logger
* @param history the history
*/
public static void printHistory(@Nonnull final NotebookOutput log, @Nonnull final List history) {
if (!history.isEmpty()) {
log.out("Convergence Plot: ");
log.eval(() -> {
final DoubleSummaryStatistics valueStats = history.stream().mapToDouble(x -> x.fitness).filter(x -> x > 0).summaryStatistics();
@Nonnull final PlotCanvas plot = ScatterPlot.plot(history.stream().map(step ->
new double[]{step.iteration, Math.log10(Math.max(valueStats.getMin(), step.fitness))})
.toArray(i -> new double[i][]));
plot.setTitle("Convergence Plot");
plot.setAxisLabels("Iteration", "log10(Fitness)");
plot.setSize(600, 400);
return plot;
});
}
}
/**
* Remove monitoring.
*
* @param network the network
*/
public static void removeLogging(@Nonnull final DAGNetwork network) {
network.visitNodes(node -> {
if (node.getLayer() instanceof LoggingWrapperLayer) {
node.setLayer(((LoggingWrapperLayer) node.getLayer()).getInner());
}
});
}
/**
* Remove monitoring.
*
* @param network the network
*/
public static void removeMonitoring(@Nonnull final DAGNetwork network) {
network.visitNodes(node -> {
if (node.getLayer() instanceof MonitoringWrapperLayer) {
node.setLayer(((MonitoringWrapperLayer) node.getLayer()).getInner());
}
});
}
/**
* Render string.
*
* @param log the logger
* @param tensor the tensor
* @param normalize the normalize
* @return the string
*/
public static CharSequence render(@Nonnull final NotebookOutput log, @Nonnull final Tensor tensor, final boolean normalize) {
return TestUtil.renderToImages(tensor, normalize).map(image -> {
return log.png(image, "");
}).reduce((a, b) -> a + b).get();
}
/**
* Render to images stream.
*
* @param tensor the tensor
* @param normalize the normalize
* @return the stream
*/
public static Stream renderToImages(@Nonnull final Tensor tensor, final boolean normalize) {
final DoubleStatistics[] statistics = IntStream.range(0, tensor.getDimensions()[2]).mapToObj(band -> {
return new DoubleStatistics().accept(tensor.coordStream(false)
.filter(x -> x.getCoords()[2] == band)
.mapToDouble(c -> tensor.get(c)).toArray());
}).toArray(i -> new DoubleStatistics[i]);
@Nonnull final BiFunction transform = (value, stats) -> {
final double width = Math.sqrt(2) * stats.getStandardDeviation();
final double centered = value - stats.getAverage();
final double distance = Math.abs(value - stats.getAverage());
final double positiveMax = stats.getMax() - stats.getAverage();
final double negativeMax = stats.getAverage() - stats.getMin();
final double unitValue;
if (value < centered) {
if (distance > width) {
unitValue = 0.25 - 0.25 * ((distance - width) / (negativeMax - width));
} else {
unitValue = 0.5 - 0.25 * (distance / width);
}
} else {
if (distance > width) {
unitValue = 0.75 + 0.25 * ((distance - width) / (positiveMax - width));
} else {
unitValue = 0.5 + 0.25 * (distance / width);
}
}
return 0xFF * unitValue;
};
tensor.coordStream(true).collect(Collectors.groupingBy(x -> x.getCoords()[2], Collectors.toList()));
@Nullable final Tensor normal = tensor.mapCoords((c) -> transform.apply(tensor.get(c), statistics[c.getCoords()[2]]))
.map(v -> Math.min(0xFF, Math.max(0, v)));
return (normalize ? normal : tensor).toImages().stream();
}
/**
* Resize buffered png.
*
* @param source the source
* @param size the size
* @return the buffered png
*/
@Nonnull
public static BufferedImage resize(@Nonnull final BufferedImage source, final int size) {
return resize(source, size, false);
}
/**
* Resize buffered png.
*
* @param source the source
* @param size the size
* @param preserveAspect the preserve aspect
* @return the buffered png
*/
@Nonnull
public static BufferedImage resize(@Nonnull final BufferedImage source, final int size, boolean preserveAspect) {
if (size <= 0) return source;
double zoom = (double) size / source.getWidth();
int steps = (int) Math.ceil(Math.abs(Math.log(zoom)) / Math.log(1.5));
BufferedImage img = source;
for (int i = 1; i <= steps; i++) {
double pos = ((double) i / steps);
double z = Math.pow(zoom, pos);
img = resize(img,
(int) (source.getWidth() * z),
(int) ((preserveAspect ? source.getHeight() : source.getWidth()) * z)
);
}
return img;
}
/**
* Resize px buffered png.
*
* @param source the source
* @param size the size
* @return the buffered png
*/
public static BufferedImage resizePx(@Nonnull final BufferedImage source, final long size) {
if (size < 0) return source;
double scale = Math.sqrt(size / ((double) source.getHeight() * source.getWidth()));
int width = (int) (scale * source.getWidth());
int height = (int) (scale * source.getHeight());
return resize(source, width, height);
}
/**
* Resize buffered png.
*
* @param source the source
* @param width the width
* @param height the height
* @return the buffered png
*/
@Nonnull
public static BufferedImage resize(BufferedImage source, int width, int height) {
@Nonnull final BufferedImage image = new BufferedImage(width, height, source.getType());
@Nonnull final Graphics2D graphics = (Graphics2D) image.getGraphics();
HashMap
© 2015 - 2025 Weber Informatics LLC | Privacy Policy