All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.test.unit.StandardLayerTests Maven / Gradle / Ivy

There is a newer version: 2.1.0
Show newest version
/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.test.unit;

import com.google.gson.JsonObject;
import com.simiacryptus.devutil.Javadoc;
import com.simiacryptus.lang.ref.LifecycleException;
import com.simiacryptus.lang.ref.ReferenceCounting;
import com.simiacryptus.lang.ref.ReferenceCountingBase;
import com.simiacryptus.mindseye.lang.*;
import com.simiacryptus.mindseye.layers.Explodable;
import com.simiacryptus.mindseye.network.DAGNetwork;
import com.simiacryptus.mindseye.test.NotebookReportBase;
import com.simiacryptus.mindseye.test.TestUtil;
import com.simiacryptus.mindseye.test.ToleranceStatistics;
import com.simiacryptus.notebook.NotebookOutput;
import com.simiacryptus.notebook.TableOutput;
import com.simiacryptus.util.IOUtil;
import com.simiacryptus.util.test.SysOutInterceptor;
import guru.nidi.graphviz.engine.Format;
import guru.nidi.graphviz.engine.Graphviz;
import guru.nidi.graphviz.model.Graph;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.io.File;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * The type LayerBase apply base.
 */
public abstract class StandardLayerTests extends NotebookReportBase {
  /**
   * The constant seed.
   */
  public static final long seed = 51389; //System.nanoTime();
  private static final HashMap> javadocs = loadJavadoc();

  static {
    SysOutInterceptor.INSTANCE.init();
  }

  private final Random random = getRandom();
  /**
   * The Testing batch size.
   */
  protected int testingBatchSize = 5;
  /**
   * The Validate batch execution.
   */
  protected boolean validateBatchExecution = true;
  /**
   * The Validate differentials.
   */
  protected boolean validateDifferentials = true;
  /**
   * The Test training.
   */
  protected boolean testTraining = true;
  /**
   * The Test equivalency.
   */
  protected boolean testEquivalency = true;
  /**
   * The Tolerance.
   */
  protected double tolerance;

  /**
   * Instantiates a new Standard key tests.
   */
  public StandardLayerTests() {
    logger.info("Seed: " + seed);
    tolerance = 1e-3;
  }

  @Nonnull
  private static HashMap> loadJavadoc() {
    try {
      HashMap> javadocData = Javadoc.loadModelSummary();
      IOUtil.writeJson(new TreeMap<>(javadocData), new File("./javadoc.json"));
      return javadocData;
    } catch (Throwable e) {
      logger.warn("Error loading javadocs", e);
      return new HashMap<>();
    }
  }

  /**
   * Gets batching tester.
   *
   * @return the batching tester
   */
  @Nullable
  public ComponentTest getBatchingTester() {
    if (!validateBatchExecution) return null;
    return new BatchingTester(1e-2, validateDifferentials) {
      @Override
      public double getRandom() {
        return random();
      }
    }.setBatchSize(testingBatchSize);
  }

  /**
   * Gets big tests.
   *
   * @return the big tests
   */
  @Nonnull
  public List> getBigTests() {
    return Arrays.asList(
        getPerformanceTester(),
        getBatchingTester(),
        getReferenceIOTester(),
        getEquivalencyTester()
    );
  }

  /**
   * Gets big tests.
   *
   * @return the big tests
   */
  @Nonnull
  public List> getFinalTests() {
    return Arrays.asList(
        getTrainingTester()
    );
  }

  /**
   * Gets derivative tester.
   *
   * @return the derivative tester
   */
  @Nullable
  public ComponentTest getDerivativeTester() {
    if (!validateDifferentials) return null;
    return new SingleDerivativeTester(tolerance, 1e-4);
  }

  /**
   * Gets equivalency tester.
   *
   * @return the equivalency tester
   */
  @Nullable
  public ComponentTest getEquivalencyTester() {
    if (!testEquivalency) return null;
    @Nullable final Layer referenceLayer = getReferenceLayer();
    if (null == referenceLayer) return null;
    @Nonnull EquivalencyTester equivalencyTester = new EquivalencyTester(1e-2, referenceLayer);
    referenceLayer.freeRef();
    return equivalencyTester;
  }

  /**
   * Get input dims int [ ] [ ].
   *
   * @param random the random
   * @return the int [ ] [ ]
   */
  public abstract int[][] getSmallDims(Random random);

  /**
   * Gets json tester.
   *
   * @return the json tester
   */
  @Nullable
  protected ComponentTest getJsonTester() {
    return new SerializationTest();
  }

  /**
   * Gets key.
   *
   * @param inputSize the input size
   * @param random    the random
   * @return the key
   */
  public abstract Layer getLayer(int[][] inputSize, Random random);

  /**
   * Gets little tests.
   *
   * @return the little tests
   */
  @Nonnull
  public List> getLittleTests() {
    return Arrays.asList(
        getJsonTester(),
        getDerivativeTester()
    );
  }

  /**
   * Get perf dims int [ ] [ ].
   *
   * @param random the random
   * @return the int [ ] [ ]
   */
  public int[][] getLargeDims(Random random) {
    return getSmallDims(new Random());
  }

  /**
   * Gets reference io.
   *
   * @return the reference io
   */
  protected HashMap getReferenceIO() {
    return new HashMap<>();
  }

  /**
   * Gets performance tester.
   *
   * @return the performance tester
   */
  @Nullable
  public ComponentTest getPerformanceTester() {
    return new PerformanceTester().setBatches(this.testingBatchSize);
  }

  /**
   * Gets reference io tester.
   *
   * @return the reference io tester
   */
  @Nullable
  protected ComponentTest getReferenceIOTester() {
    return new ReferenceIO(getReferenceIO());
  }

  /**
   * Gets reference key.
   *
   * @return the reference key
   */
  @Nullable
  public Layer getReferenceLayer() {
    return convertToReferenceLayer(getLayer(getSmallDims(new Random()), new Random()));
  }

  /**
   * Gets test class.
   *
   * @return the test class
   */
  public Class getTestClass() {
    Layer layer = getLayer(getSmallDims(new Random()), new Random());
    Class layerClass = layer.getClass();
    layer.freeRef();
    return layerClass;
  }

  protected final Layer convertToReferenceLayer(Layer layer) {
    AtomicInteger counter = new AtomicInteger(0);
    Layer cvt = cvt(layer, counter);
    if (counter.get() == 0) {
      if (null != cvt) cvt.freeRef();
      return null;
    } else return cvt;
  }

  private final Layer cvt(Layer layer, AtomicInteger counter) {
    if (layer instanceof DAGNetwork) {
      ((DAGNetwork) layer).visitNodes(node -> {
        Layer cvt = cvt(node.getLayer().addRef(), counter);
        node.setLayer(cvt);
        cvt.freeRef();
      });
      return layer;
    } else if (getTestClass().isAssignableFrom(layer.getClass())) {
      @Nullable Class referenceLayerClass = getReferenceLayerClass();
      if (null == referenceLayerClass) {
        layer.freeRef();
        return null;
      } else {
        @Nonnull Layer cast = layer.as(referenceLayerClass);
        layer.freeRef();
        counter.incrementAndGet();
        return cast;
      }
    } else {
      return layer;
    }
  }

  /**
   * Gets reference key class.
   *
   * @return the reference key class
   */
  @Nullable
  public Class getReferenceLayerClass() {
    return null;
  }

  /**
   * Gets learning tester.
   *
   * @return the learning tester
   */
  @Nullable
  public ComponentTest getTrainingTester() {
    return isTestTraining() ? new TrainingTester() {
      @Override
      protected Layer lossLayer() {
        return StandardLayerTests.this.lossLayer();
      }
    } : null;
  }

  protected abstract Layer lossLayer();

  /**
   * Random double.
   *
   * @return the double
   */
  public double random() {
    return random(random);
  }

  /**
   * Random double.
   *
   * @param random the random
   * @return the double
   */
  public double random(@Nonnull Random random) {
    return Math.round(1000.0 * (random.nextDouble() - 0.5)) / 250.0;
  }

  /**
   * Random tensor [ ].
   *
   * @param inputDims the input dims
   * @return the tensor [ ]
   */
  public Tensor[] randomize(@Nonnull final int[][] inputDims) {
    return Arrays.stream(inputDims).map(dim -> new Tensor(dim).set(() -> random())).toArray(i -> new Tensor[i]);
  }

  /**
   * Test.
   *
   * @param log the log
   */
  public void run(@Nonnull final NotebookOutput log) {
    TreeMap javadoc = javadocs.get(getTargetClass().getCanonicalName());
    if (null != javadoc) {
      log.p("Class Javadoc: " + javadoc.get(":class"));
      javadoc.remove(":class");
      javadoc.forEach((key, doc) -> {
        log.p(String.format("Field __%s__: %s", key, doc));
      });
    }

    long seed = (long) (Math.random() * Long.MAX_VALUE);
    int[][] smallDims = getSmallDims(new Random(seed));
    final Layer smallLayer = getLayer(smallDims, new Random(seed));
    int[][] largeDims = getLargeDims(new Random(seed));
    final Layer largeLayer = getLayer(largeDims, new Random(seed));

    log.h1("Test Modules");
    TableOutput results = new TableOutput();
    try {
      if (smallLayer instanceof DAGNetwork) {
        try {
          log.h1("Network Diagram");
          log.p("This is a network apply the following layout:");
          log.eval(() -> {
            return Graphviz.fromGraph((Graph) TestUtil.toGraph((DAGNetwork) smallLayer))
                .height(400).width(600).render(Format.PNG).toImage();
          });
        } catch (Throwable e) {
          logger.info("Error plotting graph", e);
        }
      } else if (smallLayer instanceof Explodable) {
        try {
          Layer explode = ((Explodable) smallLayer).explode();
          if (explode instanceof DAGNetwork) {
            log.h1("Exploded Network Diagram");
            log.p("This is a network apply the following layout:");
            @Nonnull DAGNetwork network = (DAGNetwork) explode;
            log.eval(() -> {
              @Nonnull Graphviz graphviz = Graphviz.fromGraph((Graph) TestUtil.toGraph(network)).height(400).width(600);
              @Nonnull File file = new File(log.getResourceDir(), log.getName() + "_network.svg");
              graphviz.render(Format.SVG_STANDALONE).toFile(file);
              log.link(file, "Saved to File");
              return graphviz.render(Format.SVG).toString();
            });
          }
        } catch (Throwable e) {
          logger.info("Error plotting graph", e);
        }
      }
      @Nonnull ArrayList exceptions = standardTests(log, seed, results);
      if (!exceptions.isEmpty()) {
        if (smallLayer instanceof DAGNetwork) {
          for (@Nonnull Invocation invocation : getInvocations(smallLayer, smallDims)) {
            log.h1("Small SubTests: " + invocation.getLayer().getClass().getSimpleName());
            log.p(Arrays.deepToString(invocation.getDims()));
            tests(log, getLittleTests(), invocation, exceptions, results);
            invocation.freeRef();
          }
        }
        if (largeLayer instanceof DAGNetwork) {
          testEquivalency = false;
          for (@Nonnull Invocation invocation : getInvocations(largeLayer, largeDims)) {
            log.h1("Large SubTests: " + invocation.getLayer().getClass().getSimpleName());
            log.p(Arrays.deepToString(invocation.getDims()));
            tests(log, getBigTests(), invocation, exceptions, results);
            invocation.freeRef();
          }
        }
      }
      log.run(() -> {
        throwException(exceptions);
      });
    } finally {
      smallLayer.freeRef();
      largeLayer.freeRef();
    }
    getFinalTests().stream().filter(x -> null != x).forEach(test -> {
      final Layer perfLayer;
      perfLayer = getLayer(largeDims, new Random(seed));
      perfLayer.assertAlive();
      @Nonnull Layer copy;
      copy = perfLayer.copy();
      Tensor[] randomize = randomize(largeDims);
      HashMap testResultProps = new HashMap<>();
      try {
        Class testClass = test.getClass();
        String name = testClass.getCanonicalName();
        if (null == name) name = testClass.getName();
        if (null == name) name = testClass.toString();
        testResultProps.put("class", name);
        Object result = log.subreport(name, sublog -> test.test(sublog, copy, randomize));
        testResultProps.put("details", null == result ? null : result.toString());
        testResultProps.put("result", "OK");
      } catch (Throwable e) {
        testResultProps.put("result", e.toString());
        throw new RuntimeException(e);
      } finally {
        results.putRow(testResultProps);
        test.freeRef();
        for (@Nonnull Tensor tensor : randomize) {
          tensor.freeRef();
        }
        perfLayer.freeRef();
        copy.freeRef();
      }
    });
    log.h1("Test Matrix");
    log.out(results.toMarkdownTable());

  }

  /**
   * Gets invocations.
   *
   * @param smallLayer the small key
   * @param smallDims  the small dims
   * @return the invocations
   */
  @Nonnull
  public Collection getInvocations(@Nonnull Layer smallLayer, @Nonnull int[][] smallDims) {
    @Nonnull DAGNetwork smallCopy = (DAGNetwork) smallLayer.copy();
    @Nonnull HashSet invocations = new HashSet<>();
    smallCopy.visitNodes(node -> {
      @Nullable Layer inner = node.getLayer();
      inner.addRef();
      @Nullable Layer wrapper = new LayerBase() {
        @Nullable
        @Override
        public Result eval(@Nonnull Result... array) {
          if (null == inner) return null;
          @Nullable Result result = inner.eval(array);
          invocations.add(new Invocation(inner, Arrays.stream(array).map(x -> x.getData().getDimensions()).toArray(i -> new int[i][])));
          return result;
        }

        @Override
        public JsonObject getJson(Map resources, DataSerializer dataSerializer) {
          return inner.getJson(resources, dataSerializer).getAsJsonObject();
        }

        @Nullable
        @Override
        public List state() {
          return inner.state();
        }

        @Override
        protected void _free() {
          inner.freeRef();
        }
      };
      node.setLayer(wrapper);
      wrapper.freeRef();
    });
    Tensor[] input = Arrays.stream(smallDims).map(i -> new Tensor(i)).toArray(i -> new Tensor[i]);
    try {
      Result eval = smallCopy.eval(input);
      eval.freeRef();
      eval.getData().freeRef();
      return invocations;
    } finally {
      Arrays.stream(input).forEach(ReferenceCounting::freeRef);
      smallCopy.freeRef();
    }
  }

  /**
   * Throw exception.
   *
   * @param exceptions the exceptions
   */
  public void throwException(@Nonnull ArrayList exceptions) {
    for (@Nonnull TestError exception : exceptions) {
      logger.info(String.format("LayerBase: %s", exception.layer));
      logger.info("Error", exception);
    }
    for (Throwable exception : exceptions) {
      try {
        ReferenceCountingBase.supressLog = true;
        System.gc();
        throw new RuntimeException(exception);
      } finally {
        ReferenceCountingBase.supressLog = false;
      }
    }
  }

  /**
   * Standard tests array list.
   *
   * @param log     the log
   * @param seed    the seed
   * @param results
   * @return the array list
   */
  @Nonnull
  public ArrayList standardTests(@Nonnull NotebookOutput log, long seed, TableOutput results) {
    log.p(String.format("Using Seed %d", seed));
    @Nonnull ArrayList exceptions = new ArrayList<>();
    final Layer layer = getLayer(getSmallDims(new Random(seed)), new Random(seed));
    Invocation invocation = new Invocation(layer, getSmallDims(new Random(seed)));
    try {
      tests(log, getLittleTests(), invocation, exceptions, results);
    } finally {
      invocation.freeRef();
      layer.freeRef();
    }
    final Layer perfLayer = getLayer(getLargeDims(new Random(seed)), new Random(seed));
    try {
      bigTests(log, seed, perfLayer, exceptions, results);
    } finally {
      perfLayer.freeRef();
    }
    return exceptions;
  }

  /**
   * Big tests.
   *
   * @param log        the log
   * @param seed       the seed
   * @param perfLayer  the perf key
   * @param exceptions the exceptions
   * @param results
   */
  public void bigTests(NotebookOutput log, long seed, @Nonnull Layer perfLayer, @Nonnull ArrayList exceptions, TableOutput results) {
    getBigTests().stream().filter(x -> null != x).forEach(test -> {
      @Nonnull Layer layer = perfLayer.copy();
      try {
        Tensor[] input = randomize(getLargeDims(new Random(seed)));
        LinkedHashMap testResultProps = new LinkedHashMap<>();
        try {
          String testclass = test.getClass().getCanonicalName();
          if (null == testclass || testclass.isEmpty()) testclass = test.toString();
          testResultProps.put("class", testclass);
          Object result = log.subreport(testclass, sublog -> test.test(sublog, layer, input));
          testResultProps.put("details", null == result ? null : result.toString());
          testResultProps.put("result", "OK");
        } catch (Throwable e) {
          testResultProps.put("result", e.toString());
          throw new RuntimeException(e);
        } finally {
          results.putRow(testResultProps);
          for (@Nonnull Tensor t : input) {
            t.freeRef();
          }
        }
      } catch (LifecycleException e) {
        throw e;
      } catch (Throwable e) {
        if (e.getClass().getSimpleName().equals("CudaError")) throw e;
        exceptions.add(new TestError(e, test, layer));
      } finally {
        layer.freeRef();
        test.freeRef();
        System.gc();
      }
    });
  }

  private void tests(final NotebookOutput log, final List> tests, @Nonnull final Invocation invocation, @Nonnull final ArrayList exceptions, TableOutput results) {
    tests.stream().filter(x -> null != x).forEach((ComponentTest test) -> {
      @Nonnull Layer layer = invocation.getLayer().copy();
      //layer.addRef();
      Tensor[] inputs = randomize(invocation.getDims());
      LinkedHashMap testResultProps = new LinkedHashMap<>();
      try {
        String testname = test.getClass().getCanonicalName();
        testResultProps.put("class", testname);
        Object result = log.subreport(testname, sublog -> test.test(sublog, layer, inputs));
        testResultProps.put("details", null == result ? null : result.toString());
        testResultProps.put("result", "OK");
      } catch (LifecycleException e) {
        throw e;
      } catch (Throwable e) {
        testResultProps.put("result", e.toString());
        exceptions.add(new TestError(e, test, layer));
      } finally {
        results.putRow(testResultProps);
        for (@Nonnull Tensor tensor : inputs) tensor.freeRef();
        layer.freeRef();
        test.freeRef();
        System.gc();
      }
    });
  }

  @Override
  protected Class getTargetClass() {
    try {
      Layer layer = getLayer(getSmallDims(new Random()), new Random());
      Class layerClass = layer.getClass();
      layer.freeRef();
      return layerClass;
    } catch (Throwable e) {
      logger.warn("ERROR", e);
      return getClass();
    }
  }

  @Nonnull
  @Override
  public ReportType getReportType() {
    return ReportType.Components;
  }

  /**
   * Is test training boolean.
   *
   * @return the boolean
   */
  public boolean isTestTraining() {
    return testTraining;
  }

  /**
   * Sets test training.
   *
   * @param testTraining the test training
   * @return the test training
   */
  @Nonnull
  public StandardLayerTests setTestTraining(boolean testTraining) {
    this.testTraining = testTraining;
    return this;
  }

  /**
   * Gets random.
   *
   * @return the random
   */
  @Nonnull
  public Random getRandom() {
    return new Random(seed);
  }

  private static class Invocation extends ReferenceCountingBase {
    private final Layer layer;
    private final int[][] smallDims;

    private Invocation(Layer layer, int[][] smallDims) {
      this.layer = layer;
      this.smallDims = smallDims;
      this.layer.addRef();
    }

    @Override
    protected void _free() {
      this.layer.freeRef();
      super._free();
    }

    /**
     * Gets key.
     *
     * @return the key
     */
    public Layer getLayer() {
      return layer;
    }

    /**
     * Get dims int [ ] [ ].
     *
     * @return the int [ ] [ ]
     */
    public int[][] getDims() {
      return smallDims;
    }

    @Override
    public boolean equals(Object o) {
      if (this == o) return true;
      if (!(o instanceof Invocation)) return false;

      @Nonnull Invocation that = (Invocation) o;

      if (layer != null ? !layer.getClass().equals(that.layer.getClass()) : that.layer != null) return false;
      return Arrays.deepEquals(smallDims, that.smallDims);
    }

    @Override
    public int hashCode() {
      int result = layer != null ? layer.getClass().hashCode() : 0;
      result = 31 * result + Arrays.deepHashCode(smallDims);
      return result;
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy