
com.simiacryptus.mindseye.test.integration.AutoencodingProblem Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mindseye-test Show documentation
Show all versions of mindseye-test Show documentation
Testing Tools for Neural Network Components
The newest version!
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.test.integration;
import com.simiacryptus.lang.UncheckedSupplier;
import com.simiacryptus.mindseye.eval.ArrayTrainable;
import com.simiacryptus.mindseye.eval.SampledArrayTrainable;
import com.simiacryptus.mindseye.lang.Layer;
import com.simiacryptus.mindseye.lang.Result;
import com.simiacryptus.mindseye.lang.Tensor;
import com.simiacryptus.mindseye.lang.TensorList;
import com.simiacryptus.mindseye.layers.StochasticComponent;
import com.simiacryptus.mindseye.network.DAGNetwork;
import com.simiacryptus.mindseye.network.PipelineNetwork;
import com.simiacryptus.mindseye.opt.Step;
import com.simiacryptus.mindseye.opt.TrainingMonitor;
import com.simiacryptus.mindseye.opt.ValidatingTrainer;
import com.simiacryptus.mindseye.test.GraphVizNetworkInspector;
import com.simiacryptus.mindseye.test.StepRecord;
import com.simiacryptus.mindseye.test.TestUtil;
import com.simiacryptus.notebook.NotebookOutput;
import com.simiacryptus.notebook.TableOutput;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.util.Util;
import com.simiacryptus.util.test.LabeledObject;
import guru.nidi.graphviz.engine.Format;
import guru.nidi.graphviz.engine.Graphviz;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
/**
* The type Autoencoding problem.
*/
@SuppressWarnings("FieldCanBeLocal")
public abstract class AutoencodingProblem implements Problem {
private static int modelNo = 0;
private final int batchSize = 10000;
private final ImageProblemData data;
private final double dropout;
private final int features;
private final FwdNetworkFactory fwdFactory;
@Nonnull
private final List history = new ArrayList<>();
private final OptimizationStrategy optimizer;
private final RevNetworkFactory revFactory;
private int timeoutMinutes = 1;
/**
* Instantiates a new Autoencoding problem.
*
* @param fwdFactory the fwd factory
* @param optimizer the optimizer
* @param revFactory the rev factory
* @param data the data
* @param features the features
* @param dropout the dropout
*/
public AutoencodingProblem(final FwdNetworkFactory fwdFactory, final OptimizationStrategy optimizer,
final RevNetworkFactory revFactory, final ImageProblemData data, final int features, final double dropout) {
this.fwdFactory = fwdFactory;
this.optimizer = optimizer;
this.revFactory = revFactory;
this.data = data;
this.features = features;
this.dropout = dropout;
}
@Nonnull
@Override
public List getHistory() {
return history;
}
/**
* Gets timeout minutes.
*
* @return the timeout minutes
*/
public int getTimeoutMinutes() {
return timeoutMinutes;
}
/**
* Sets timeout minutes.
*
* @param timeoutMinutes the timeout minutes
* @return the timeout minutes
*/
@Nonnull
public AutoencodingProblem setTimeoutMinutes(final int timeoutMinutes) {
this.timeoutMinutes = timeoutMinutes;
return this;
}
/**
* Get training data tensor [ ] [ ].
*
* @return the tensor [ ] [ ]
*/
@Nonnull
public Tensor[][] getTrainingData() {
try {
return data.trainingData().map(labeledObject -> {
Tensor[] tensors = {labeledObject.data};
labeledObject.freeRef();
return tensors;
}).toArray(Tensor[][]::new);
} catch (@Nonnull final IOException e) {
throw Util.throwException(e);
}
}
/**
* Parse int.
*
* @param label the label
* @return the int
*/
public int parse(@Nonnull final String label) {
return Integer.parseInt(label.replaceAll("[^\\d]", ""));
}
@Nonnull
@Override
public AutoencodingProblem run(@Nonnull final NotebookOutput log) {
@Nonnull final DAGNetwork fwdNetwork = fwdFactory.imageToVector(log, features);
@Nonnull final DAGNetwork revNetwork = revFactory.vectorToImage(log, features);
@Nonnull final PipelineNetwork echoNetwork = new PipelineNetwork(1);
RefUtil.freeRef(echoNetwork.add(fwdNetwork.addRef()));
RefUtil.freeRef(echoNetwork.add(revNetwork.addRef()));
@Nonnull final PipelineNetwork supervisedNetwork = new PipelineNetwork(1);
RefUtil.freeRef(supervisedNetwork.add(fwdNetwork.addRef()));
@Nonnull final StochasticComponent dropoutNoiseLayer = dropout(dropout);
RefUtil.freeRef(supervisedNetwork.add(dropoutNoiseLayer));
RefUtil.freeRef(supervisedNetwork.add(revNetwork.addRef()));
RefUtil.freeRef(supervisedNetwork.add(lossLayer(), supervisedNetwork.getHead(), supervisedNetwork.getInput(0)));
log.h3("Network Diagrams");
log.eval(RefUtil.wrapInterface((UncheckedSupplier) () -> {
return Graphviz.fromGraph(GraphVizNetworkInspector.toGraphviz(fwdNetwork.addRef())).height(400)
.width(600).render(Format.PNG).toImage();
}, fwdNetwork.addRef()));
log.eval(RefUtil.wrapInterface((UncheckedSupplier) () -> {
return Graphviz.fromGraph(GraphVizNetworkInspector.toGraphviz(revNetwork.addRef())).height(400)
.width(600).render(Format.PNG).toImage();
}, revNetwork.addRef()));
log.eval(RefUtil.wrapInterface((UncheckedSupplier) () -> {
return Graphviz.fromGraph(GraphVizNetworkInspector.toGraphviz(supervisedNetwork.addRef()))
.height(400).width(600).render(Format.PNG).toImage();
}, supervisedNetwork.addRef()));
@Nonnull final TrainingMonitor monitor = new TrainingMonitor() {
@Nonnull
final TrainingMonitor inner = TestUtil.getMonitor(history);
@Override
public void log(final String msg) {
inner.log(msg);
}
@Override
public void onStepComplete(@Nullable final Step currentPoint) {
inner.onStepComplete(currentPoint == null ? null : currentPoint.addRef());
if (null != currentPoint)
currentPoint.freeRef();
}
};
final Tensor[][] trainingData = getTrainingData();
//MonitoredObject monitoringRoot = new MonitoredObject();
//TestUtil.addMonitoring(supervisedNetwork, monitoringRoot);
log.h3("Training");
TestUtil.instrumentPerformance(supervisedNetwork.addRef());
@Nonnull final ValidatingTrainer trainer = optimizer.train(log,
new SampledArrayTrainable(RefUtil.addRef(trainingData),
supervisedNetwork.addRef(), trainingData.length / 2, batchSize),
new ArrayTrainable(RefUtil.addRef(trainingData), supervisedNetwork.addRef(),
batchSize),
monitor);
RefUtil.freeRef(trainingData);
log.run(RefUtil.wrapInterface(() -> {
trainer.setTimeout(timeoutMinutes, TimeUnit.MINUTES);
ValidatingTrainer temp_21_0003 = trainer.addRef();
temp_21_0003.setMaxIterations(10000);
ValidatingTrainer temp_21_0004 = temp_21_0003.addRef();
temp_21_0004.run();
temp_21_0004.freeRef();
temp_21_0003.freeRef();
}, trainer));
if (!history.isEmpty()) {
log.eval(() -> {
return TestUtil.plot(history);
});
log.eval(() -> {
return TestUtil.plotTime(history);
});
}
TestUtil.extractPerformance(log, supervisedNetwork);
{
@Nonnull final String modelName = "encoder_model" + AutoencodingProblem.modelNo++ + ".json";
log.p("Saved model as " + log.file(fwdNetwork.getJson().toString(), modelName, modelName));
}
fwdNetwork.freeRef();
@Nonnull final String modelName = "decoder_model" + AutoencodingProblem.modelNo++ + ".json";
log.p("Saved model as " + log.file(revNetwork.getJson().toString(), modelName, modelName));
// log.h3("Metrics");
// log.run(() -> {
// return TestUtil.toFormattedJson(monitoringRoot.getMetrics());
// });
log.h3("Validation");
log.p("Here are some re-encoded examples:");
log.eval(RefUtil.wrapInterface((UncheckedSupplier) () -> {
@Nonnull final TableOutput table = new TableOutput();
data.validationData().map(RefUtil.wrapInterface(
(Function super LabeledObject, ? extends LinkedHashMap>) labeledObject -> {
Result temp_21_0006 = echoNetwork.eval(labeledObject.data.addRef());
assert temp_21_0006 != null;
TensorList data = temp_21_0006.getData();
Tensor tensor = data.get(0);
LinkedHashMap row = toRow(log, labeledObject, tensor.getData());
tensor.freeRef();
data.freeRef();
temp_21_0006.freeRef();
return row;
}, echoNetwork.addRef())).filter(Objects::nonNull).limit(10)
.forEach(table::putRow);
return table;
}, echoNetwork));
log.p("Some rendered unit vectors:");
for (int featureNumber = 0; featureNumber < features; featureNumber++) {
Tensor temp_21_0001 = new Tensor(features);
temp_21_0001.set(featureNumber, 1);
@Nonnull final Tensor input = temp_21_0001.addRef();
temp_21_0001.freeRef();
Result temp_21_0007 = revNetwork.eval(input);
assert temp_21_0007 != null;
TensorList temp_21_0008 = temp_21_0007.getData();
@Nullable final Tensor tensor = temp_21_0008.get(0);
temp_21_0008.freeRef();
temp_21_0007.freeRef();
log.out(log.png(tensor.toImage(), ""));
tensor.freeRef();
}
revNetwork.freeRef();
return this;
}
/**
* To row linked hash map.
*
* @param log the log
* @param labeledObject the labeled object
* @param predictionSignal the prediction signal
* @return the linked hash map
*/
@Nonnull
public LinkedHashMap toRow(@Nonnull final NotebookOutput log,
@Nonnull final LabeledObject labeledObject, final double[] predictionSignal) {
@Nonnull final LinkedHashMap row = new LinkedHashMap<>();
row.put("Image", log.png(labeledObject.data.toImage(), labeledObject.label));
Tensor temp_21_0002 = new Tensor(predictionSignal, labeledObject.data.getDimensions());
row.put("Echo", log.png(temp_21_0002.toImage(), labeledObject.label));
labeledObject.freeRef();
temp_21_0002.freeRef();
return row;
}
/**
* Loss layer layer.
*
* @return the layer
*/
@Nonnull
protected abstract Layer lossLayer();
/**
* Dropout stochastic component.
*
* @param dropout the dropout
* @return the stochastic component
*/
@Nonnull
protected abstract StochasticComponent dropout(double dropout);
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy