All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.test.MermaidGrapher Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2020 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.test;

import com.google.gson.GsonBuilder;
import com.google.gson.JsonElement;
import com.simiacryptus.mindseye.lang.Layer;
import com.simiacryptus.mindseye.lang.SerialPrecision;
import com.simiacryptus.mindseye.layers.WrapperLayer;
import com.simiacryptus.mindseye.network.DAGNetwork;
import com.simiacryptus.mindseye.network.DAGNode;
import com.simiacryptus.notebook.NotebookOutput;
import com.simiacryptus.ref.lang.RefAware;
import com.simiacryptus.ref.lang.RefIgnore;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.lang.ReferenceCountingBase;
import com.simiacryptus.ref.wrappers.*;
import org.apache.commons.io.FileUtils;
import org.jetbrains.annotations.NotNull;

import javax.annotation.Nonnull;
import java.io.File;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class MermaidGrapher {
  private final NotebookOutput log;
  private final boolean writeZip;

  public MermaidGrapher(NotebookOutput log, boolean writeZip) {
    this.log = log;
    this.writeZip = writeZip;
  }

  public static void writeZip(NotebookOutput log, Layer layer) {
    File file = new File(log.getResourceDir(), log.getFileName() + "_" + layer.getId() + ".zip");
    layer.writeZip(file, SerialPrecision.Double);
    layer.freeRef();
    log.p(log.link(file, RefString.format("Layer Zip (%.3fMiB bytes)", file.length() * 1.0 / 0x100000)));
  }

  public static void writeJson(NotebookOutput log, Layer layer) {
    File file = new File(log.getResourceDir(), log.getFileName() + "_" + layer.getId() + ".json");
    try {
      FileUtils.writeStringToFile(file, compactJson(layer), "UTF-8");
    } catch (IOException e) {
      throw new RuntimeException(e);
    }
    log.p(log.link(file, RefString.format("Layer Json (%.3fMiB bytes)", file.length() * 1.0 / 0x100000)));
  }

  @NotNull
  public static String getLabel(@Nonnull @RefIgnore DAGNetwork network, DAGNode node) {
    Layer layer = node.getLayer();
    UUID id = node.getId();
    node.freeRef();
    if (null != layer) {
      String name = layer.getName();
      layer.freeRef();
      assert name != null;
      if (name.endsWith("Layer")) {
        return name.substring(0, name.length() - 5);
      } else {
        return name;
      }
    } else {
      assert network != null;
      return "Input " + network.inputHandles.indexOf(id);
    }
  }

  public static Map> connections(RefList nodes) {
    Map> collect = Arrays.stream(nodes.stream().flatMap(to -> {
      UUID toId = to.getId();
      @Nonnull DAGNode[] inputs = to.getInputs();
      to.freeRef();
      return RefArrays.stream(inputs).map(from -> new UUID[]{getUuid(from), toId});
    }).toArray(UUID[][]::new)).collect(Collectors.groupingBy(x -> x[0], Collectors.mapping(x -> x[1], Collectors.toList())));
    nodes.freeRef();
    return collect;
  }

  public static UUID getUuid(DAGNode node) {
    UUID id = node.getId();
    node.freeRef();
    return id;
  }

  private static String compactJson(Layer layer) {
    if (layer instanceof WrapperLayer) {
      Layer copy = layer.copy();
      layer.freeRef();
      layer = copy;
      ((WrapperLayer) layer).setInner(null);
    }
    JsonElement compactJson = layer.getJson(new HashMap<>(), SerialPrecision.Double);
    layer.freeRef();
    return new GsonBuilder().setPrettyPrinting().create().toJson(compactJson);
  }

  @NotNull
  private static String getLabel(@Nonnull @RefIgnore Layer layer) {
    String name = layer.getName();
    assert name != null;
    if (name.endsWith("Layer")) {
      return name.substring(0, name.length() - 5);
    } else {
      return name;
    }
  }

  private static String idString(@RefIgnore Layer l) {
    return l.getId().toString();
  }

  @NotNull
  public String toString(UUID toId) {
    return toId.toString().replaceAll("-", "");
  }

  public void mermaid(DAGNetwork network) {
    log.subreport("Network Diagrams for " + network.getName(), (NotebookOutput sublog) -> {
      List logged = new ArrayList<>();
      RefHashMap found = new RefHashMap<>();
      UUID networkId = network.getId();
      RefUtil.freeRef(found.put(networkId, network.addRef()));
      network.visitNodes(dagNode -> {
        Layer layer = dagNode.getLayer();
        dagNode.freeRef();
        if (null != layer) RefUtil.freeRef(found.put(layer.getId(), layer));
        else layer.freeRef();
      });
      RefHashSet values = found.values();
      found.freeRef();
      values.stream()
          .collect(Collectors.groupingBy(Layer::getClass, Collectors.toList()))
          .entrySet().stream().sorted(Comparator.comparingDouble(x -> {
        Class key = x.getKey();
        RefUtil.freeRef(x);
        if (network.getClass().equals(key)) return -1.0;
        if (DAGNetwork.class.isAssignableFrom(key)) return 0.0;
        if (WrapperLayer.class.isAssignableFrom(key)) return 0.5;
        return 1.0;
      })).forEach(e -> {
        String type = e.getKey().getSimpleName();
        sublog.h1(type);
        out(sublog, logged, networkId, e.getValue());
        RefUtil.freeRef(e);
      });
      values.freeRef();
      return null;
    });
    network.freeRef();
  }

  private void out(NotebookOutput sublog, List logged, UUID networkId, List value) {
    value.stream()
        .filter((@RefIgnore Layer x) -> x.getId().equals(networkId))
        .sorted(Comparator.comparing(Layer::getName).thenComparing(MermaidGrapher::idString))
        .forEach(layer -> out(sublog, logged, layer, writeZip));
    value.stream()
        .filter((@RefIgnore Layer x) -> !x.getId().equals(networkId))
        .sorted(Comparator.comparing(Layer::getName).thenComparing(MermaidGrapher::idString))
        .forEach(layer -> out(sublog, logged, layer, writeZip));
    RefUtil.freeRef(value);
  }

  protected String describeNode(MermaidNode mermaidNode) {
    String name = nodeHtml((MermaidNode) mermaidNode.addRef());
    name = "\"" + name.replaceAll("\\\"", "\\\\\\\"") + "\"";
    return nodeID(mermaidNode) + "(" + name + ")";
  }

  @NotNull
  protected String nodeID(MermaidNode mermaidNode) {
    String str = toString(mermaidNode.id);
    mermaidNode.freeRef();
    return str;
  }

  protected String nodeHtml(MermaidNode mermaidNode) {
    String name = mermaidNode.name;
    if (null != mermaidNode.layer) name += " 
" + mermaidNode.layer.getId(); if (mermaidNode.layer != null) { name = "" + name + ""; } mermaidNode.freeRef(); return name; } private void out(NotebookOutput log, List logged, Layer layer, boolean writeZip) { UUID id = layer.getId(); String label = getLabel(layer); log.out(String.format("", toString(id))); log.h2(label); log.p("ID: " + id.toString()); log.p("Class: " + layer.getClass().getName()); if (writeZip) writeZip(log, layer.addRef()); add(log, layer); logged.add(id); } private void add(NotebookOutput log, Layer layer) { if (layer instanceof DAGNetwork) { writeJson(log, layer.addRef()); String src = toMermaid((DAGNetwork) layer); if (Arrays.stream(src.split("\n")).map(String::trim).filter(x -> !x.isEmpty()).count() > 1) { log.mermaid(src); } else { log.p("Trivial Graph"); } } else if (layer instanceof WrapperLayer) { log.p("Wrapper Type: " + layer.getClass().getSimpleName()); //writeJson(log, layer); Layer inner = ((WrapperLayer) layer).getInner(); log.out("\n\n```json\n" + compactJson(layer) + "\n```\n\n"); add(log, inner); } else { writeJson(log, layer); //log.out("\n\n```json\n" + compactJson(layer) + "\n```\n\n"); } } @NotNull private String toMermaid(@Nonnull DAGNetwork network) { RefFunction fn = RefUtil.wrapInterface(node -> getLabel(network, node), network.addRef()); return toMermaid(network, fn); } @Nonnull private String toMermaid(@Nonnull final DAGNetwork network, @RefAware @Nonnull RefFunction fn) { final RefList nodes = network.getNodes(); network.freeRef(); final RefMap graphNodes = nodes.stream().collect(RefCollectors.toMap( node -> getUuid(node), node -> new MermaidNode(fn.apply(node), node.getId(), node.getLayer()) )); RefUtil.freeRef(fn); final Map> connections = connections(nodes.addRef()); try { assert graphNodes != null; return "graph TB\n\t" + nodes.stream().flatMap(from -> { UUID fromId = from.getId(); List toList = connections.computeIfAbsent(fromId, a1 -> Arrays.asList()); try { if (toList != null) { MermaidNode fromNode = graphNodes.get(fromId); String describe = fromNode.describe(); fromNode.freeRef(); return toList.stream().map(toId -> { MermaidNode toNode = graphNodes.get(toId); String str = String.format("%s-->%s\n", describe, toNode.describe()); toNode.freeRef(); return str; }); } else { return Stream.empty(); } } finally { from.freeRef(); } }).map(String::trim).filter(x -> !x.isEmpty()) .reduce((a, b) -> a + "\n\t" + b).orElse(""); } finally { nodes.freeRef(); graphNodes.freeRef(); } } protected final class MermaidNode extends ReferenceCountingBase { public final String name; public final UUID id; public final AtomicBoolean described = new AtomicBoolean(false); public final Layer layer; protected MermaidNode(String name, UUID id, Layer layer) { this.name = name; this.id = id; this.layer = layer; } @NotNull public String describe() { return this.described.getAndSet(true) ? nodeID((MermaidNode) this.addRef()) : describeNode((MermaidNode) this.addRef()); } @Override protected void _free() { if (null != layer) layer.freeRef(); super._free(); } } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy