All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.test.unit.BatchDerivativeTester Maven / Gradle / Ivy

There is a newer version: 2.1.0
Show newest version
/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.test.unit;

import com.simiacryptus.mindseye.lang.*;
import com.simiacryptus.mindseye.layers.PlaceholderLayer;
import com.simiacryptus.mindseye.test.SimpleEval;
import com.simiacryptus.mindseye.test.ToleranceStatistics;
import com.simiacryptus.notebook.NotebookOutput;
import com.simiacryptus.util.data.ScalarStatistics;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.Arrays;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
 * The type Derivative tester.
 */
public class BatchDerivativeTester extends ComponentTestBase {
  /**
   * The Logger.
   */
  static final Logger log = LoggerFactory.getLogger(BatchDerivativeTester.class);

  /**
   * The Probe size.
   */
  public final double probeSize;
  private final int batches;
  private final double tolerance;
  private boolean testFeedback = true;
  private boolean testLearning = true;
  private boolean verbose = true;
  private boolean verify = true;

  /**
   * Instantiates a new Derivative tester.
   *
   * @param tolerance the tolerance
   * @param probeSize the probe size
   * @param batches   the batches
   */
  public BatchDerivativeTester(final double tolerance, final double probeSize, final int batches) {
    this.tolerance = tolerance;
    this.probeSize = probeSize;
    this.batches = batches;
  }

  @Nonnull
  private Tensor getFeedbackGradient(@Nonnull final Layer component, final int inputIndex, @Nonnull final Tensor outputPrototype, final Tensor... inputPrototype) {
    final Tensor inputTensor = inputPrototype[inputIndex];
    final int inputDims = inputTensor.length();
    @Nonnull final Tensor result = new Tensor(inputDims, outputPrototype.length());
    for (int j = 0; j < outputPrototype.length(); j++) {
      final int j_ = j;
      @Nonnull final PlaceholderLayer inputKey = new PlaceholderLayer(new Tensor());
      @Nonnull final Result copyInput = new Result(TensorArray.create(inputPrototype), (@Nonnull final DeltaSet buffer, @Nonnull final TensorList data) -> {
        @Nonnull final Tensor gradientBuffer = new Tensor(inputDims, outputPrototype.length());
        if (!Arrays.equals(inputTensor.getDimensions(), data.get(inputIndex).getDimensions())) {
          throw new AssertionError();
        }
        for (int i = 0; i < inputDims; i++) {
          gradientBuffer.set(new int[]{i, j_}, data.get(inputIndex).getData()[i]);
        }
        buffer.get(inputKey.getId(), new double[gradientBuffer.length()]).addInPlace(gradientBuffer.getData());
      }) {

        @Override
        public boolean isAlive() {
          return true;
        }

      };
      @Nullable final Result eval = component.eval(copyInput);
      @Nonnull final DeltaSet xxx = new DeltaSet();
      @Nonnull TensorArray tensorArray = TensorArray.wrap(eval.getData().stream().map(x -> {
        @Nonnull Tensor set = x.set(j_, 1);
        x.freeRef();
        return set;
      }).toArray(i -> new Tensor[i]));
      eval.accumulate(xxx, tensorArray);
      final Delta inputDelta = xxx.getMap().get(inputKey);
      if (null != inputDelta) {
        result.addInPlace(new Tensor(inputDelta.getDelta(), result.getDimensions()));
      }
    }
    return result;
  }

  @Nonnull
  private Tensor getLearningGradient(@Nonnull final Layer component, final int layerNum, @Nonnull final Tensor outputPrototype, final Tensor... inputPrototype) {
    component.setFrozen(false);
    final double[] stateArray = component.state().get(layerNum);
    final int stateLen = stateArray.length;
    @Nonnull final Tensor gradient = new Tensor(stateLen, outputPrototype.length());
    for (int j = 0; j < outputPrototype.length(); j++) {
      final int j_ = j;
      @Nonnull final DeltaSet buffer = new DeltaSet();
      @Nonnull final Tensor data = new Tensor(outputPrototype.getDimensions()).set((k) -> k == j_ ? 1 : 0);
      @Nullable final Result eval = component.eval(ConstantResult.singleResultArray(new Tensor[][]{inputPrototype}));
      eval.getData().get(0);
      @Nonnull TensorArray tensorArray = TensorArray.wrap(data);
      eval.accumulate(buffer, tensorArray);
      final DoubleBuffer deltaFlushBuffer = buffer.getMap().values().stream().filter(x -> x.target == stateArray).findFirst().orElse(null);
      if (null != deltaFlushBuffer) {
        for (int i = 0; i < stateLen; i++) {
          gradient.set(new int[]{i, j_}, deltaFlushBuffer.getDelta()[i]);
        }
      }
    }
    return gradient;
  }

  /**
   * Is apply feedback boolean.
   *
   * @return the boolean
   */
  public boolean isTestFeedback() {
    return testFeedback;
  }

  /**
   * Sets apply feedback.
   *
   * @param testFeedback the apply feedback
   * @return the apply feedback
   */
  @Nonnull
  public BatchDerivativeTester setTestFeedback(final boolean testFeedback) {
    this.testFeedback = testFeedback;
    return this;
  }

  /**
   * Is apply learning boolean.
   *
   * @return the boolean
   */
  public boolean isTestLearning() {
    return testLearning;
  }

  /**
   * Sets apply learning.
   *
   * @param testLearning the apply learning
   * @return the apply learning
   */
  @Nonnull
  public BatchDerivativeTester setTestLearning(final boolean testLearning) {
    this.testLearning = testLearning;
    return this;
  }

  /**
   * Is verbose boolean.
   *
   * @return the boolean
   */
  public boolean isVerbose() {
    return verbose;
  }

  /**
   * Sets verbose.
   *
   * @param verbose the verbose
   * @return the verbose
   */
  @Nonnull
  public BatchDerivativeTester setVerbose(final boolean verbose) {
    this.verbose = verbose;
    return this;
  }

  /**
   * Is verify boolean.
   *
   * @return the boolean
   */
  public boolean isVerify() {
    return verify;
  }

  /**
   * Sets verify.
   *
   * @param verify the verify
   * @return the verify
   */
  @Nonnull
  public BatchDerivativeTester setVerify(final boolean verify) {
    this.verify = verify;
    return this;
  }

  @Nonnull
  private Tensor measureFeedbackGradient(@Nonnull final Layer component, final int inputIndex, @Nonnull final Tensor outputPrototype, @Nonnull final Tensor... inputPrototype) {
    @Nonnull final Tensor measuredGradient = new Tensor(inputPrototype[inputIndex].length(), outputPrototype.length());
    @Nullable final Tensor baseOutput = component.eval(ConstantResult.singleResultArray(new Tensor[][]{inputPrototype})).getData().get(0);
    outputPrototype.set(baseOutput);
    for (int i = 0; i < inputPrototype[inputIndex].length(); i++) {
      @Nonnull final Tensor inputProbe = inputPrototype[inputIndex].copy();
      inputProbe.add(i, probeSize * 1);
      @Nonnull final Tensor[] copyInput = Arrays.copyOf(inputPrototype, inputPrototype.length);
      copyInput[inputIndex] = inputProbe;
      @Nullable final Tensor evalProbe = component.eval(ConstantResult.singleResultArray(new Tensor[][]{copyInput})).getData().get(0);
      @Nonnull final Tensor delta = evalProbe.minus(baseOutput).scaleInPlace(1. / probeSize);
      for (int j = 0; j < delta.length(); j++) {
        measuredGradient.set(new int[]{i, j}, delta.getData()[j]);
      }
    }
    return measuredGradient;
  }

  @Nonnull
  private Tensor measureLearningGradient(@Nonnull final Layer component, final int layerNum, @Nonnull final Tensor outputPrototype, final Tensor... inputPrototype) {
    final int stateLen = component.state().get(layerNum).length;
    @Nonnull final Tensor gradient = new Tensor(stateLen, outputPrototype.length());

    @Nullable final Tensor baseOutput = component.eval(ConstantResult.singleResultArray(new Tensor[][]{inputPrototype})).getData().get(0);

    for (int i = 0; i < stateLen; i++) {
      @Nonnull final Layer copy = component.copy();
      copy.state().get(layerNum)[i] += probeSize;

      @Nullable final Tensor evalProbe = copy.eval(ConstantResult.singleResultArray(new Tensor[][]{inputPrototype})).getData().get(0);

      @Nonnull final Tensor delta = evalProbe.minus(baseOutput).scaleInPlace(1. / probeSize);
      for (int j = 0; j < delta.length(); j++) {
        gradient.set(new int[]{i, j}, delta.getData()[j]);
      }
    }
    return gradient;
  }

  /**
   * Test learning tolerance statistics.
   *
   * @param component  the component
   * @param IOPair     the io pair
   * @param statistics the statistics
   * @return the tolerance statistics
   */
  public ToleranceStatistics testLearning(@Nonnull Layer component, @Nonnull IOPair IOPair, ToleranceStatistics statistics) {
    final ToleranceStatistics prev = statistics;
    statistics = IntStream.range(0, component.state().size()).mapToObj(i -> {
      @Nullable final Tensor measuredGradient = !verify ? null : measureLearningGradient(component, i, IOPair.getOutputPrototype(), IOPair.getInputPrototype());
      @Nonnull final Tensor implementedGradient = getLearningGradient(component, i, IOPair.getOutputPrototype(), IOPair.getInputPrototype());
      try {
        final ToleranceStatistics result = IntStream.range(0, null == measuredGradient ? 0 : measuredGradient.length()).mapToObj(i1 -> {
          return new ToleranceStatistics().accumulate(measuredGradient.getData()[i1], implementedGradient.getData()[i1]);
        }).reduce((a, b) -> a.combine(b)).orElse(new ToleranceStatistics());
        if (!(result.absoluteTol.getMax() < tolerance)) {
          throw new AssertionError(result.toString());
        } else {
          //log.info(String.format("Component: %s", component));
          if (verbose) {

            log.info(String.format("Learning Gradient for weight setByCoord %s", i));
            log.info(String.format("Weights: %s", new Tensor(component.state().get(i)).prettyPrint()));
            log.info(String.format("Implemented Gradient: %s", implementedGradient.prettyPrint()));
            log.info(String.format("Implemented Statistics: %s", new ScalarStatistics().add(implementedGradient.getData())));
            if (null != measuredGradient) {
              log.info(String.format("Measured Gradient: %s", measuredGradient.prettyPrint()));
              log.info(String.format("Measured Statistics: %s", new ScalarStatistics().add(measuredGradient.getData())));
              log.info(String.format("Gradient Error: %s", measuredGradient.minus(implementedGradient).prettyPrint()));
              log.info(String.format("Error Statistics: %s", new ScalarStatistics().add(measuredGradient.minus(implementedGradient).getData())));
            }
          }
          return result;
        }
      } catch (@Nonnull final Throwable e) {
        //log.info(String.format("Component: %s", component));
        log.info(String.format("Learning Gradient for weight setByCoord %s", i));
        log.info(String.format("Implemented Gradient: %s", implementedGradient.prettyPrint()));
        log.info(String.format("Implemented Statistics: %s", new ScalarStatistics().add(implementedGradient.getData())));
        if (null != measuredGradient) {
          log.info(String.format("Measured Gradient: %s", measuredGradient.prettyPrint()));
          log.info(String.format("Measured Statistics: %s", new ScalarStatistics().add(measuredGradient.getData())));
          log.info(String.format("Gradient Error: %s", measuredGradient.minus(implementedGradient).prettyPrint()));
          log.info(String.format("Error Statistics: %s", new ScalarStatistics().add(measuredGradient.minus(implementedGradient).getData())));
        }
        throw e;
      }

    }).reduce((a, b) -> a.combine(b)).map(x -> x.combine(prev)).orElseGet(() -> prev);
    return statistics;
  }

  /**
   * Test feedback tolerance statistics.
   *
   * @param component  the component
   * @param IOPair     the io pair
   * @param statistics the statistics
   * @return the tolerance statistics
   */
  public ToleranceStatistics testFeedback(@Nonnull Layer component, @Nonnull IOPair IOPair, ToleranceStatistics statistics) {
    statistics = statistics.combine(IntStream.range(0, IOPair.getInputPrototype().length).mapToObj(i -> {
      @Nullable final Tensor measuredGradient = !verify ? null : measureFeedbackGradient(component, i, IOPair.getOutputPrototype(), IOPair.getInputPrototype());
      @Nonnull final Tensor implementedGradient = getFeedbackGradient(component, i, IOPair.getOutputPrototype(), IOPair.getInputPrototype());
      try {
        final ToleranceStatistics result = IntStream.range(0, null == measuredGradient ? 0 : measuredGradient.length()).mapToObj(i1 -> {
          return new ToleranceStatistics().accumulate(measuredGradient.getData()[i1], implementedGradient.getData()[i1]);
        }).reduce((a, b) -> a.combine(b)).orElse(new ToleranceStatistics());

        if (!(result.absoluteTol.getMax() < tolerance)) throw new AssertionError(result.toString());
        //log.info(String.format("Component: %s", component));
        if (verbose) {
          log.info(String.format("Feedback for input %s", i));
          log.info(String.format("Inputs Values: %s", IOPair.getInputPrototype()[i].prettyPrint()));
          log.info(String.format("Value Statistics: %s", new ScalarStatistics().add(IOPair.getInputPrototype()[i].getData())));
          log.info(String.format("Implemented Feedback: %s", implementedGradient.prettyPrint()));
          log.info(String.format("Implemented Statistics: %s", new ScalarStatistics().add(implementedGradient.getData())));
          if (null != measuredGradient) {
            log.info(String.format("Measured Feedback: %s", measuredGradient.prettyPrint()));
            log.info(String.format("Measured Statistics: %s", new ScalarStatistics().add(measuredGradient.getData())));
            log.info(String.format("Feedback Error: %s", measuredGradient.minus(implementedGradient).prettyPrint()));
            log.info(String.format("Error Statistics: %s", new ScalarStatistics().add(measuredGradient.minus(implementedGradient).getData())));
          }
        }
        return result;
      } catch (@Nonnull final Throwable e) {
        //log.info(String.format("Component: %s", component));
        log.info(String.format("Feedback for input %s", i));
        log.info(String.format("Inputs Values: %s", IOPair.getInputPrototype()[i].prettyPrint()));
        log.info(String.format("Value Statistics: %s", new ScalarStatistics().add(IOPair.getInputPrototype()[i].getData())));
        log.info(String.format("Implemented Feedback: %s", implementedGradient.prettyPrint()));
        log.info(String.format("Implemented Statistics: %s", new ScalarStatistics().add(implementedGradient.getData())));
        if (null != measuredGradient) {
          log.info(String.format("Measured: %s", measuredGradient.prettyPrint()));
          log.info(String.format("Measured Statistics: %s", new ScalarStatistics().add(measuredGradient.getData())));
          log.info(String.format("Feedback Error: %s", measuredGradient.minus(implementedGradient).prettyPrint()));
          log.info(String.format("Error Statistics: %s", new ScalarStatistics().add(measuredGradient.minus(implementedGradient).getData())));
        }
        throw e;
      }
    }).reduce((a, b) -> a.combine(b)).get());
    return statistics;
  }

  /**
   * Test tolerance statistics.
   *
   * @param log
   * @param component      the component
   * @param inputPrototype the input prototype
   * @return the tolerance statistics
   */
  @Override
  public ToleranceStatistics test(@Nonnull final NotebookOutput log, @Nonnull final Layer component, @Nonnull final Tensor... inputPrototype) {
    log.h1("Differential Validation");
    @Nonnull IOPair ioPair = new IOPair(component, inputPrototype[0]).invoke();

    if (verbose) {
      log.run(() -> {
        BatchDerivativeTester.log.info(String.format("Inputs: %s", Arrays.stream(inputPrototype).map(t -> t.prettyPrint()).reduce((a, b) -> a + ",\n" + b).get()));
        BatchDerivativeTester.log.info(String.format("Inputs Statistics: %s", Arrays.stream(inputPrototype).map(x -> new ScalarStatistics().add(x.getData()).toString()).reduce((a, b) -> a + ",\n" + b).get()));
        BatchDerivativeTester.log.info(String.format("Output: %s", ioPair.getOutputPrototype().prettyPrint()));
        BatchDerivativeTester.log.info(String.format("Outputs Statistics: %s", new ScalarStatistics().add(ioPair.getOutputPrototype().getData())));
      });
    }

    ToleranceStatistics _statistics = new ToleranceStatistics();

    if (isTestFeedback()) {
      log.h2("Feedback Validation");
      log.p("We validate the agreement between the implemented derivative _of the inputs_ apply finite difference estimations:");
      ToleranceStatistics statistics = _statistics;
      _statistics = log.eval(() -> {
        return testFeedback(component, ioPair, statistics);
      });
    }
    if (isTestLearning()) {
      log.h2("Learning Validation");
      log.p("We validate the agreement between the implemented derivative _of the internal weights_ apply finite difference estimations:");
      ToleranceStatistics statistics = _statistics;
      _statistics = log.eval(() -> {
        return testLearning(component, ioPair, statistics);
      });
    }

    log.h2("Total Accuracy");
    log.p("The overall agreement accuracy between the implemented derivative and the finite difference estimations:");
    ToleranceStatistics statistics = _statistics;
    log.run(() -> {
      //log.info(String.format("Component: %s\nInputs: %s\noutput=%s", component, Arrays.toStream(inputPrototype), outputPrototype));
      BatchDerivativeTester.log.info(String.format("Finite-Difference Derivative Accuracy:"));
      BatchDerivativeTester.log.info(String.format("absoluteTol: %s", statistics.absoluteTol));
      BatchDerivativeTester.log.info(String.format("relativeTol: %s", statistics.relativeTol));
    });

    log.h2("Frozen and Alive Status");
    log.run(() -> {
      testFrozen(component, ioPair.getInputPrototype());
      testUnFrozen(component, ioPair.getInputPrototype());
    });

    return _statistics;
  }

  /**
   * Test frozen.
   *
   * @param component      the component
   * @param inputPrototype the input prototype
   */
  public void testFrozen(@Nonnull final Layer component, @Nonnull final Tensor[] inputPrototype) {
    @Nonnull final AtomicBoolean reachedInputFeedback = new AtomicBoolean(false);
    @Nonnull final Layer frozen = component.copy().freeze();
    @Nullable final Result eval = frozen.eval(new Result(TensorArray.create(inputPrototype), (@Nonnull final DeltaSet buffer, @Nonnull final TensorList data) -> {
      reachedInputFeedback.set(true);
    }) {

      @Override
      public boolean isAlive() {
        return true;
      }


    });
    @Nonnull final DeltaSet buffer = new DeltaSet();
    TensorList tensorList = eval.getData().copy();
    eval.accumulate(buffer, tensorList);
    final List> deltas = component.state().stream().map(doubles -> {
      return buffer.stream().filter(x -> x.target == doubles).findFirst().orElse(null);
    }).filter(x -> x != null).collect(Collectors.toList());
    if (!deltas.isEmpty() && !component.state().isEmpty()) {
      throw new AssertionError("Frozen component listed in evalInputDelta. Deltas: " + deltas);
    }
    final int inElements = Arrays.stream(inputPrototype).mapToInt(x -> x.length()).sum();
    if (!reachedInputFeedback.get() && 0 < inElements) {
      throw new RuntimeException("Frozen component did not pass input backwards");
    }
  }

  /**
   * Test un frozen.
   *
   * @param component      the component
   * @param inputPrototype the input prototype
   */
  public void testUnFrozen(@Nonnull final Layer component, final Tensor[] inputPrototype) {
    @Nonnull final AtomicBoolean reachedInputFeedback = new AtomicBoolean(false);
    @Nonnull final Layer frozen = component.copy().setFrozen(false);
    @Nullable final Result eval = frozen.eval(new Result(TensorArray.create(inputPrototype), (@Nonnull final DeltaSet buffer, @Nonnull final TensorList data) -> {
      reachedInputFeedback.set(true);
    }) {

      @Override
      public boolean isAlive() {
        return true;
      }

    });
    @Nonnull final DeltaSet buffer = new DeltaSet();
    TensorList data = eval.getData();
    eval.accumulate(buffer, data);
    @Nullable final List stateList = frozen.state();
    final List> deltas = stateList.stream().map(doubles -> {
      return buffer.stream().filter(x -> x.target == doubles).findFirst().orElse(null);
    }).filter(x -> x != null).collect(Collectors.toList());
    if (deltas.isEmpty() && !stateList.isEmpty()) {
      throw new AssertionError("Nonfrozen component not listed in evalInputDelta. Deltas: " + deltas);
    }
    if (!reachedInputFeedback.get()) {
      throw new RuntimeException("Nonfrozen component did not pass input backwards");
    }
  }

  @Nonnull
  @Override
  public String toString() {
    return "BatchDerivativeTester{" +
        "probeSize=" + probeSize +
        ", batches=" + batches +
        ", tolerance=" + tolerance +
        ", testFeedback=" + testFeedback +
        ", testLearning=" + testLearning +
        ", verbose=" + verbose +
        ", verify=" + verify +
        '}';
  }

  private class IOPair {
    private final Layer component;
    private final Tensor tensor;
    private Tensor[] inputPrototype;
    private Tensor outputPrototype;

    /**
     * Instantiates a new Io pair.
     *
     * @param component the component
     * @param tensor    the tensor
     */
    public IOPair(Layer component, Tensor tensor) {
      this.component = component;
      this.tensor = tensor;
    }

    /**
     * Get input prototype tensor [ ].
     *
     * @return the tensor [ ]
     */
    public Tensor[] getInputPrototype() {
      return inputPrototype;
    }

    /**
     * Gets output prototype.
     *
     * @return the output prototype
     */
    public Tensor getOutputPrototype() {
      return outputPrototype;
    }

    /**
     * Invoke io pair.
     *
     * @return the io pair
     */
    @Nonnull
    public IOPair invoke() {
      inputPrototype = IntStream.range(0, batches).mapToObj(i -> tensor.copy()).toArray(j -> new Tensor[j]);
      outputPrototype = SimpleEval.run(component, inputPrototype[0]).getOutputAndFree();
      return this;
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy