All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.test.TestUtil Maven / Gradle / Ivy

There is a newer version: 2.1.0
Show newest version
/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.test;

import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.simiacryptus.mindseye.lang.Layer;
import com.simiacryptus.mindseye.lang.Tensor;
import com.simiacryptus.mindseye.layers.LoggingWrapperLayer;
import com.simiacryptus.mindseye.layers.MonitoringWrapperLayer;
import com.simiacryptus.mindseye.layers.StochasticComponent;
import com.simiacryptus.mindseye.network.DAGNetwork;
import com.simiacryptus.mindseye.network.DAGNode;
import com.simiacryptus.mindseye.opt.Step;
import com.simiacryptus.mindseye.opt.TrainingMonitor;
import com.simiacryptus.notebook.FileHTTPD;
import com.simiacryptus.notebook.NotebookOutput;
import com.simiacryptus.util.JsonUtil;
import com.simiacryptus.util.MonitoredObject;
import com.simiacryptus.util.data.DoubleStatistics;
import com.simiacryptus.util.data.PercentileStatistics;
import com.simiacryptus.util.data.ScalarStatistics;
import com.simiacryptus.util.io.GifSequenceWriter;
import guru.nidi.graphviz.attribute.Label;
import guru.nidi.graphviz.attribute.RankDir;
import guru.nidi.graphviz.engine.Format;
import guru.nidi.graphviz.engine.Graphviz;
import guru.nidi.graphviz.model.*;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.plot.PlotCanvas;
import smile.plot.ScatterPlot;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.imageio.ImageIO;
import javax.swing.*;
import javax.swing.filechooser.FileFilter;
import java.awt.*;
import java.awt.event.*;
import java.awt.image.BufferedImage;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.lang.ref.WeakReference;
import java.net.URI;
import java.util.*;
import java.util.List;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.*;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;
import java.util.stream.Stream;

/**
 * The type Image apply util.
 */
public class TestUtil {
  /**
   * The constant S3_ROOT.
   */
  public static final URI S3_ROOT = URI.create("https://s3-us-west-2.amazonaws.com/simiacryptus/");
  private static final Logger logger = LoggerFactory.getLogger(TestUtil.class);
  /**
   * The constant scheduledThreadPool.
   */
  public static ScheduledExecutorService scheduledThreadPool = Executors.newScheduledThreadPool(1, new ThreadFactoryBuilder().setDaemon(true).build());
  private static int gifNumber = 0;

  /**
   * Add logging.
   *
   * @param network the network
   */
  public static void addLogging(@Nonnull final DAGNetwork network) {
    network.visitNodes(node -> {
      if (!(node.getLayer() instanceof LoggingWrapperLayer)) {
        node.setLayer(new LoggingWrapperLayer(node.getLayer()));
      }
    });
  }

  /**
   * Add monitoring.
   *
   * @param network        the network
   * @param monitoringRoot the monitoring root
   */
  public static void addMonitoring(@Nonnull final DAGNetwork network, @Nonnull final MonitoredObject monitoringRoot) {
    network.visitNodes(node -> {
      if (!(node.getLayer() instanceof MonitoringWrapperLayer)) {
        node.setLayer(new MonitoringWrapperLayer(node.getLayer()).addTo(monitoringRoot));
      }
    });
  }

  /**
   * Compare plot canvas.
   *
   * @param title  the title
   * @param trials the trials
   * @return the plot canvas
   */
  public static PlotCanvas compare(final String title, @Nonnull final ProblemRun... trials) {
    try {
      final DoubleSummaryStatistics xStatistics = Arrays.stream(trials)
          .flatMapToDouble(x -> x.history.stream().mapToDouble(step -> step.iteration))
          .filter(Double::isFinite)
          .summaryStatistics();
      final DoubleSummaryStatistics yStatistics = Arrays.stream(trials)
          .flatMapToDouble(x -> x.history.stream().filter(y -> y.fitness > 0).mapToDouble(step -> Math.log10(step.fitness)))
          .filter(Double::isFinite)
          .summaryStatistics();
      if (xStatistics.getCount() == 0) {
        logger.info("No Data");
        return null;
      }
      @Nonnull final double[] lowerBound = {xStatistics.getCount() == 0 ? 0 : xStatistics.getMin(), yStatistics.getCount() < 2 ? 0 : yStatistics.getMin()};
      @Nonnull final double[] upperBound = {xStatistics.getCount() == 0 ? 1 : xStatistics.getMax(), yStatistics.getCount() < 2 ? 1 : yStatistics.getMax()};
      @Nonnull final PlotCanvas canvas = new PlotCanvas(lowerBound, upperBound);
      canvas.setTitle(title);
      canvas.setAxisLabels("Iteration", "log10(Fitness)");
      canvas.setSize(600, 400);
      final List filtered = Arrays.stream(trials).filter(x -> !x.history.isEmpty()).collect(Collectors.toList());
      if (filtered.isEmpty()) {
        logger.info("No Data");
        return null;
      }
      DoubleSummaryStatistics valueStatistics = filtered.stream().flatMap(x -> x.history.stream()).mapToDouble(x -> x.fitness).filter(x -> x > 0).summaryStatistics();
      logger.info(String.format("Plotting range=%s, %s; valueStats=%s", Arrays.toString(lowerBound), Arrays.toString(upperBound), valueStatistics));
      for (@Nonnull final ProblemRun trial : filtered) {
        final double[][] pts = trial.history.stream().map(step -> new double[]{
            step.iteration, Math.log10(Math.max(step.fitness, valueStatistics.getMin()))})
            .filter(x -> Arrays.stream(x).allMatch(Double::isFinite))
            .toArray(i -> new double[i][]);
        if (pts.length > 1) {
          logger.info(String.format("Plotting %s points for %s", pts.length, trial.name));
          canvas.add(trial.plot(pts));
        } else {
          logger.info(String.format("Only %s points for %s", pts.length, trial.name));
        }
      }
      return canvas;
    } catch (@Nonnull final Exception e) {
      e.printStackTrace(System.out);
      return null;
    }
  }

  /**
   * Compare plot canvas.
   *
   * @param title  the title
   * @param trials the trials
   * @return the plot canvas
   */
  public static PlotCanvas compareTime(final String title, @Nonnull final ProblemRun... trials) {
    try {
      final DoubleSummaryStatistics[] xStatistics = Arrays.stream(trials)
          .map(x -> x.history.stream().mapToDouble(step -> step.epochTime)
              .filter(Double::isFinite)
              .summaryStatistics()).toArray(i -> new DoubleSummaryStatistics[i]);
      final double totalTime = Arrays.stream(xStatistics).mapToDouble(x -> x.getMax() - x.getMin()).max().getAsDouble();
      final DoubleSummaryStatistics yStatistics = Arrays.stream(trials)
          .flatMapToDouble(x -> x.history.stream().filter(y -> y.fitness > 0).mapToDouble(step -> Math.log10(step.fitness)))
          .filter(Double::isFinite)
          .summaryStatistics();
      if (yStatistics.getCount() == 0) {
        logger.info("No Data");
        return null;
      }
      @Nonnull final double[] lowerBound = {0, yStatistics.getCount() == 0 ? 0 : yStatistics.getMin()};
      @Nonnull final double[] upperBound = {totalTime / 1000.0, yStatistics.getCount() == 1 ? 0 : yStatistics.getMax()};
      @Nonnull final PlotCanvas canvas = new PlotCanvas(lowerBound, upperBound);
      canvas.setTitle(title);
      canvas.setAxisLabels("Time", "log10(Fitness)");
      canvas.setSize(600, 400);
      final List filtered = Arrays.stream(trials).filter(x -> !x.history.isEmpty()).collect(Collectors.toList());
      if (filtered.isEmpty()) {
        logger.info("No Data");
        return null;
      }
      DoubleSummaryStatistics valueStatistics = filtered.stream().flatMap(x -> x.history.stream()).mapToDouble(x -> x.fitness).filter(x -> x > 0).summaryStatistics();
      logger.info(String.format("Plotting range=%s, %s; valueStats=%s", Arrays.toString(lowerBound), Arrays.toString(upperBound), valueStatistics));
      for (int t = 0; t < filtered.size(); t++) {
        final ProblemRun trial = filtered.get(t);
        final DoubleSummaryStatistics trialStats = xStatistics[t];
        final double[][] pts = trial.history.stream().map(step -> {
          return new double[]{(step.epochTime - trialStats.getMin()) / 1000.0, Math.log10(Math.max(step.fitness, valueStatistics.getMin()))};
        }).filter(x -> Arrays.stream(x).allMatch(Double::isFinite))
            .toArray(i -> new double[i][]);
        if (pts.length > 1) {
          logger.info(String.format("Plotting %s points for %s", pts.length, trial.name));
          canvas.add(trial.plot(pts));
        } else {
          logger.info(String.format("Only %s points for %s", pts.length, trial.name));
        }
      }
      return canvas;
    } catch (@Nonnull final Exception e) {
      e.printStackTrace(System.out);
      return null;
    }
  }

  /**
   * Remove performance wrappers.
   *
   * @param log     the logger
   * @param network the network
   */
  public static void extractPerformance(@Nonnull final NotebookOutput log, @Nonnull final DAGNetwork network) {
    log.p("Per-key Performance Metrics:");
    log.run(() -> {
      @Nonnull final Map metrics = new HashMap<>();
      network.visitNodes(node -> {
        if (node.getLayer() instanceof MonitoringWrapperLayer) {
          @Nullable final MonitoringWrapperLayer layer = node.getLayer();
          Layer inner = layer.getInner();
          String str = inner.toString();
          str += " class=" + inner.getClass().getName();
//          if(inner instanceof MultiPrecision) {
//            str += "; precision=" + ((MultiPrecision) inner).getPrecision().name();
//          }
          metrics.put(str, layer);
        }
      });
      TestUtil.logger.info("Performance: \n\t" + metrics.entrySet().stream().sorted(Comparator.comparing(x -> -x.getValue().getForwardPerformance().getMean())).map(e -> {
        @Nonnull final PercentileStatistics performanceF = e.getValue().getForwardPerformance();
        @Nonnull final PercentileStatistics performanceB = e.getValue().getBackwardPerformance();
        return String.format("%.6fs +- %.6fs (%d) <- %s", performanceF.getMean(), performanceF.getStdDev(), performanceF.getCount(), e.getKey()) +
            (performanceB.getCount() == 0 ? "" : String.format("%n\tBack: %.6fs +- %.6fs (%s)", performanceB.getMean(), performanceB.getStdDev(), performanceB.getCount()));
      }).reduce((a, b) -> a + "\n\t" + b).get());
    });
    removeInstrumentation(network);
  }

  /**
   * Remove instrumentation.
   *
   * @param network the network
   */
  public static void removeInstrumentation(@Nonnull final DAGNetwork network) {
    network.visitNodes(node -> {
      if (node.getLayer() instanceof MonitoringWrapperLayer) {
        Layer layer = node.getLayer().getInner();
        layer.addRef();
        node.setLayer(layer);
        layer.freeRef();
      }
    });
  }

  /**
   * Sample performance buildMap.
   *
   * @param network the network
   * @return the buildMap
   */
  public static Map samplePerformance(@Nonnull final DAGNetwork network) {
    @Nonnull final Map metrics = new HashMap<>();
    network.visitLayers(layer -> {
      if (layer instanceof MonitoringWrapperLayer) {
        MonitoringWrapperLayer monitoringWrapperLayer = (MonitoringWrapperLayer) layer;
        Layer inner = monitoringWrapperLayer.getInner();
        String str = inner.toString();
        str += " class=" + inner.getClass().getName();
        HashMap row = new HashMap<>();
        row.put("fwd", monitoringWrapperLayer.getForwardPerformance().getMetrics());
        row.put("rev", monitoringWrapperLayer.getBackwardPerformance().getMetrics());
        metrics.put(str, row);
      }
    });
    return metrics;
  }

  /**
   * Gets monitor.
   *
   * @param history the history
   * @return the monitor
   */
  public static TrainingMonitor getMonitor(@Nonnull final List history) {
    return getMonitor(history, null);
  }

  /**
   * Gets monitor.
   *
   * @param history the history
   * @param network the network
   * @return the monitor
   */
  public static TrainingMonitor getMonitor(@Nonnull final List history, final Layer network) {
    return new TrainingMonitor() {
      @Override
      public void clear() {
        super.clear();
      }

      @Override
      public void log(final String msg) {
        logger.info(msg);
        super.log(msg);
      }

      @Override
      public void onStepComplete(@Nonnull final Step currentPoint) {
        history.add(new StepRecord(currentPoint.point.getMean(), currentPoint.time, currentPoint.iteration));
        super.onStepComplete(currentPoint);
      }
    };
  }

  /**
   * Add performance wrappers.
   *
   * @param network the network
   */
  public static void instrumentPerformance(@Nonnull final DAGNetwork network) {
    network.visitNodes(node -> {
      Layer layer = node.getLayer();
      if (layer instanceof MonitoringWrapperLayer) {
        ((MonitoringWrapperLayer) layer).shouldRecordSignalMetrics(false);
      } else {
        @Nonnull MonitoringWrapperLayer monitoringWrapperLayer = new MonitoringWrapperLayer(layer).shouldRecordSignalMetrics(false);
        node.setLayer(monitoringWrapperLayer);
        monitoringWrapperLayer.freeRef();
      }
    });
  }

  /**
   * Plot plot canvas.
   *
   * @param history the history
   * @return the plot canvas
   */
  public static JPanel plot(@Nonnull final List history) {
    try {
      final DoubleSummaryStatistics valueStats = history.stream().mapToDouble(x -> x.fitness).summaryStatistics();
      double min = valueStats.getMin();
      if (0 < min) {
        double[][] data = history.stream().map(step -> new double[]{
            step.iteration, Math.log10(Math.max(min, step.fitness))})
            .filter(x -> Arrays.stream(x).allMatch(Double::isFinite))
            .toArray(i -> new double[i][]);
        if (Arrays.stream(data).mapToInt(x -> x.length).sum() == 0) return null;
        @Nonnull final PlotCanvas plot = ScatterPlot.plot(data);
        plot.setTitle("Convergence Plot");
        plot.setAxisLabels("Iteration", "log10(Fitness)");
        plot.setSize(600, 400);
        return plot;
      } else {
        double[][] data = history.stream().map(step -> new double[]{
            step.iteration, step.fitness})
            .filter(x -> Arrays.stream(x).allMatch(Double::isFinite))
            .toArray(i -> new double[i][]);
        if (Arrays.stream(data).mapToInt(x -> x.length).sum() == 0) return null;
        @Nonnull final PlotCanvas plot = ScatterPlot.plot(data);
        plot.setTitle("Convergence Plot");
        plot.setAxisLabels("Iteration", "Fitness");
        plot.setSize(600, 400);
        return plot;
      }
    } catch (@Nonnull final Exception e) {
      logger.warn("Error plotting", e);
      return null;
    }
  }

  /**
   * Plot plot canvas.
   *
   * @param history the history
   * @return the plot canvas
   */
  public static PlotCanvas plotTime(@Nonnull final List history) {
    try {
      final LongSummaryStatistics timeStats = history.stream().mapToLong(x -> x.epochTime).summaryStatistics();
      final DoubleSummaryStatistics valueStats = history.stream().mapToDouble(x -> x.fitness).filter(x -> x > 0).summaryStatistics();
      @Nonnull final PlotCanvas plot = ScatterPlot.plot(history.stream().map(step -> new double[]{
          (step.epochTime - timeStats.getMin()) / 1000.0, Math.log10(Math.max(valueStats.getMin(), step.fitness))})
          .filter(x -> Arrays.stream(x).allMatch(Double::isFinite))
          .toArray(i -> new double[i][]));
      plot.setTitle("Convergence Plot");
      plot.setAxisLabels("Time", "log10(Fitness)");
      plot.setSize(600, 400);
      return plot;
    } catch (@Nonnull final Exception e) {
      e.printStackTrace(System.out);
      return null;
    }
  }

  /**
   * Print data statistics.
   *
   * @param log  the logger
   * @param data the data
   */
  public static void printDataStatistics(@Nonnull final NotebookOutput log, @Nonnull final Tensor[][] data) {
    for (int col = 1; col < data[0].length; col++) {
      final int c = col;
      log.out("Learned Representation Statistics for Column " + col + " (all bands)");
      log.eval(() -> {
        @Nonnull final ScalarStatistics scalarStatistics = new ScalarStatistics();
        Arrays.stream(data)
            .flatMapToDouble(row -> Arrays.stream(row[c].getData()))
            .forEach(v -> scalarStatistics.add(v));
        return scalarStatistics.getMetrics();
      });
      final int _col = col;
      log.out("Learned Representation Statistics for Column " + col + " (by band)");
      log.eval(() -> {
        @Nonnull final int[] dimensions = data[0][_col].getDimensions();
        return IntStream.range(0, dimensions[2]).mapToObj(x -> x).flatMap(b -> {
          return Arrays.stream(data).map(r -> r[_col]).map(tensor -> {
            @Nonnull final ScalarStatistics scalarStatistics = new ScalarStatistics();
            scalarStatistics.add(new Tensor(dimensions[0], dimensions[1]).setByCoord(coord -> tensor.get(coord.getCoords()[0], coord.getCoords()[1], b)).getData());
            return scalarStatistics;
          });
        }).map(x -> x.getMetrics().toString()).reduce((a, b) -> a + "\n" + b).get();
      });
    }
  }

  /**
   * Print history.
   *
   * @param log     the logger
   * @param history the history
   */
  public static void printHistory(@Nonnull final NotebookOutput log, @Nonnull final List history) {
    if (!history.isEmpty()) {
      log.out("Convergence Plot: ");
      log.eval(() -> {
        final DoubleSummaryStatistics valueStats = history.stream().mapToDouble(x -> x.fitness).filter(x -> x > 0).summaryStatistics();
        @Nonnull final PlotCanvas plot = ScatterPlot.plot(history.stream().map(step ->
            new double[]{step.iteration, Math.log10(Math.max(valueStats.getMin(), step.fitness))})
            .toArray(i -> new double[i][]));
        plot.setTitle("Convergence Plot");
        plot.setAxisLabels("Iteration", "log10(Fitness)");
        plot.setSize(600, 400);
        return plot;
      });
    }
  }

  /**
   * Remove monitoring.
   *
   * @param network the network
   */
  public static void removeLogging(@Nonnull final DAGNetwork network) {
    network.visitNodes(node -> {
      if (node.getLayer() instanceof LoggingWrapperLayer) {
        node.setLayer(((LoggingWrapperLayer) node.getLayer()).getInner());
      }
    });
  }

  /**
   * Remove monitoring.
   *
   * @param network the network
   */
  public static void removeMonitoring(@Nonnull final DAGNetwork network) {
    network.visitNodes(node -> {
      if (node.getLayer() instanceof MonitoringWrapperLayer) {
        node.setLayer(((MonitoringWrapperLayer) node.getLayer()).getInner());
      }
    });
  }

  /**
   * Render string.
   *
   * @param log       the logger
   * @param tensor    the tensor
   * @param normalize the normalize
   * @return the string
   */
  public static CharSequence render(@Nonnull final NotebookOutput log, @Nonnull final Tensor tensor, final boolean normalize) {
    return TestUtil.renderToImages(tensor, normalize).map(image -> {
      return log.png(image, "");
    }).reduce((a, b) -> a + b).get();
  }

  /**
   * Render to images stream.
   *
   * @param tensor    the tensor
   * @param normalize the normalize
   * @return the stream
   */
  public static Stream renderToImages(@Nonnull final Tensor tensor, final boolean normalize) {
    final DoubleStatistics[] statistics = IntStream.range(0, tensor.getDimensions()[2]).mapToObj(band -> {
      return new DoubleStatistics().accept(tensor.coordStream(false)
          .filter(x -> x.getCoords()[2] == band)
          .mapToDouble(c -> tensor.get(c)).toArray());
    }).toArray(i -> new DoubleStatistics[i]);
    @Nonnull final BiFunction transform = (value, stats) -> {
      final double width = Math.sqrt(2) * stats.getStandardDeviation();
      final double centered = value - stats.getAverage();
      final double distance = Math.abs(value - stats.getAverage());
      final double positiveMax = stats.getMax() - stats.getAverage();
      final double negativeMax = stats.getAverage() - stats.getMin();
      final double unitValue;
      if (value < centered) {
        if (distance > width) {
          unitValue = 0.25 - 0.25 * ((distance - width) / (negativeMax - width));
        } else {
          unitValue = 0.5 - 0.25 * (distance / width);
        }
      } else {
        if (distance > width) {
          unitValue = 0.75 + 0.25 * ((distance - width) / (positiveMax - width));
        } else {
          unitValue = 0.5 + 0.25 * (distance / width);
        }
      }
      return 0xFF * unitValue;
    };
    tensor.coordStream(true).collect(Collectors.groupingBy(x -> x.getCoords()[2], Collectors.toList()));
    @Nullable final Tensor normal = tensor.mapCoords((c) -> transform.apply(tensor.get(c), statistics[c.getCoords()[2]]))
        .map(v -> Math.min(0xFF, Math.max(0, v)));
    return (normalize ? normal : tensor).toImages().stream();
  }

  /**
   * Resize buffered png.
   *
   * @param source the source
   * @param size   the size
   * @return the buffered png
   */
  @Nonnull
  public static BufferedImage resize(@Nonnull final BufferedImage source, final int size) {
    return resize(source, size, false);
  }

  /**
   * Resize buffered png.
   *
   * @param source         the source
   * @param size           the size
   * @param preserveAspect the preserve aspect
   * @return the buffered png
   */
  @Nonnull
  public static BufferedImage resize(@Nonnull final BufferedImage source, final int size, boolean preserveAspect) {
    if (size <= 0) return source;
    double zoom = (double) size / source.getWidth();
    int steps = (int) Math.ceil(Math.abs(Math.log(zoom)) / Math.log(1.5));
    BufferedImage img = source;
    for (int i = 1; i <= steps; i++) {
      double pos = ((double) i / steps);
      double z = Math.pow(zoom, pos);
      img = resize(img,
          (int) (source.getWidth() * z),
          (int) ((preserveAspect ? source.getHeight() : source.getWidth()) * z)
      );
    }
    return img;
  }

  /**
   * Resize px buffered png.
   *
   * @param source the source
   * @param size   the size
   * @return the buffered png
   */
  public static BufferedImage resizePx(@Nonnull final BufferedImage source, final long size) {
    if (size < 0) return source;
    double scale = Math.sqrt(size / ((double) source.getHeight() * source.getWidth()));
    int width = (int) (scale * source.getWidth());
    int height = (int) (scale * source.getHeight());
    return resize(source, width, height);
  }

  /**
   * Resize buffered png.
   *
   * @param source the source
   * @param width  the width
   * @param height the height
   * @return the buffered png
   */
  @Nonnull
  public static BufferedImage resize(BufferedImage source, int width, int height) {
    @Nonnull final BufferedImage image = new BufferedImage(width, height, source.getType());
    @Nonnull final Graphics2D graphics = (Graphics2D) image.getGraphics();
    HashMap hints = new HashMap<>();
    hints.put(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BICUBIC);
    hints.put(RenderingHints.KEY_ALPHA_INTERPOLATION, RenderingHints.VALUE_ALPHA_INTERPOLATION_QUALITY);
    hints.put(RenderingHints.KEY_COLOR_RENDERING, RenderingHints.VALUE_COLOR_RENDER_QUALITY);
    hints.put(RenderingHints.KEY_RENDERING, RenderingHints.VALUE_RENDER_QUALITY);
    graphics.setRenderingHints(hints);
    graphics.drawImage(source, 0, 0, width, height, null);
    return image;
  }

  /**
   * To formatted json string.
   *
   * @param metrics the metrics
   * @return the string
   */
  public static CharSequence toFormattedJson(final Object metrics) {
    try {
      @Nonnull final ByteArrayOutputStream out = new ByteArrayOutputStream();
      JsonUtil.getMapper().writeValue(out, metrics);
      return out.toString();
    } catch (@Nonnull final IOException e1) {
      throw new RuntimeException(e1);
    }
  }

  public static Object toGraph(@Nonnull final DAGNetwork network) {
    return toGraph(network, TestUtil::getName);
  }

  public static void graph(@Nonnull final NotebookOutput log, @Nonnull final DAGNetwork network) {
    Graphviz graphviz = Graphviz.fromGraph((Graph) toGraph(network, node -> {
      Layer layer = node.getLayer();
      if (null != layer) {
        String name = layer.getName();
        if (name.endsWith("Layer")) return name.substring(0, name.length() - 5);
        else return name;
      } else {
        return "Input " + node.getNetwork().inputHandles.indexOf(node.getId());
      }
    }));
    log.out("\n" + log.png(graphviz.height(400).width(600).render(Format.PNG).toImage(), "Configuration Graph") + "\n");
    log.out("\n" + log.svg(graphviz.height(400).width(600).render(Format.SVG_STANDALONE).toString(), "Configuration Graph") + "\n");
  }

  public static Object toGraph(@Nonnull final DAGNetwork network, Function fn) {
    final List nodes = network.getNodes();
    final Map graphNodes = nodes.stream().collect(Collectors.toMap(node -> node.getId(), node -> {
      String name = fn.apply(node);
      return Factory.mutNode(Label.html(name + ""));
    }));
    final Stream stream = nodes.stream().flatMap(to -> {
      return Arrays.stream(to.getInputs()).map(from -> {
        return new UUID[]{from.getId(), to.getId()};
      });
    });
    final Map> idMap = stream.collect(Collectors.groupingBy(x -> x[0],
        Collectors.mapping(x -> x[1], Collectors.toList())));
    nodes.forEach(to -> {
      graphNodes.get(to.getId()).addLink(
          idMap.getOrDefault(to.getId(), Arrays.asList()).stream().map(from -> {
            return Link.to(graphNodes.get(from));
          }).toArray(i -> new LinkTarget[i]));
    });
    final LinkSource[] nodeArray = graphNodes.values().stream().map(x -> (LinkSource) x).toArray(i -> new LinkSource[i]);
    return Factory.graph().with(nodeArray).graphAttr().with(RankDir.TOP_TO_BOTTOM).directed();
  }

  @NotNull
  public static String getName(DAGNode node) {
    String name;
    @Nullable final Layer layer = node.getLayer();
    if (null == layer) {
      name = node.getId().toString();
    } else {
      final Class layerClass = layer.getClass();
      name = layerClass.getSimpleName() + "\n" + layer.getId();
    }
    return name;
  }

  /**
   * Shuffle int stream.
   *
   * @param stream the stream
   * @return the int stream
   */
  public static IntStream shuffle(@Nonnull IntStream stream) {
    // http://primes.utm.edu/lists/small/10000.txt
    long coprimeA = 41387;
    long coprimeB = 9967;
    long ringSize = coprimeA * coprimeB - 1;
    @Nonnull IntToLongFunction fn = x -> (x * coprimeA * coprimeA) % ringSize;
    @Nonnull LongToIntFunction inv = x -> (int) ((x * coprimeB * coprimeB) % ringSize);
    @Nonnull IntUnaryOperator conditions = x -> {
      assert x < ringSize;
      assert x >= 0;
      return x;
    };
    return stream.map(conditions).mapToLong(fn).sorted().mapToInt(inv);
  }

  /**
   * Or else supplier.
   *
   * @param        the type parameter
   * @param suppliers the suppliers
   * @return the supplier
   */
  public static  Supplier orElse(@Nonnull Supplier... suppliers) {
    return () -> {
      for (@Nonnull Supplier supplier : suppliers) {
        T t = supplier.get();
        if (null != t) return t;
      }
      return null;
    };
  }

  /**
   * Mini stack trace string.
   *
   * @return the string
   */
  public static CharSequence miniStackTrace() {
    int max = 30;
    StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace();
    List list = Arrays.stream(stackTrace).skip(3).limit(max - 3).map(x -> x.isNativeMethod() ? "(Native Method)" :
        (x.getFileName() != null && x.getLineNumber() >= 0 ?
            x.getFileName() + ":" + x.getLineNumber() :
            (x.getFileName() != null ? x.getFileName() : "(Unknown Source)"))).collect(Collectors.toList());
    return "[" + list.stream().reduce((a, b) -> a + ", " + b).get() + (stackTrace.length > max ? ", ..." : "") + "]";
  }

  /**
   * Monitor ui.
   *
   * @param input       the input
   * @param exitOnClose the exit on close
   * @param normalize   the normalize
   */
  public static void monitorImage(final Tensor input, final boolean exitOnClose, final boolean normalize) {
    monitorImage(input, exitOnClose, 30, normalize);
  }

  /**
   * Monitor ui.
   *
   * @param input       the input
   * @param exitOnClose the exit on close
   * @param period      the period
   * @param normalize   the normalize
   */
  public static void monitorImage(final Tensor input, final boolean exitOnClose, final int period, final boolean normalize) {
    if (GraphicsEnvironment.isHeadless() || !Desktop.isDesktopSupported() || !Desktop.getDesktop().isSupported(Desktop.Action.BROWSE))
      return;
    JLabel label = new JLabel(new ImageIcon(input.toImage()));
    final AtomicReference dialog = new AtomicReference();
    WeakReference labelWeakReference = new WeakReference<>(label);
    ScheduledFuture updater = scheduledThreadPool.scheduleAtFixedRate(() -> {
      try {
        JLabel jLabel = labelWeakReference.get();
        if (null != jLabel && !input.isFinalized()) {
          BufferedImage image = (normalize ? normalizeBands(input) : input).toImage();
          int width = jLabel.getWidth();
          if (width > 0) TestUtil.resize(image, width, jLabel.getHeight());
          jLabel.setIcon(new ImageIcon(image));
          return;
        }
      } catch (Throwable e) {
        logger.warn("Error updating png", e);
      }
      JDialog jDialog = dialog.get();
      jDialog.setVisible(false);
      jDialog.dispose();
    }, 0, period, TimeUnit.SECONDS);
    new Thread(() -> {
      Window window = JOptionPane.getRootFrame();
      String title = "Image: " + Arrays.toString(input.getDimensions());
      if (window instanceof Frame) {
        dialog.set(new JDialog((Frame) window, title, true));
      } else {
        dialog.set(new JDialog((Dialog) window, title, true));
      }
      dialog.get().setResizable(false);
      dialog.get().setSize(input.getDimensions()[0], input.getDimensions()[1]);
      JMenuBar menu = new JMenuBar();
      JMenu fileMenu = new JMenu("File");
      JMenuItem saveAction = new JMenuItem("Save");
      fileMenu.add(saveAction);
      saveAction.addActionListener(new ActionListener() {
        @Override
        public void actionPerformed(final ActionEvent e) {
          JFileChooser fileChooser = new JFileChooser();
          fileChooser.setAcceptAllFileFilterUsed(false);
          fileChooser.addChoosableFileFilter(new FileFilter() {
            @Override
            public boolean accept(final File f) {
              return f.getName().toUpperCase().endsWith(".PNG");
            }

            @Override
            public String getDescription() {
              return "*.png";
            }
          });

          int result = fileChooser.showSaveDialog(dialog.get());
          if (JFileChooser.APPROVE_OPTION == result) {
            try {
              File selectedFile = fileChooser.getSelectedFile();
              if (!selectedFile.getName().toUpperCase().endsWith(".PNG"))
                selectedFile = new File(selectedFile.getParent(), selectedFile.getName() + ".png");
              BufferedImage image = (normalize ? normalizeBands(input) : input).toImage();
              if (!ImageIO.write(image, "PNG", selectedFile)) throw new IllegalArgumentException();
            } catch (IOException e1) {
              throw new RuntimeException(e1);
            }
          }
        }
      });
      menu.add(fileMenu);
      dialog.get().setJMenuBar(menu);

      Container contentPane = dialog.get().getContentPane();
      contentPane.setLayout(new BorderLayout());
      contentPane.add(label, BorderLayout.CENTER);
      //contentPane.add(dialog, BorderLayout.CENTER);
      if (JDialog.isDefaultLookAndFeelDecorated()) {
        boolean supportsWindowDecorations = UIManager.getLookAndFeel().getSupportsWindowDecorations();
        if (supportsWindowDecorations) {
          dialog.get().setUndecorated(true);
          SwingUtilities.getRootPane(dialog.get()).setWindowDecorationStyle(JRootPane.PLAIN_DIALOG);
        }
      }
      dialog.get().pack();
      dialog.get().setLocationRelativeTo(null);
      dialog.get().addComponentListener(new ComponentAdapter() {
        @Override
        public void componentResized(final ComponentEvent e) {
          //dialog.pack();
          super.componentResized(e);
          BufferedImage image = input.toImage();
          int width = e.getComponent().getWidth();
          if (width > 0) TestUtil.resize(image, width, e.getComponent().getHeight());
          label.setIcon(new ImageIcon(image));
          dialog.get().pack();
        }
      });
      dialog.get().addWindowListener(new WindowAdapter() {
        private boolean gotFocus = false;

        public void windowClosed(WindowEvent e) {
          dialog.get().getContentPane().removeAll();
          updater.cancel(false);
          if (exitOnClose) {
            logger.warn("Exiting test", new RuntimeException("Stack Trace"));
            System.exit(0);
          }
        }

        public void windowGainedFocus(WindowEvent we) {
          // Once window gets focus, set initial focus
          if (!gotFocus) {
            gotFocus = true;
          }
        }

      });
      dialog.get().setVisible(true);
      dialog.get().dispose();
    }).start();
  }

  /**
   * Normalize bands tensor.
   *
   * @param image the png
   * @return the tensor
   */
  public static Tensor normalizeBands(final Tensor image) {
    return normalizeBands(image, 255);
  }

  /**
   * Normalize bands tensor.
   *
   * @param image the png
   * @param max   the max
   * @return the tensor
   */
  public static Tensor normalizeBands(final Tensor image, final int max) {
    DoubleStatistics[] statistics = IntStream.range(0, image.getDimensions()[2]).mapToObj(i -> new DoubleStatistics()).toArray(i -> new DoubleStatistics[i]);
    image.coordStream(false).forEach(c -> {
      double value = image.get(c);
      statistics[c.getCoords()[2]].accept(value);
    });
    return image.mapCoords(c -> {
      double value = image.get(c);
      DoubleStatistics statistic = statistics[c.getCoords()[2]];
      return max * (value - statistic.getMin()) / (statistic.getMax() - statistic.getMin());
    });
  }

  /**
   * Animated gif char sequence.
   *
   * @param log    the logger
   * @param images the images
   * @return the char sequence
   */
  public static CharSequence animatedGif(@Nonnull final NotebookOutput log, @Nonnull final BufferedImage... images) {
    return animatedGif(log, 15000, images);
  }

  /**
   * Animated gif char sequence.
   *
   * @param log        the logger
   * @param loopTimeMs the loop time ms
   * @param images     the images
   * @return the char sequence
   */
  public static CharSequence animatedGif(@Nonnull final NotebookOutput log, final int loopTimeMs, @Nonnull final BufferedImage... images) {
    try {
      @Nonnull String filename = gifNumber++ + ".gif";
      @Nonnull File file = new File(log.getResourceDir(), filename);
      GifSequenceWriter.write(file, loopTimeMs / images.length, true, images);
      return String.format("", filename);
    } catch (IOException e) {
      throw new RuntimeException(e);
    }
  }

  /**
   * Write gif.
   *
   * @param log         the logger
   * @param imageStream the png stream
   */
  public static void writeGif(@Nonnull final NotebookOutput log, final Stream imageStream) {
    BufferedImage[] imgs = imageStream.toArray(i -> new BufferedImage[i]);
    log.p("Animated Sequence:");
    log.p(animatedGif(log, imgs));
  }

  /**
   * Build map map.
   *
   * @param        the type parameter
   * @param        the type parameter
   * @param configure the configure
   * @return the map
   */
  @Nonnull
  public static  Map buildMap(Consumer> configure) {
    Map map = new HashMap<>();
    configure.accept(map);
    return map;
  }

  /**
   * Geometric stream supplier.
   *
   * @param start the start
   * @param end   the end
   * @param steps the steps
   * @return the supplier
   */
  @Nonnull
  public static Supplier geometricStream(final double start, final double end, final int steps) {
    double step = Math.pow(end / start, 1.0 / (steps - 1));
    return () -> DoubleStream.iterate(start, x -> x * step).limit(steps);
  }

  /**
   * Arithmetic stream supplier.
   *
   * @param start the start
   * @param end   the end
   * @param steps the steps
   * @return the supplier
   */
  @Nonnull
  public static Supplier arithmeticStream(final double start, final double end, final int steps) {
    double step = Math.pow(end - start, 1.0 / steps);
    return () -> DoubleStream.iterate(start, x -> x + step).limit(steps);
  }

  /**
   * Constant stream supplier.
   *
   * @param values the values
   * @return the supplier
   */
  @Nonnull
  public static Supplier constantStream(final double... values) {
    return () -> Arrays.stream(values);
  }

  /**
   * Shuffle list.
   *
   * @param   the type parameter
   * @param list the list
   * @return the list
   */
  public static  List shuffle(final List list) {
    ArrayList copy = new ArrayList<>(list);
    Collections.shuffle(copy);
    return copy;
  }

  /**
   * Add global handlers.
   *
   * @param httpd the httpd
   */
  public static void addGlobalHandlers(final FileHTTPD httpd) {
    if (null != httpd) {
//      httpd.addGET("gpu.json", "text/json", out -> {
//        try {
//          JsonUtil.getMapper().writer().writeValue(out, CudaSystem.getExecutionStatistics());
//          //JsonUtil.MAPPER.writer().writeValue(out, new HashMap<>());
//          out.close();
//        } catch (IOException e) {
//          throw new RuntimeException(e);
//        }
//      });
      httpd.addGET("threads.json", "text/json", out -> {
        try {
          JsonUtil.getMapper().writer().writeValue(out, getStackInfo());
          //JsonUtil.MAPPER.writer().writeValue(out, new HashMap<>());
          out.close();
        } catch (IOException e) {
          throw new RuntimeException(e);
        }
      });
    }
  }

  /**
   * Gets stack info.
   *
   * @return the stack info
   */
  public static Map> getStackInfo() {
    return Thread.getAllStackTraces().entrySet().stream().collect(Collectors.toMap(entry -> {
      Thread key = entry.getKey();
      return String.format("%s@%d", key.getName(), key.getId());
    }, entry -> {
      return Arrays.stream(entry.getValue()).map(StackTraceElement::toString).collect(Collectors.toList());
    }));
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy