
com.simiacryptus.mindseye.test.unit.BatchDerivativeTester Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mindseye-test Show documentation
Show all versions of mindseye-test Show documentation
Testing Tools for Neural Network Components
The newest version!
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.test.unit;
import com.simiacryptus.mindseye.lang.*;
import com.simiacryptus.mindseye.layers.PlaceholderLayer;
import com.simiacryptus.mindseye.test.SimpleEval;
import com.simiacryptus.mindseye.test.ToleranceStatistics;
import com.simiacryptus.notebook.NotebookOutput;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.lang.ReferenceCountingBase;
import com.simiacryptus.ref.wrappers.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.function.IntFunction;
/**
* The type Batch derivative tester.
*/
public class BatchDerivativeTester extends ComponentTestBase {
/**
* The Log.
*/
static final Logger log = LoggerFactory.getLogger(BatchDerivativeTester.class);
/**
* The Probe size.
*/
public final double probeSize;
private final int batches;
private final double tolerance;
private boolean testFeedback = true;
private boolean testLearning = true;
private boolean verbose = true;
private boolean verify = true;
/**
* Instantiates a new Batch derivative tester.
*
* @param tolerance the tolerance
* @param probeSize the probe size
* @param batches the batches
*/
public BatchDerivativeTester(final double tolerance, final double probeSize, final int batches) {
this.tolerance = tolerance;
this.probeSize = probeSize;
this.batches = batches;
}
/**
* Is test feedback boolean.
*
* @return the boolean
*/
public boolean isTestFeedback() {
return testFeedback;
}
/**
* Sets test feedback.
*
* @param testFeedback the test feedback
*/
public void setTestFeedback(boolean testFeedback) {
this.testFeedback = testFeedback;
}
/**
* Is test learning boolean.
*
* @return the boolean
*/
public boolean isTestLearning() {
return testLearning;
}
/**
* Sets test learning.
*
* @param testLearning the test learning
*/
public void setTestLearning(boolean testLearning) {
this.testLearning = testLearning;
}
/**
* Is verbose boolean.
*
* @return the boolean
*/
public boolean isVerbose() {
return verbose;
}
/**
* Sets verbose.
*
* @param verbose the verbose
*/
public void setVerbose(boolean verbose) {
this.verbose = verbose;
}
/**
* Is verify boolean.
*
* @return the boolean
*/
public boolean isVerify() {
return verify;
}
/**
* Sets verify.
*
* @param verify the verify
*/
public void setVerify(boolean verify) {
this.verify = verify;
}
/**
* Test learning tolerance statistics.
*
* @param component the component
* @param IOPair the io pair
* @param statistics the statistics
* @return the tolerance statistics
*/
public ToleranceStatistics testLearning(@Nonnull Layer component, @Nonnull IOPair IOPair,
ToleranceStatistics statistics) {
final ToleranceStatistics prev = statistics;
RefList temp_02_0021 = component.state();
assert temp_02_0021 != null;
statistics = RefIntStream.range(0, temp_02_0021.size())
.mapToObj(RefUtil.wrapInterface((IntFunction) i -> {
Tensor temp_02_0022 = measureLearningGradient(component.addRef(), i,
IOPair.getOutputPrototype(), IOPair.getInputPrototype());
@Nullable final Tensor measuredGradient = !verify ? null : temp_02_0022.addRef();
temp_02_0022.freeRef();
@Nonnull final Tensor implementedGradient = getLearningGradient(component.addRef(), i,
IOPair.getOutputPrototype(), IOPair.getInputPrototype());
try {
final ToleranceStatistics result = RefIntStream
.range(0, null == measuredGradient ? 0 : measuredGradient.length())
.mapToObj(RefUtil.wrapInterface((IntFunction) i1 -> {
assert measuredGradient != null;
return new ToleranceStatistics().accumulate(measuredGradient.get(i1),
implementedGradient.get(i1));
}, measuredGradient == null ? null : measuredGradient.addRef(),
implementedGradient.addRef()))
.reduce(ToleranceStatistics::combine).orElse(new ToleranceStatistics());
if (!(result.absoluteTol.getMax() < tolerance)) {
if (null != measuredGradient)
measuredGradient.freeRef();
implementedGradient.freeRef();
throw new AssertionError(result.toString());
} else {
//log.info(String.format("Component: %s", component));
if (verbose) {
log.info(RefString.format("Learning Gradient for weight setByCoord %s", i));
RefList temp_02_0023 = component.state();
assert temp_02_0023 != null;
Tensor temp_02_0018 = new Tensor(temp_02_0023.get(i));
temp_02_0023.freeRef();
log.info(RefString.format("Weights: %s", temp_02_0018.prettyPrint()));
temp_02_0018.freeRef();
log.info(RefString.format("Implemented Gradient: %s", implementedGradient.prettyPrint()));
log.info(RefString.format("Implemented Statistics: %s",
implementedGradient.getScalarStatistics()));
if (null != measuredGradient) {
log.info(RefString.format("Measured Gradient: %s", measuredGradient.prettyPrint()));
log.info(RefString.format("Measured Statistics: %s",
measuredGradient.getScalarStatistics()));
Tensor temp_02_0024 = measuredGradient
.minus(implementedGradient.addRef());
log.info(RefString.format("Gradient Error: %s", temp_02_0024.prettyPrint()));
temp_02_0024.freeRef();
Tensor temp_02_0025 = measuredGradient
.minus(implementedGradient.addRef());
log.info(
RefString.format("Error Statistics: %s", temp_02_0025.getScalarStatistics()));
temp_02_0025.freeRef();
}
}
if (null != measuredGradient)
measuredGradient.freeRef();
implementedGradient.freeRef();
return result;
}
} catch (@Nonnull final Throwable e) {
//log.info(String.format("Component: %s", component));
log.info(RefString.format("Learning Gradient for weight setByCoord %s", i));
log.info(RefString.format("Implemented Gradient: %s", implementedGradient.prettyPrint()));
log.info(RefString.format("Implemented Statistics: %s",
implementedGradient.getScalarStatistics()));
if (null != measuredGradient) {
log.info(RefString.format("Measured Gradient: %s", measuredGradient.prettyPrint()));
log.info(
RefString.format("Measured Statistics: %s", measuredGradient.getScalarStatistics()));
Tensor temp_02_0026 = measuredGradient
.minus(implementedGradient.addRef());
log.info(RefString.format("Gradient Error: %s", temp_02_0026.prettyPrint()));
temp_02_0026.freeRef();
Tensor temp_02_0027 = measuredGradient
.minus(implementedGradient.addRef());
log.info(RefString.format("Error Statistics: %s", temp_02_0027.getScalarStatistics()));
temp_02_0027.freeRef();
}
throw e;
}
}, component, IOPair)).reduce(ToleranceStatistics::combine)
.map(x -> x.combine(prev)).orElse(prev);
temp_02_0021.freeRef();
return statistics;
}
/**
* Test feedback tolerance statistics.
*
* @param component the component
* @param IOPair the io pair
* @param statistics the statistics
* @return the tolerance statistics
*/
public ToleranceStatistics testFeedback(@Nonnull Layer component, @Nonnull IOPair IOPair,
ToleranceStatistics statistics) {
Tensor[] inputPrototype = IOPair.getInputPrototype();
assert inputPrototype != null;
statistics = statistics.combine(RefUtil.get(RefIntStream.range(0, inputPrototype.length)
.mapToObj(RefUtil.wrapInterface((IntFunction) i -> {
Tensor temp_02_0029 = measureFeedbackGradient(component.addRef(), i,
IOPair.getOutputPrototype(), RefUtil.addRef(inputPrototype));
@Nullable final Tensor measuredGradient = !verify ? null : temp_02_0029.addRef();
temp_02_0029.freeRef();
@Nonnull final Tensor implementedGradient = getFeedbackGradient(component.addRef(), i,
IOPair.getOutputPrototype(), RefUtil.addRef(inputPrototype));
try {
final ToleranceStatistics result = RefIntStream
.range(0, null == measuredGradient ? 0 : measuredGradient.length())
.mapToObj(RefUtil.wrapInterface((IntFunction) i1 -> {
assert measuredGradient != null;
return new ToleranceStatistics().accumulate(measuredGradient.get(i1),
implementedGradient.get(i1));
}, implementedGradient.addRef(),
measuredGradient == null ? null : measuredGradient.addRef()))
.reduce(ToleranceStatistics::combine).orElse(new ToleranceStatistics());
if (!(result.absoluteTol.getMax() < tolerance)) {
if (null != measuredGradient)
measuredGradient.freeRef();
implementedGradient.freeRef();
throw new AssertionError(result.toString());
}
//log.info(String.format("Component: %s", component));
if (verbose) {
log.info(RefString.format("Feedback for input %s", i));
log.info(RefString.format("Inputs Values: %s", inputPrototype[i].prettyPrint()));
log.info(RefString.format("Value Statistics: %s",
inputPrototype[i].getScalarStatistics()));
log.info(RefString.format("Implemented Feedback: %s", implementedGradient.prettyPrint()));
log.info(RefString.format("Implemented Statistics: %s",
implementedGradient.getScalarStatistics()));
if (null != measuredGradient) {
log.info(RefString.format("Measured Feedback: %s", measuredGradient.prettyPrint()));
log.info(RefString.format("Measured Statistics: %s",
measuredGradient.getScalarStatistics()));
Tensor temp_02_0030 = measuredGradient
.minus(implementedGradient.addRef());
log.info(RefString.format("Feedback Error: %s", temp_02_0030.prettyPrint()));
temp_02_0030.freeRef();
Tensor temp_02_0031 = measuredGradient
.minus(implementedGradient.addRef());
log.info(RefString.format("Error Statistics: %s", temp_02_0031.getScalarStatistics()));
temp_02_0031.freeRef();
}
}
if (null != measuredGradient)
measuredGradient.freeRef();
implementedGradient.freeRef();
return result;
} catch (@Nonnull final Throwable e) {
//log.info(String.format("Component: %s", component));
log.info(RefString.format("Feedback for input %s", i));
log.info(RefString.format("Inputs Values: %s", inputPrototype[i].prettyPrint()));
log.info(RefString.format("Value Statistics: %s",
inputPrototype[i].getScalarStatistics()));
log.info(RefString.format("Implemented Feedback: %s", implementedGradient.prettyPrint()));
log.info(RefString.format("Implemented Statistics: %s",
implementedGradient.getScalarStatistics()));
if (null != measuredGradient) {
log.info(RefString.format("Measured: %s", measuredGradient.prettyPrint()));
log.info(
RefString.format("Measured Statistics: %s", measuredGradient.getScalarStatistics()));
Tensor temp_02_0032 = measuredGradient
.minus(implementedGradient.addRef());
log.info(RefString.format("Feedback Error: %s", temp_02_0032.prettyPrint()));
temp_02_0032.freeRef();
Tensor temp_02_0033 = measuredGradient
.minus(implementedGradient.addRef());
log.info(RefString.format("Error Statistics: %s", temp_02_0033.getScalarStatistics()));
temp_02_0033.freeRef();
}
throw e;
}
}, IOPair, component, inputPrototype)).reduce(ToleranceStatistics::combine)));
return statistics;
}
@Override
public ToleranceStatistics test(@Nonnull final NotebookOutput log, @Nonnull final Layer component,
@Nonnull final Tensor... inputPrototype) {
log.h1("Differential Validation");
BatchDerivativeTester.IOPair temp_02_0019 = new IOPair(component.addRef(),
inputPrototype[0].addRef(), BatchDerivativeTester.this.addRef());
temp_02_0019.invoke();
@Nonnull
IOPair ioPair = temp_02_0019.addRef();
temp_02_0019.freeRef();
if (verbose) {
log.run(RefUtil.wrapInterface(() -> {
BatchDerivativeTester.log
.info(RefString.format("Inputs: %s", RefUtil.get(RefArrays.stream(RefUtil.addRef(inputPrototype)).map(t -> {
String temp_02_0007 = t.prettyPrint();
t.freeRef();
return temp_02_0007;
}).reduce((a, b) -> a + ",\n" + b))));
BatchDerivativeTester.log
.info(RefString.format("Inputs Statistics: %s", RefUtil.get(RefArrays.stream(RefUtil.addRef(inputPrototype)).map(x -> {
String temp_02_0008 = x.getScalarStatistics().toString();
x.freeRef();
return temp_02_0008;
}).reduce((a, b) -> a + ",\n" + b))));
Tensor temp_02_0034 = ioPair.getOutputPrototype();
assert temp_02_0034 != null;
BatchDerivativeTester.log.info(RefString.format("Output: %s", temp_02_0034.prettyPrint()));
temp_02_0034.freeRef();
Tensor temp_02_0035 = ioPair.getOutputPrototype();
BatchDerivativeTester.log
.info(RefString.format("Outputs Statistics: %s", temp_02_0035.getScalarStatistics()));
temp_02_0035.freeRef();
}, RefUtil.addRef(inputPrototype), ioPair.addRef()));
}
RefUtil.freeRef(inputPrototype);
ToleranceStatistics _statistics = new ToleranceStatistics();
if (isTestFeedback()) {
log.h2("Feedback Validation");
log.p(
"We validate the agreement between the implemented derivative _of the inputs_ apply finite difference estimations:");
ToleranceStatistics statistics = _statistics;
_statistics = log.eval(RefUtil.wrapInterface(() -> {
return testFeedback(component.addRef(), ioPair.addRef(),
statistics);
}, component.addRef(), ioPair.addRef()));
}
if (isTestLearning()) {
log.h2("Learning Validation");
log.p(
"We validate the agreement between the implemented derivative _of the internal weights_ apply finite difference estimations:");
ToleranceStatistics statistics = _statistics;
_statistics = log.eval(RefUtil.wrapInterface(() -> {
return testLearning(component.addRef(), ioPair.addRef(),
statistics);
}, component.addRef(), ioPair.addRef()));
}
log.h2("Total Accuracy");
log.p("The overall agreement accuracy between the implemented derivative and the finite difference estimations:");
ToleranceStatistics statistics = _statistics;
log.run(() -> {
//log.info(String.format("Component: %s\nInputs: %s\noutput=%s", component, Arrays.toStream(inputPrototype), outputPrototype));
BatchDerivativeTester.log.info(RefString.format("Finite-Difference Derivative Accuracy:"));
BatchDerivativeTester.log.info(RefString.format("absoluteTol: %s", statistics.absoluteTol));
BatchDerivativeTester.log.info(RefString.format("relativeTol: %s", statistics.relativeTol));
});
log.h2("Frozen and Alive Status");
log.run(RefUtil.wrapInterface(() -> {
testFrozen(component.addRef(), ioPair.getInputPrototype());
testUnFrozen(component.addRef(), ioPair.getInputPrototype());
}, component, ioPair));
return _statistics;
}
/**
* Test frozen.
*
* @param component the component
* @param inputPrototype the input prototype
*/
public void testFrozen(@Nonnull final Layer component, @Nonnull final Tensor[] inputPrototype) {
@Nonnull final AtomicBoolean reachedInputFeedback = new AtomicBoolean(false);
Layer temp_02_0036 = component.copy();
temp_02_0036.freeze();
@Nonnull final Layer frozen = temp_02_0036.addRef();
temp_02_0036.freeRef();
Result.Accumulator accumulator = new Result.Accumulator() {
@Override
public void accept(@Nonnull DeltaSet buffer, @Nonnull TensorList data) {
buffer.freeRef();
data.freeRef();
reachedInputFeedback.set(true);
}
@Override
public void _free() {
super._free();
}
};
@Nullable final Result eval = frozen
.eval(new Result(new TensorArray(RefUtil.addRef(inputPrototype)), accumulator, true));
frozen.freeRef();
@Nonnull final DeltaSet buffer = new DeltaSet();
assert eval != null;
TensorList temp_02_0037 = eval.getData();
TensorList tensorList = temp_02_0037.copy();
temp_02_0037.freeRef();
eval.accumulate(buffer.addRef(), tensorList == null ? null : tensorList.addRef());
if (null != tensorList)
tensorList.freeRef();
eval.freeRef();
RefList temp_02_0038 = component.state();
assert temp_02_0038 != null;
final RefList> deltas = temp_02_0038.stream()
.map(RefUtil.wrapInterface((Function super double[], ? extends Delta>) doubles -> {
Optional> temp_02_0040 = buffer.stream().filter(x -> {
boolean temp_02_0009 = x.target == doubles;
x.freeRef();
return temp_02_0009;
}).findFirst();
Delta temp_02_0039 = temp_02_0040.orElse(null);
RefUtil.freeRef(temp_02_0040);
return temp_02_0039;
}, buffer)).filter(x -> {
boolean temp_02_0010 = x != null;
if (null != x)
x.freeRef();
return temp_02_0010;
}).collect(RefCollectors.toList());
temp_02_0038.freeRef();
RefList temp_02_0041 = component.state();
assert temp_02_0041 != null;
try {
if (!deltas.isEmpty() && !temp_02_0041.isEmpty()) {
throw new AssertionError("Frozen component listed in evalInputDelta. Deltas: " + deltas);
}
final int inElements = RefArrays.stream(RefUtil.addRef(inputPrototype)).mapToInt(x -> {
int temp_02_0012 = x.length();
x.freeRef();
return temp_02_0012;
}).sum();
if (!reachedInputFeedback.get() && 0 < inElements) {
throw new RuntimeException("Frozen component did not pass input backwards");
}
} finally {
deltas.freeRef();
component.freeRef();
RefUtil.freeRef(inputPrototype);
temp_02_0041.freeRef();
}
}
/**
* Test un frozen.
*
* @param component the component
* @param inputPrototype the input prototype
*/
public void testUnFrozen(@Nonnull final Layer component, @Nullable final Tensor[] inputPrototype) {
@Nonnull final AtomicBoolean reachedInputFeedback = new AtomicBoolean(false);
Layer temp_02_0042 = component.copy();
temp_02_0042.setFrozen(false);
@Nonnull final Layer frozen = temp_02_0042.addRef();
temp_02_0042.freeRef();
component.freeRef();
Result.Accumulator accumulator = new Result.Accumulator() {
@Override
public void accept(@Nonnull DeltaSet buffer, @Nonnull TensorList data) {
buffer.freeRef();
data.freeRef();
reachedInputFeedback.set(true);
}
@Override
public void _free() {
super._free();
}
};
@Nullable final Result eval = frozen.eval(new Result(new TensorArray(RefUtil.addRef(inputPrototype)), accumulator, true));
if (null != inputPrototype)
RefUtil.freeRef(inputPrototype);
@Nonnull final DeltaSet buffer = new DeltaSet();
assert eval != null;
TensorList data = eval.getData();
eval.accumulate(buffer.addRef(), data.addRef());
data.freeRef();
eval.freeRef();
@Nullable final RefList stateList = frozen.state();
frozen.freeRef();
assert stateList != null;
final RefList> deltas = stateList.stream()
.map(RefUtil.wrapInterface((Function super double[], ? extends Delta>) doubles -> {
Optional> temp_02_0044 = buffer.stream().filter(x -> {
boolean temp_02_0013 = x.target == doubles;
x.freeRef();
return temp_02_0013;
}).findFirst();
Delta temp_02_0043 = temp_02_0044.orElse(null);
RefUtil.freeRef(temp_02_0044);
return temp_02_0043;
}, buffer)).filter(x -> {
boolean temp_02_0014 = x != null;
if (null != x)
x.freeRef();
return temp_02_0014;
}).collect(RefCollectors.toList());
if (deltas.isEmpty() && !stateList.isEmpty()) {
stateList.freeRef();
AssertionError temp_02_0015 = new AssertionError(
"Nonfrozen component not listed in evalInputDelta. Deltas: " + deltas);
deltas.freeRef();
throw temp_02_0015;
}
deltas.freeRef();
stateList.freeRef();
if (!reachedInputFeedback.get()) {
throw new RuntimeException("Nonfrozen component did not pass input backwards");
}
}
@Nonnull
@Override
public String toString() {
return "BatchDerivativeTester{" + "probeSize=" + probeSize + ", batches=" + batches + ", tolerance=" + tolerance
+ ", testFeedback=" + testFeedback + ", testLearning=" + testLearning + ", verbose=" + verbose + ", verify="
+ verify + '}';
}
public @SuppressWarnings("unused")
void _free() {
super._free();
}
@Nonnull
public @Override
@SuppressWarnings("unused")
BatchDerivativeTester addRef() {
return (BatchDerivativeTester) super.addRef();
}
@Nonnull
private Tensor getFeedbackGradient(@Nonnull final Layer component, final int inputIndex,
@Nonnull final Tensor outputPrototype, @Nullable final Tensor... inputPrototype) {
assert inputPrototype != null;
final Tensor inputTensor = inputPrototype[inputIndex].addRef();
final int inputDims = inputTensor.length();
@Nonnull final Tensor result = new Tensor(inputDims, outputPrototype.length());
for (int j = 0; j < outputPrototype.length(); j++) {
final int j_ = j;
@Nonnull final PlaceholderLayer inputKey = new PlaceholderLayer(new Tensor());
Result.Accumulator accumulator = new Result.Accumulator() {
{
inputTensor.addRef();
outputPrototype.addRef();
inputKey.addRef();
}
@Override
public void accept(@Nonnull DeltaSet buffer, @Nonnull TensorList data) {
@Nonnull final Tensor gradientBuffer = new Tensor(inputDims, outputPrototype.length());
Tensor tensor = data.get(inputIndex);
if (!RefArrays.equals(inputTensor.getDimensions(), tensor.getDimensions())) {
buffer.freeRef();
data.freeRef();
tensor.freeRef();
gradientBuffer.freeRef();
throw new AssertionError();
}
for (int i = 0; i < inputDims; i++) {
gradientBuffer.set(new int[]{i, j_}, tensor.get(i));
}
tensor.freeRef();
Delta delta = buffer.get(inputKey.getId(), new double[gradientBuffer.length()]);
assert delta != null;
delta.addInPlace(gradientBuffer);
delta.freeRef();
buffer.freeRef();
data.freeRef();
}
@Override
public void _free() {
inputTensor.freeRef();
outputPrototype.freeRef();
inputKey.freeRef();
super._free();
}
};
@Nonnull final Result copyInput = new Result(new TensorArray(RefUtil.addRef(inputPrototype)), accumulator, true);
@Nullable final Result eval = component.eval(copyInput);
@Nonnull final DeltaSet xxx = new DeltaSet();
assert eval != null;
TensorList temp_02_0045 = eval.getData();
@Nonnull
TensorArray tensorArray = new TensorArray(temp_02_0045.stream().map(x -> {
x.set(j_, 1);
Tensor temp_02_0016 = x.addRef();
x.freeRef();
return temp_02_0016;
}).toArray(Tensor[]::new));
temp_02_0045.freeRef();
eval.accumulate(xxx.addRef(), tensorArray);
eval.freeRef();
final Delta inputDelta = xxx.get(inputKey.getId());
xxx.freeRef();
if (null != inputDelta) {
result.addInPlace(new Tensor(inputDelta.getDelta(), result.getDimensions()));
}
if (null != inputKey)
inputKey.freeRef();
if (null != inputDelta)
inputDelta.freeRef();
}
RefUtil.freeRef(inputPrototype);
outputPrototype.freeRef();
component.freeRef();
inputTensor.freeRef();
return result;
}
@Nonnull
private Tensor getLearningGradient(@Nonnull final Layer component, final int layerNum,
@Nonnull final Tensor outputPrototype, @Nullable final Tensor... inputPrototype) {
component.setFrozen(false);
RefList temp_02_0047 = component.state();
assert temp_02_0047 != null;
final double[] stateArray = temp_02_0047.get(layerNum);
temp_02_0047.freeRef();
final int stateLen = stateArray.length;
@Nonnull final Tensor gradient = new Tensor(stateLen, outputPrototype.length());
for (int j = 0; j < outputPrototype.length(); j++) {
final int j_ = j;
@Nonnull final DeltaSet buffer = new DeltaSet();
Tensor temp_02_0020 = new Tensor(outputPrototype.getDimensions());
temp_02_0020.set(k -> k == j_ ? 1 : 0);
@Nonnull final Tensor data = temp_02_0020.addRef();
temp_02_0020.freeRef();
@Nullable final Result eval = component
.eval(ConstantResult.singleResultArray(new Tensor[][]{RefUtil.addRef(inputPrototype)}));
assert eval != null;
TensorList temp_02_0048 = eval.getData();
RefUtil.freeRef(temp_02_0048.get(0));
temp_02_0048.freeRef();
@Nonnull
TensorArray tensorArray = new TensorArray(data);
eval.accumulate(buffer.addRef(), tensorArray);
eval.freeRef();
RefMap> temp_02_0049 = buffer.getMap();
RefCollection> temp_02_0050 = temp_02_0049.values();
Optional> temp_02_0051 = temp_02_0050.stream().filter(x -> {
boolean temp_02_0017 = x.target == stateArray;
x.freeRef();
return temp_02_0017;
}).findFirst();
final DoubleBuffer deltaFlushBuffer = temp_02_0051.orElse(null);
RefUtil.freeRef(temp_02_0051);
temp_02_0050.freeRef();
temp_02_0049.freeRef();
buffer.freeRef();
if (null != deltaFlushBuffer) {
for (int i = 0; i < stateLen; i++) {
gradient.set(new int[]{i, j_}, deltaFlushBuffer.getDelta()[i]);
}
}
if (null != deltaFlushBuffer)
deltaFlushBuffer.freeRef();
}
if (null != inputPrototype)
RefUtil.freeRef(inputPrototype);
outputPrototype.freeRef();
component.freeRef();
return gradient;
}
@Nonnull
private Tensor measureFeedbackGradient(@Nonnull final Layer component, final int inputIndex,
@Nonnull final Tensor outputPrototype, @Nonnull final Tensor... inputPrototype) {
@Nonnull final Tensor measuredGradient = new Tensor(inputPrototype[inputIndex].length(), outputPrototype.length());
Result temp_02_0052 = component
.eval(ConstantResult.singleResultArray(new Tensor[][]{RefUtil.addRef(inputPrototype)}));
assert temp_02_0052 != null;
TensorList temp_02_0053 = temp_02_0052.getData();
@Nullable final Tensor baseOutput = temp_02_0053.get(0);
temp_02_0053.freeRef();
temp_02_0052.freeRef();
outputPrototype.set(baseOutput.addRef());
outputPrototype.freeRef();
for (int i = 0; i < inputPrototype[inputIndex].length(); i++) {
@Nonnull final Tensor inputProbe = inputPrototype[inputIndex].copy();
inputProbe.add(i, probeSize * 1);
@Nonnull final Tensor[] copyInput = RefArrays.copyOf(RefUtil.addRef(inputPrototype), inputPrototype.length);
RefUtil.set(copyInput, inputIndex, inputProbe);
Result temp_02_0054 = component
.eval(ConstantResult.singleResultArray(new Tensor[][]{copyInput}));
assert temp_02_0054 != null;
TensorList temp_02_0055 = temp_02_0054.getData();
@Nullable final Tensor evalProbe = temp_02_0055.get(0);
temp_02_0055.freeRef();
temp_02_0054.freeRef();
Tensor temp_02_0056 = evalProbe.minus(baseOutput.addRef());
temp_02_0056.scaleInPlace(1. / probeSize);
@Nonnull final Tensor delta = temp_02_0056.addRef();
temp_02_0056.freeRef();
evalProbe.freeRef();
for (int j = 0; j < delta.length(); j++) {
measuredGradient.set(new int[]{i, j}, delta.get(j));
}
delta.freeRef();
}
RefUtil.freeRef(inputPrototype);
component.freeRef();
baseOutput.freeRef();
return measuredGradient;
}
@Nonnull
private Tensor measureLearningGradient(@Nonnull final Layer component, final int layerNum,
@Nonnull final Tensor outputPrototype, @Nullable final Tensor... inputPrototype) {
RefList temp_02_0057 = component.state();
assert temp_02_0057 != null;
double[] doubles = temp_02_0057.get(layerNum);
final int stateLen = doubles.length;
temp_02_0057.freeRef();
@Nonnull final Tensor gradient = new Tensor(stateLen, outputPrototype.length());
outputPrototype.freeRef();
Result temp_02_0058 = component
.eval(ConstantResult.singleResultArray(new Tensor[][]{RefUtil.addRef(inputPrototype)}));
assert temp_02_0058 != null;
TensorList temp_02_0059 = temp_02_0058.getData();
@Nullable final Tensor baseOutput = temp_02_0059.get(0);
temp_02_0059.freeRef();
temp_02_0058.freeRef();
for (int i = 0; i < stateLen; i++) {
@Nonnull final Layer copy = component.copy();
RefList temp_02_0060 = copy.state();
assert temp_02_0060 != null;
double[] doubles1 = temp_02_0060.get(layerNum);
doubles1[i] += probeSize;
temp_02_0060.freeRef();
Result temp_02_0061 = copy
.eval(ConstantResult.singleResultArray(new Tensor[][]{RefUtil.addRef(inputPrototype)}));
assert temp_02_0061 != null;
TensorList temp_02_0062 = temp_02_0061.getData();
@Nullable final Tensor evalProbe = temp_02_0062.get(0);
temp_02_0062.freeRef();
temp_02_0061.freeRef();
copy.freeRef();
Tensor delta = evalProbe.minus(baseOutput.addRef());
delta.scaleInPlace(1. / probeSize);
evalProbe.freeRef();
for (int j = 0; j < delta.length(); j++) {
gradient.set(new int[]{i, j}, delta.get(j));
}
delta.freeRef();
}
if (null != inputPrototype)
RefUtil.freeRef(inputPrototype);
component.freeRef();
baseOutput.freeRef();
return gradient;
}
private static class IOPair extends ReferenceCountingBase {
@Nullable
private final Layer component;
@Nonnull
private final Tensor tensor;
@Nullable
private final BatchDerivativeTester parent;
@Nullable
private Tensor[] inputPrototype;
@Nullable
private Tensor outputPrototype;
/**
* Instantiates a new Io pair.
*
* @param component the component
* @param tensor the tensor
* @param parent the parent
*/
public IOPair(@Nullable Layer component, @Nullable Tensor tensor, @Nullable BatchDerivativeTester parent) {
this.component = component;
this.tensor = tensor;
this.parent = parent;
}
/**
* Get input prototype tensor [ ].
*
* @return the tensor [ ]
*/
@Nullable
public Tensor[] getInputPrototype() {
return RefUtil.addRef(inputPrototype);
}
/**
* Gets output prototype.
*
* @return the output prototype
*/
@Nullable
public Tensor getOutputPrototype() {
return outputPrototype == null ? null : outputPrototype.addRef();
}
/**
* Invoke.
*/
public void invoke() {
assert parent != null;
if (null != inputPrototype)
RefUtil.freeRef(inputPrototype);
inputPrototype = RefIntStream.range(0, parent.batches)
.mapToObj(i -> tensor.copy())
.toArray(Tensor[]::new);
SimpleEval simpleEval = SimpleEval.run(component == null ? null : component.addRef(),
inputPrototype[0].addRef());
if (null != outputPrototype)
outputPrototype.freeRef();
outputPrototype = simpleEval.getOutput();
simpleEval.freeRef();
}
public @SuppressWarnings("unused")
void _free() {
super._free();
if (null != outputPrototype)
outputPrototype.freeRef();
outputPrototype = null;
if (null != inputPrototype)
RefUtil.freeRef(inputPrototype);
inputPrototype = null;
if (null != parent)
parent.freeRef();
tensor.freeRef();
if (null != component)
component.freeRef();
}
@Nonnull
public @Override
@SuppressWarnings("unused")
IOPair addRef() {
return (IOPair) super.addRef();
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy