All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.test.unit.SingleDerivativeTester Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.test.unit;

import com.simiacryptus.mindseye.lang.*;
import com.simiacryptus.mindseye.test.SimpleEval;
import com.simiacryptus.mindseye.test.ToleranceStatistics;
import com.simiacryptus.notebook.NotebookOutput;
import com.simiacryptus.ref.lang.RefIgnore;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.wrappers.*;
import com.simiacryptus.util.data.ScalarStatistics;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.Arrays;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.function.IntFunction;
import java.util.stream.IntStream;

/**
 * The type Single derivative tester.
 */
public class SingleDerivativeTester extends ComponentTestBase {
  private static final Logger log = LoggerFactory.getLogger(SingleDerivativeTester.class);

  /**
   * The Probe size.
   */
  public final double probeSize;
  private final double tolerance;
  private boolean testFeedback = true;
  private boolean testLearning = true;
  private boolean verbose = true;
  private boolean verify = true;

  /**
   * Instantiates a new Single derivative tester.
   *
   * @param tolerance the tolerance
   * @param probeSize the probe size
   */
  public SingleDerivativeTester(final double tolerance, final double probeSize) {
    this.tolerance = tolerance;
    this.probeSize = probeSize;
  }

  /**
   * Is test feedback boolean.
   *
   * @return the boolean
   */
  public boolean isTestFeedback() {
    return testFeedback;
  }

  /**
   * Sets test feedback.
   *
   * @param testFeedback the test feedback
   */
  public void setTestFeedback(boolean testFeedback) {
    this.testFeedback = testFeedback;
  }

  /**
   * Is test learning boolean.
   *
   * @return the boolean
   */
  public boolean isTestLearning() {
    return testLearning;
  }

  /**
   * Sets test learning.
   *
   * @param testLearning the test learning
   */
  public void setTestLearning(boolean testLearning) {
    this.testLearning = testLearning;
  }

  /**
   * Is verbose boolean.
   *
   * @return the boolean
   */
  public boolean isVerbose() {
    return verbose;
  }

  /**
   * Sets verbose.
   *
   * @param verbose the verbose
   */
  public void setVerbose(boolean verbose) {
    this.verbose = verbose;
  }

  /**
   * Is verify boolean.
   *
   * @return the boolean
   */
  public boolean isVerify() {
    return verify;
  }

  /**
   * Sets verify.
   *
   * @param verify the verify
   */
  public void setVerify(boolean verify) {
    this.verify = verify;
  }

  @Override
  public ToleranceStatistics test(@Nonnull final NotebookOutput output, @Nonnull final Layer component,
                                  @Nonnull final Tensor... inputPrototype) {
    output.h1("Differential Validation");
    SimpleEval temp_00_0023 = SimpleEval.run(component.addRef(),
        RefUtil.addRef(inputPrototype));
    final Tensor outputPrototype = temp_00_0023.getOutput();
    temp_00_0023.freeRef();
    try {
      if (verbose) {
        output.run(RefUtil.wrapInterface(() -> {
              log.info(RefString.format("Inputs: %s", prettyPrint(inputPrototype)));
              log.info(RefString.format("Inputs Statistics: %s", printStats(inputPrototype)));
              log.info(RefString.format("Output: %s", outputPrototype.prettyPrint()));
              assert outputPrototype != null;
              log.info(RefString.format("Outputs Statistics: %s", outputPrototype.getScalarStatistics()));
            },
            outputPrototype.addRef(),
            RefUtil.addRef(inputPrototype)));
      }
      ToleranceStatistics _statistics = new ToleranceStatistics();
      if (isTestFeedback()) {
        output.h2("Feedback Validation");
        output.p(
            "We validate the agreement between the implemented derivative _of the inputs_ apply finite difference estimations:");
        final ToleranceStatistics statistics = _statistics;
        _statistics = output.eval(RefUtil.wrapInterface(() -> {
              return testFeedback(
                  statistics,
                  component.addRef(),
                  RefUtil.addRef(inputPrototype),
                  outputPrototype.addRef());
            },
            outputPrototype.addRef(),
            RefUtil.addRef(inputPrototype),
            component.addRef()));
      }
      if (isTestLearning()) {
        output.h2("Learning Validation");
        output.p(
            "We validate the agreement between the implemented derivative _of the internal weights_ apply finite difference estimations:");
        final ToleranceStatistics statistics = _statistics;
        _statistics = output.eval(RefUtil.wrapInterface(() -> {
              return testLearning(
                  statistics,
                  component.addRef(),
                  RefUtil.addRef(inputPrototype),
                  outputPrototype.addRef());
            },
            outputPrototype.addRef(),
            RefUtil.addRef(inputPrototype),
            component.addRef()));
      }
      output.h2("Total Accuracy");
      output
          .p("The overall agreement accuracy between the implemented derivative and the finite difference estimations:");
      final ToleranceStatistics statistics = _statistics;
      output.run(() -> {
        //log.info(String.format("Component: %s\nInputs: %s\noutput=%s", component, Arrays.toStream(inputPrototype), outputPrototype));
        log.info(RefString.format("Finite-Difference Derivative Accuracy:"));
        log.info(RefString.format("absoluteTol: %s", statistics.absoluteTol));
        log.info(RefString.format("relativeTol: %s", statistics.relativeTol));
      });

      output.h2("Frozen and Alive Status");
      output.run(RefUtil.wrapInterface(() -> {
        testFrozen(component.addRef(), RefUtil.addRef(inputPrototype));
        testUnFrozen(component.addRef(), RefUtil.addRef(inputPrototype));
      }, RefUtil.addRef(inputPrototype), component.addRef()));
      return _statistics;
    } finally {
      outputPrototype.freeRef();
      component.freeRef();
      RefUtil.freeRef(inputPrototype);
    }
  }

  /**
   * Print stats string.
   *
   * @param array the array
   * @return the string
   */
  @NotNull
  @RefIgnore
  public String printStats(@RefIgnore Tensor[] array) {
    return Arrays.stream(array)
        .map(data -> data.getScalarStatistics())
        .map(ScalarStatistics::toString)
        .reduce((a, b) -> a + ",\n" + b)
        .orElse("");
  }

  /**
   * Pretty print string.
   *
   * @param array the array
   * @return the string
   */
  @NotNull
  @RefIgnore
  public String prettyPrint(@RefIgnore Tensor[] array) {
    return Arrays.stream(array)
        .map(Tensor::prettyPrint)
        .reduce((a, b) -> a + ",\n" + b)
        .orElse("");
  }

  /**
   * Test learning tolerance statistics.
   *
   * @param prev            the prev
   * @param component       the component
   * @param inputPrototype  the input prototype
   * @param outputPrototype the output prototype
   * @return the tolerance statistics
   */
  public ToleranceStatistics testLearning(@Nonnull ToleranceStatistics prev, @Nonnull Layer component,
                                          @Nullable Tensor[] inputPrototype, @Nonnull Tensor outputPrototype) {
    RefList temp_00_0024 = component.state();
    assert temp_00_0024 != null;
    int size = temp_00_0024.size();
    temp_00_0024.freeRef();
    assert verify;
    return RefIntStream.range(0, size)
        .mapToObj(RefUtil.wrapInterface((IntFunction) i -> {
              Tensor measuredGradient = measureLearningGradient(component.addRef(), i,
                  outputPrototype.addRef(), RefUtil.addRef(inputPrototype));
              @Nonnull final Tensor implementedGradient = getLearningGradient(component.addRef(), i,
                  outputPrototype.addRef(), RefUtil.addRef(inputPrototype));
              @Nonnull
              Tensor difference = measuredGradient.minus(implementedGradient.addRef());
              try {
                final ToleranceStatistics result = RefIntStream
                    .range(0, measuredGradient.length())
                    .mapToObj(RefUtil.wrapInterface((IntFunction) gradientIndex -> {
                          return new ToleranceStatistics().accumulate(
                              measuredGradient.get(gradientIndex),
                              implementedGradient.get(gradientIndex));
                        },
                        implementedGradient.addRef(),
                        measuredGradient.addRef()))
                    .reduce(ToleranceStatistics::combine)
                    .orElse(new ToleranceStatistics());

                //log.info(String.format("Component: %s", component));
                if (!(result.absoluteTol.getMax() < tolerance)) {
                  throw new AssertionError(result.toString());
                }
                if (verbose) {
                  log.info(RefString.format("Learning Gradient for weight setByCoord %s", i));
                  RefList temp_00_0026 = component.state();
                  assert temp_00_0026 != null;
                  double[] doubles = temp_00_0026.get(i);
                  log.info(RefString.format("Weights: %s", Tensor.prettyPrint(doubles)));
                  temp_00_0026.freeRef();
                  log.info(RefString.format("Implemented Gradient: %s", implementedGradient.prettyPrint()));
                  log.info(RefString.format("Implemented Statistics: %s",
                      implementedGradient.getScalarStatistics()));
                  log.info(RefString.format("Measured Gradient: %s", measuredGradient.prettyPrint()));
                  log.info(RefString.format("Measured Statistics: %s",
                      measuredGradient.getScalarStatistics()));
                  log.info(RefString.format("Gradient Error: %s", difference.prettyPrint()));
                  log.info(RefString.format("Error Statistics: %s", difference.getScalarStatistics()));
                }
                return result;
              } catch (@Nonnull final Throwable e) {
                //log.info(String.format("Component: %s", component));
                log.info(RefString.format("Learning Gradient for weight setByCoord %s", i));
                log.info(RefString.format("Implemented Gradient: %s", implementedGradient.prettyPrint()));
                log.info(RefString.format("Implemented Statistics: %s",
                    implementedGradient.getScalarStatistics()));
                log.info(RefString.format("Measured Gradient: %s", measuredGradient.prettyPrint()));
                log.info(
                    RefString.format("Measured Statistics: %s", measuredGradient.getScalarStatistics()));
                log.info(RefString.format("Gradient Error: %s", difference.prettyPrint()));
                log.info(RefString.format("Error Statistics: %s", difference.getScalarStatistics()));
                throw e;
              } finally {
                measuredGradient.freeRef();
                implementedGradient.freeRef();
                difference.freeRef();
              }
            },
            outputPrototype,
            component,
            inputPrototype)
        ).reduce(ToleranceStatistics::combine).map(x -> x.combine(prev)).orElse(prev);
  }

  /**
   * Test feedback tolerance statistics.
   *
   * @param statistics      the statistics
   * @param component       the component
   * @param inputPrototype  the input prototype
   * @param outputPrototype the output prototype
   * @return the tolerance statistics
   */
  @Nonnull
  public ToleranceStatistics testFeedback(@Nonnull ToleranceStatistics statistics, @Nonnull Layer component,
                                          @Nonnull Tensor[] inputPrototype, @Nonnull Tensor outputPrototype) {
    Optional optional = RefIntStream.range(0, inputPrototype.length)
        .mapToObj(RefUtil.wrapInterface((IntFunction) i -> {
              assert verify;
              //@Nullable final Tensor measuredGradient = !verify ? null : temp_00_0027.addRef();
              final Tensor measuredGradient = measureFeedbackGradient(component.addRef(), i,
                  outputPrototype.addRef(), RefUtil.addRef(inputPrototype));
              @Nonnull final Tensor implementedGradient = getFeedbackGradient(component.addRef(), i,
                  outputPrototype.addRef(), RefUtil.addRef(inputPrototype));
              Tensor maskedGradient = implementedGradient.mapCoords(RefUtil.wrapInterface(
                  c -> {
                    return Double.isNaN(measuredGradient.get(c.getCoords())) ? Double.NaN : implementedGradient.get(c);
                  },
                  implementedGradient.addRef(),
                  measuredGradient.addRef()
              ));
              @Nonnull
              Tensor difference = measuredGradient.minus(maskedGradient.addRef());
              try {
                final ToleranceStatistics result = RefIntStream
                    .range(0, measuredGradient.length())
                    .mapToObj(RefUtil.wrapInterface((IntFunction) i1 -> {
                          return new ToleranceStatistics().accumulate(
                              measuredGradient.get(i1),
                              maskedGradient.get(i1));
                        },
                        maskedGradient.addRef(),
                        measuredGradient.addRef()
                    ))
                    .reduce(ToleranceStatistics::combine)
                    .orElse(new ToleranceStatistics());

                //log.info(String.format("Component: %s", component));
                if (!(result.absoluteTol.getMax() < tolerance)) {
                  throw new AssertionError(result.toString());
                }
                if (verbose) {
                  log.info(RefString.format("Feedback for input %s", i));
                  log.info(RefString.format("Inputs Values: %s", inputPrototype[i].prettyPrint()));
                  log.info(
                      RefString.format("Value Statistics: %s", inputPrototype[i].getScalarStatistics()));
                  log.info(RefString.format("Implemented Feedback: %s", implementedGradient.prettyPrint()));
                  log.info(RefString.format("Implemented Statistics: %s",
                      implementedGradient.getScalarStatistics()));
                  log.info(RefString.format("Measured Feedback: %s", measuredGradient.prettyPrint()));
                  log.info(RefString.format("Measured Statistics: %s",
                      measuredGradient.getScalarStatistics()));
                  log.info(RefString.format("Feedback Error: %s", difference.prettyPrint()));
                  log.info(RefString.format("Error Statistics: %s", difference.getScalarStatistics()));
                }
                return result;
              } catch (@Nonnull final Throwable e) {
                //log.info(String.format("Component: %s", component));
                log.info(RefString.format("Feedback for input %s", i));
                log.info(RefString.format("Inputs Values: %s", inputPrototype[i].prettyPrint()));
                log.info(RefString.format("Value Statistics: %s", inputPrototype[i].getScalarStatistics()));
                if (!implementedGradient.isFreed()) {
                  log.info(RefString.format("Implemented Feedback: %s", implementedGradient.prettyPrint()));
                  log.info(RefString.format("Implemented Statistics: %s",
                      implementedGradient.getScalarStatistics()));
                }
                log.info(RefString.format("Measured: %s", measuredGradient.prettyPrint()));
                log.info(RefString.format(
                    "Measured Statistics: %s", measuredGradient.getScalarStatistics()));
                log.info(RefString.format("Feedback Error: %s", difference.prettyPrint()));
                log.info(RefString.format("Error Statistics: %s", difference.getScalarStatistics()));
                throw e;
              } finally {
                measuredGradient.freeRef();
                implementedGradient.freeRef();
                maskedGradient.freeRef();
                difference.freeRef();
              }
            },
            component,
            inputPrototype,
            outputPrototype
        )).reduce(ToleranceStatistics::combine);
    if (!optional.isPresent())
      return statistics;
    return statistics.combine(RefUtil.orElse(optional, null));
  }

  /**
   * Test frozen.
   *
   * @param component      the component
   * @param inputPrototype the input prototype
   */
  public void testFrozen(@Nonnull final Layer component, @Nonnull Tensor[] inputPrototype) {
    final int inElements = RefArrays.stream(RefUtil.addRef(inputPrototype)).mapToInt(x -> {
      int temp_00_0005 = x.length();
      x.freeRef();
      return temp_00_0005;
    }).sum();
    inputPrototype = RefArrays.stream(inputPrototype).map(tensor -> {
      Tensor temp_00_0006 = tensor.copy();
      tensor.freeRef();
      return temp_00_0006;
    }).toArray(Tensor[]::new);
    @Nonnull final AtomicBoolean reachedInputFeedback = new AtomicBoolean(false);
    RefList inputCopies = RefArrays.stream(inputPrototype)
        .map(TensorArray::new)
        .collect(RefCollectors.toList());
    Result[] input = inputCopies.stream().map(tensorArray -> {
      Result.Accumulator accumulator = new Result.Accumulator() {
        @Override
        public void accept(@Nonnull DeltaSet buffer, @Nonnull TensorList data) {
          reachedInputFeedback.set(true);
          buffer.freeRef();
          data.freeRef();
        }

        @Override
        public void _free() {
          super._free();
        }
      };
      return new Result(tensorArray, accumulator, true);
    }).toArray(Result[]::new);
    inputCopies.freeRef();
    Layer frozen = component.copy();
    frozen.freeze();
    @Nullable final Result eval = frozen.eval(input);
    frozen.freeRef();
    assert eval != null;
    @Nonnull final DeltaSet buffer = new DeltaSet();
    TensorList evalData = eval.getData();
    eval.accumulate(buffer.addRef(), evalData.copy());
    evalData.freeRef();
    eval.freeRef();
    RefList temp_00_0029 = component.state();
    assert temp_00_0029 != null;
    final RefList> deltas = temp_00_0029.stream()
        .map(RefUtil.wrapInterface((Function>) doubles -> {
          Optional> temp_00_0031 = buffer.stream().filter(x -> {
            boolean temp_00_0009 = x.target == doubles;
            x.freeRef();
            return temp_00_0009;
          }).findFirst();
          Delta temp_00_0030 = temp_00_0031.orElse(null);
          RefUtil.freeRef(temp_00_0031);
          return temp_00_0030;
        }, buffer)).filter(x -> {
          boolean temp_00_0010 = x != null;
          if (null != x)
            x.freeRef();
          return temp_00_0010;
        }).collect(RefCollectors.toList());
    temp_00_0029.freeRef();
    RefList temp_00_0032 = component.state();
    assert temp_00_0032 != null;
    if (!deltas.isEmpty() && !temp_00_0032.isEmpty()) {
      temp_00_0032.freeRef();
      AssertionError temp_00_0011 = new AssertionError("Frozen component listed in evalInputDelta. Deltas: " + deltas);
      deltas.freeRef();
      component.freeRef();
      throw temp_00_0011;
    }
    temp_00_0032.freeRef();
    component.freeRef();
    deltas.freeRef();
    if (!reachedInputFeedback.get() && 0 < inElements) {
      throw new RuntimeException("Frozen component did not pass input backwards");
    }
  }

  /**
   * Test un frozen.
   *
   * @param component      the component
   * @param inputPrototype the input prototype
   */
  public void testUnFrozen(@Nonnull final Layer component, Tensor[] inputPrototype) {
    inputPrototype = RefArrays.stream(inputPrototype).map(tensor -> {
      Tensor temp_00_0012 = tensor.copy();
      tensor.freeRef();
      return temp_00_0012;
    }).toArray(Tensor[]::new);
    @Nonnull final AtomicBoolean reachedInputFeedback = new AtomicBoolean(false);
    Layer frozen = component.copy();
    frozen.setFrozen(false);
    component.freeRef();
    RefList inputCopies = RefArrays.stream(RefUtil.addRef(inputPrototype)).map(TensorArray::new).collect(RefCollectors.toList());
    Result[] inputs = inputCopies.stream().map(tensor -> {
      Result.Accumulator accumulator = new Result.Accumulator() {
        @Override
        public void accept(@Nonnull DeltaSet buffer, @Nonnull TensorList data) {
          buffer.freeRef();
          data.freeRef();
          reachedInputFeedback.set(true);
        }

        @Override
        public void _free() {
          super._free();
        }
      };
      return new Result(tensor, accumulator, true);
    }).toArray(Result[]::new);
    inputCopies.freeRef();
    @Nullable final Result eval = frozen.eval(inputs);
    @Nonnull final DeltaSet buffer = new DeltaSet();
    assert eval != null;
    eval.accumulate(buffer.addRef(), eval.getData());
    eval.freeRef();
    @Nullable final RefList stateList = frozen.state();
    frozen.freeRef();
    assert stateList != null;
    final RefList> deltas = stateList.stream()
        .map(RefUtil.wrapInterface((Function>) doubles -> {
          Optional> temp_00_0035 = buffer.stream().filter(x -> {
            boolean temp_00_0015 = x.target == doubles;
            x.freeRef();
            return temp_00_0015;
          }).findFirst();
          Delta temp_00_0034 = temp_00_0035.orElse(null);
          RefUtil.freeRef(temp_00_0035);
          return temp_00_0034;
        }, buffer)).filter(x -> {
          boolean temp_00_0016 = x != null;
          if (null != x) x.freeRef();
          return temp_00_0016;
        }).collect(RefCollectors.toList());
    try {
      if (deltas.isEmpty() && !stateList.isEmpty()) {
        throw new AssertionError(
            "Nonfrozen component not listed in evalInputDelta. Deltas: " + deltas);
      }
      if (!reachedInputFeedback.get() && inputPrototype.length != 0) {
        throw new RuntimeException("Nonfrozen component did not pass input backwards");
      }
    } finally {
      deltas.freeRef();
      stateList.freeRef();
      RefUtil.freeRef(inputPrototype);
    }
  }

  @Nonnull
  @Override
  public String toString() {
    return "SingleDerivativeTester{" + "probeSize=" + probeSize + ", tolerance=" + tolerance + ", testFeedback="
        + testFeedback + ", testLearning=" + testLearning + ", verbose=" + verbose + ", verify=" + verify + '}';
  }

  public @SuppressWarnings("unused")
  void _free() {
    super._free();
  }

  @Nonnull
  public @Override
  @SuppressWarnings("unused")
  SingleDerivativeTester addRef() {
    return (SingleDerivativeTester) super.addRef();
  }

  /**
   * Measure feedback.
   *
   * @param component        the component
   * @param inputIndex       the input index
   * @param baseOutput       the base output
   * @param inputPrototype   the input prototype
   * @param measuredGradient the measured gradient
   * @param probeIndex       the probe index
   */
  protected void measureFeedback(@Nonnull Layer component, int inputIndex, @Nullable Tensor baseOutput,
                                 @Nonnull Tensor[] inputPrototype, @Nonnull Tensor measuredGradient, int probeIndex) {
    @Nonnull final Tensor inputProbe = inputPrototype[inputIndex].copy();
    inputProbe.add(probeIndex, probeSize * 1);
    @Nonnull final Tensor[] copyInput = RefArrays.copyOf(inputPrototype, inputPrototype.length);
    RefUtil.set(copyInput, inputIndex, inputProbe);
    try {
      Result temp_00_0036 = component.eval(ConstantResult.batchResultArray(new Tensor[][]{copyInput}));
      assert temp_00_0036 != null;
      TensorList temp_00_0037 = temp_00_0036.getData();
      @Nullable final Tensor evalProbe = temp_00_0037.get(0);
      temp_00_0037.freeRef();
      temp_00_0036.freeRef();
      Tensor delta = evalProbe.minus(baseOutput == null ? null : baseOutput.addRef());
      delta.scaleInPlace(1. / probeSize);
      evalProbe.freeRef();
      for (int j = 0; j < delta.length(); j++) {
        measuredGradient.set(new int[]{probeIndex, j}, delta.get(j));
      }
      delta.freeRef();
    } finally {
      measuredGradient.freeRef();
      if (null != baseOutput)
        baseOutput.freeRef();
      component.freeRef();
    }
  }

  @Nonnull
  private Tensor getFeedbackGradient(@Nonnull final Layer component, final int inputIndex,
                                     @Nonnull final Tensor outputPrototype, @Nonnull final Tensor... inputPrototype) {
    final Tensor inputTensor = inputPrototype[inputIndex].addRef();
    final int inputLength = inputTensor.length();
    int[] inputDimensions = inputTensor.getDimensions();
    final int outputLength = outputPrototype.length();
    int[] outputDimensions = outputPrototype.getDimensions();
    outputPrototype.freeRef();
    @Nonnull final Tensor result = new Tensor(inputLength, outputLength);
    try {
      IntStream.range(0, outputLength).forEach(outputIndex -> {
        final UUID inputKeyId = UUID.randomUUID();
        final Result[] copyInput = RefArrays.stream(RefUtil.addRef(inputPrototype))
            .map(TensorArray::new)
            .map(data -> new Result(data, new NullAccumulator()))
            .toArray(Result[]::new);
        double[] target = new double[inputLength * outputLength];
        Result.Accumulator accumulator = new Result.Accumulator() {

          @Override
          public void accept(@Nonnull DeltaSet buffer, @Nonnull TensorList data) {
            try {
              if (1 != data.length()) throw new AssertionError();
              if (data.length() != 1) throw new AssertionError();
              if (!RefArrays.equals(inputDimensions, data.getDimensions())) throw new AssertionError();
              @Nonnull final Tensor gradientBuffer = new Tensor(inputLength, outputLength);
              IntStream.range(0, data.length()).forEach(batchIndex -> {
                IntStream.range(0, inputLength).forEach(inputIndex -> {
                  Tensor tensor = data.get(batchIndex);
                  double value = tensor.get(inputIndex);
                  tensor.freeRef();
                  gradientBuffer.set(new int[]{inputIndex, outputIndex}, value);
                });
              });
              Delta delta = buffer.get(inputKeyId, target);
              assert delta != null;
              delta.addInPlace(gradientBuffer);
              delta.freeRef();
            } finally {
              data.freeRef();
              buffer.freeRef();
            }
          }

          @Override
          public void _free() {
            super._free();
          }
        };
        RefUtil.set(copyInput, inputIndex, new Result(new TensorArray(inputTensor.addRef()), accumulator, true));
        @Nullable final Result eval = eval(component.addRef(), copyInput);
        assert eval != null;
        @Nonnull final DeltaSet deltaSet = new DeltaSet<>();
        eval.accumulate(deltaSet.addRef(), oneHotTensorArray(outputDimensions, outputIndex));
        eval.freeRef();
        Tensor tensor = getDelta(deltaSet, inputKeyId, result.getDimensions());
        if (null != tensor) result.addInPlace(tensor);
      });
    } finally {
      RefUtil.freeRef(inputPrototype);
      component.freeRef();
      inputTensor.freeRef();
    }
    return result;
  }

  @org.jetbrains.annotations.Nullable
  private Tensor getDelta(DeltaSet deltaSet, UUID inputKeyId, int[] dimensions) {
    final Delta inputDelta = deltaSet.get(inputKeyId);
    final Tensor tensor;
    if (null != inputDelta) {
      tensor = new Tensor(inputDelta.getDelta(), dimensions);
      inputDelta.freeRef();
    } else {
      tensor = null;
    }
    deltaSet.freeRef();
    return tensor;
  }

  @NotNull
  private TensorArray oneHotTensorArray(int[] outputDimensions, int j) {
    Tensor tensor1 = new Tensor(outputDimensions);
    tensor1.set(j, 1);
    return new TensorArray(tensor1);
  }

  private Result eval(@Nonnull Layer component, Result[] copyInput) {
    try {
      return component.eval(copyInput);
    } finally {
      component.freeRef();
    }
  }

  @Nonnull
  private Tensor getLearningGradient(@Nonnull final Layer component, final int layerNum,
                                     @Nonnull final Tensor outputPrototype, @Nullable final Tensor... inputPrototype) {
    component.setFrozen(false);
    RefList temp_00_0039 = component.state();
    assert temp_00_0039 != null;
    final double[] stateArray = temp_00_0039.get(layerNum);
    temp_00_0039.freeRef();
    final int stateLen = stateArray.length;
    @Nonnull final Tensor gradient = new Tensor(stateLen, outputPrototype.length());
    for (int j = 0; j < outputPrototype.length(); j++) {
      final int j_ = j;
      @Nonnull final DeltaSet buffer = new DeltaSet();
      Result[] array = ConstantResult.batchResultArray(new Tensor[][]{RefUtil.addRef(inputPrototype)});
      @Nullable final Result eval = component.eval(array);
      Tensor temp_00_0022 = new Tensor(outputPrototype.getDimensions());
      temp_00_0022.set(k -> k == j_ ? 1 : 0);
      @Nonnull
      TensorArray tensorArray = new TensorArray(temp_00_0022.addRef());
      temp_00_0022.freeRef();
      assert eval != null;
      eval.accumulate(buffer.addRef(), tensorArray);
      RefUtil.freeRef(eval.getData());
      eval.freeRef();
      RefMap> temp_00_0040 = buffer.getMap();
      RefCollection> temp_00_0041 = temp_00_0040.values();
      final DoubleBuffer deltaFlushBuffer = RefUtil.orElse(temp_00_0041.stream().filter(x -> {
        try {
          return x.target == stateArray;
        } finally {
          x.freeRef();
        }
      }).findFirst(), null);
      temp_00_0041.freeRef();
      temp_00_0040.freeRef();
      buffer.freeRef();
      if (null != deltaFlushBuffer) {
        for (int i = 0; i < stateLen; i++) {
          gradient.set(new int[]{i, j_}, deltaFlushBuffer.getDelta()[i]);
        }
      }
      if (null != deltaFlushBuffer)
        deltaFlushBuffer.freeRef();
    }
    if (null != inputPrototype)
      RefUtil.freeRef(inputPrototype);
    outputPrototype.freeRef();
    component.freeRef();
    return gradient;
  }

  @Nonnull
  private Tensor measureFeedbackGradient(@Nonnull final Layer component, final int inputIndex,
                                         @Nonnull final Tensor outputPrototype, @Nonnull final Tensor... inputPrototype) {
    int length = inputPrototype[inputIndex].length();
    @Nonnull final Tensor measuredGradient = new Tensor(length, outputPrototype.length());
    Result[] input0 = ConstantResult.batchResultArray(new Tensor[][]{RefUtil.addRef(inputPrototype)});
    Result temp_00_0043 = component.eval(input0);
    assert temp_00_0043 != null;
    TensorList temp_00_0044 = Result.getData(temp_00_0043);
    @Nullable final Tensor baseOutput = temp_00_0044.get(0);
    temp_00_0044.freeRef();
    outputPrototype.set(baseOutput.addRef());
    outputPrototype.freeRef();
    for (int probeIndex = 0; probeIndex < length; probeIndex++) {
      measureFeedback(component.addRef(), inputIndex,
          baseOutput.addRef(), RefUtil.addRef(inputPrototype),
          measuredGradient.addRef(), probeIndex);
    }
    RefUtil.freeRef(inputPrototype);
    component.freeRef();
    baseOutput.freeRef();
    return measuredGradient;
  }

  @Nonnull
  private Tensor measureLearningGradient(@Nonnull final Layer component, final int layerNum,
                                         @Nonnull final Tensor outputPrototype, @Nullable final Tensor... inputPrototype) {
    RefList temp_00_0045 = component.state();
    assert temp_00_0045 != null;
    double[] doubles = temp_00_0045.get(layerNum);
    final int stateLen = doubles.length;
    temp_00_0045.freeRef();
    @Nonnull final Tensor gradient = new Tensor(stateLen, outputPrototype.length());

    outputPrototype.freeRef();
    Result[] input2 = ConstantResult.batchResultArray(new Tensor[][]{RefUtil.addRef(inputPrototype)});
    if (null != inputPrototype)
      RefUtil.freeRef(inputPrototype);
    Result temp_00_0046 = component.eval(RefUtil.addRef(input2));
    assert temp_00_0046 != null;
    TensorList temp_00_0047 = temp_00_0046.getData();
    @Nullable final Tensor baseOutput = temp_00_0047.get(0);

    temp_00_0047.freeRef();
    temp_00_0046.freeRef();
    for (int i = 0; i < stateLen; i++) {
      @Nonnull final Layer copy = component.copy();
      RefList temp_00_0048 = copy.state();
      assert temp_00_0048 != null;
      double[] doubles1 = temp_00_0048.get(layerNum);
      doubles1[i] += probeSize;
      temp_00_0048.freeRef();
      Result temp_00_0049 = copy.eval(RefUtil.addRef(input2));
      assert temp_00_0049 != null;
      TensorList temp_00_0050 = temp_00_0049.getData();
      @Nullable final Tensor evalProbe = temp_00_0050.get(0);
      temp_00_0050.freeRef();
      temp_00_0049.freeRef();
      copy.freeRef();
      Tensor delta = evalProbe.minus(baseOutput.addRef());
      delta.scaleInPlace(1. / probeSize);
      evalProbe.freeRef();
      for (int j = 0; j < delta.length(); j++) {
        gradient.set(new int[]{i, j}, delta.get(j));
      }
      delta.freeRef();
    }
    component.freeRef();
    baseOutput.freeRef();
    RefUtil.freeRef(input2);
    return gradient;
  }

  private static class NullAccumulator extends Result.Accumulator {
    @Override
    public void accept(@Nonnull DeltaSet buffer, @Nonnull TensorList data) {
      buffer.freeRef();
      data.freeRef();
    }

    @Override
    public void _free() {
      super._free();
    }
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy