
com.simiacryptus.mindseye.test.integration.EncodingProblem Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mindseye-test Show documentation
Show all versions of mindseye-test Show documentation
Testing Tools for Neural Network Components
The newest version!
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.test.integration;
import com.simiacryptus.lang.UncheckedSupplier;
import com.simiacryptus.mindseye.eval.ArrayTrainable;
import com.simiacryptus.mindseye.eval.SampledArrayTrainable;
import com.simiacryptus.mindseye.lang.Result;
import com.simiacryptus.mindseye.lang.Tensor;
import com.simiacryptus.mindseye.lang.TensorList;
import com.simiacryptus.mindseye.network.DAGNetwork;
import com.simiacryptus.mindseye.network.PipelineNetwork;
import com.simiacryptus.mindseye.opt.TrainingMonitor;
import com.simiacryptus.mindseye.opt.ValidatingTrainer;
import com.simiacryptus.mindseye.test.GraphVizNetworkInspector;
import com.simiacryptus.mindseye.test.StepRecord;
import com.simiacryptus.mindseye.test.TestUtil;
import com.simiacryptus.mindseye.util.ImageUtil;
import com.simiacryptus.notebook.NotebookOutput;
import com.simiacryptus.notebook.TableOutput;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.wrappers.RefArrays;
import com.simiacryptus.ref.wrappers.RefDoubleStream;
import com.simiacryptus.ref.wrappers.RefList;
import com.simiacryptus.util.Util;
import com.simiacryptus.util.data.ScalarStatistics;
import guru.nidi.graphviz.engine.Format;
import guru.nidi.graphviz.engine.Graphviz;
import org.jetbrains.annotations.NotNull;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
/**
* The type Encoding problem.
*/
public abstract class EncodingProblem implements Problem {
private static int modelNo = 0;
private final ImageProblemData data;
private final List history = new ArrayList<>();
private final OptimizationStrategy optimizer;
private final RevNetworkFactory revFactory;
private int batchSize = 10000;
private int features;
private int timeoutMinutes = 1;
private int trainingSize = 15000;
/**
* Instantiates a new Encoding problem.
*
* @param revFactory the rev factory
* @param optimizer the optimizer
* @param data the data
* @param features the features
*/
public EncodingProblem(final RevNetworkFactory revFactory, final OptimizationStrategy optimizer,
final ImageProblemData data, final int features) {
this.revFactory = revFactory;
this.optimizer = optimizer;
this.data = data;
this.features = features;
}
/**
* Gets batch size.
*
* @return the batch size
*/
public int getBatchSize() {
return batchSize;
}
/**
* Sets batch size.
*
* @param batchSize the batch size
* @return the batch size
*/
@Nonnull
public EncodingProblem setBatchSize(final int batchSize) {
this.batchSize = batchSize;
return this;
}
/**
* Gets features.
*
* @return the features
*/
public int getFeatures() {
return features;
}
/**
* Sets features.
*
* @param features the features
* @return the features
*/
@Nonnull
public EncodingProblem setFeatures(final int features) {
this.features = features;
return this;
}
@Nonnull
@Override
public List getHistory() {
return history;
}
/**
* Get tensors tensor [ ] [ ].
*
* @return the tensor [ ] [ ]
*/
@NotNull
public Tensor[][] getTensors() {
try {
return data.trainingData().map(labeledObject -> {
Tensor tensor = new Tensor(features);
tensor.set(this::random);
Tensor data = labeledObject.data;
labeledObject.freeRef();
return new Tensor[]{tensor, data};
}).toArray(Tensor[][]::new);
} catch (@Nonnull final IOException e) {
throw Util.throwException(e);
}
}
/**
* Gets timeout minutes.
*
* @return the timeout minutes
*/
public int getTimeoutMinutes() {
return timeoutMinutes;
}
/**
* Sets timeout minutes.
*
* @param timeoutMinutes the timeout minutes
* @return the timeout minutes
*/
@Nonnull
public EncodingProblem setTimeoutMinutes(final int timeoutMinutes) {
this.timeoutMinutes = timeoutMinutes;
return this;
}
/**
* Gets training size.
*
* @return the training size
*/
public int getTrainingSize() {
return trainingSize;
}
/**
* Sets training size.
*
* @param trainingSize the training size
* @return the training size
*/
@Nonnull
public EncodingProblem setTrainingSize(final int trainingSize) {
this.trainingSize = trainingSize;
return this;
}
/**
* Random double.
*
* @return the double
*/
public double random() {
return 0.1 * (Math.random() - 0.5);
}
@Nonnull
@Override
public EncodingProblem run(@Nonnull final NotebookOutput log) {
@Nonnull final TrainingMonitor monitor = TestUtil.getMonitor(history);
Tensor[][] trainingData = getTensors();
@Nonnull final DAGNetwork imageNetwork = revFactory.vectorToImage(log, features);
log.h3("Network Diagram");
log.eval(RefUtil.wrapInterface((UncheckedSupplier) () -> {
return Graphviz.fromGraph(GraphVizNetworkInspector.toGraphviz(imageNetwork.addRef()))
.height(400).width(600).render(Format.PNG).toImage();
}, imageNetwork.addRef()));
final DAGNetwork trainingNetwork = trainingNetwork(imageNetwork.addRef());
log.h3("Training");
log.p("We start by training apply a very small population to improve initial convergence performance:");
TestUtil.instrumentPerformance(trainingNetwork.addRef());
@Nonnull final Tensor[][] primingData = RefArrays.copyOfRange(RefUtil.addRef(trainingData), 0, 1000);
SampledArrayTrainable temp_20_0004 = new SampledArrayTrainable(RefUtil.addRef(primingData),
trainingNetwork.addRef(), trainingSize, batchSize);
temp_20_0004.setMinSamples(trainingSize);
temp_20_0004.setMask(true, false);
@Nonnull final ValidatingTrainer preTrainer = optimizer.train(log, temp_20_0004,
new ArrayTrainable(RefUtil.addRef(primingData), trainingNetwork.addRef(),
batchSize),
monitor);
RefUtil.freeRef(primingData);
log.run(RefUtil.wrapInterface(() -> {
preTrainer.setTimeout(timeoutMinutes / 2, TimeUnit.MINUTES);
ValidatingTrainer temp_20_0008 = preTrainer.addRef();
temp_20_0008.setMaxIterations(batchSize);
ValidatingTrainer temp_20_0009 = temp_20_0008.addRef();
temp_20_0009.run();
temp_20_0009.freeRef();
temp_20_0008.freeRef();
}, preTrainer));
TestUtil.extractPerformance(log, trainingNetwork.addRef());
log.p("Then our main training phase:");
TestUtil.instrumentPerformance(trainingNetwork.addRef());
SampledArrayTrainable temp_20_0005 = new SampledArrayTrainable(RefUtil.addRef(trainingData),
trainingNetwork.addRef(), trainingSize, batchSize);
temp_20_0005.setMinSamples(trainingSize);
temp_20_0005.setMask(true, false);
@Nonnull final ValidatingTrainer mainTrainer = optimizer.train(log, temp_20_0005,
new ArrayTrainable(RefUtil.addRef(trainingData), trainingNetwork.addRef(),
batchSize),
monitor);
log.run(RefUtil.wrapInterface(() -> {
mainTrainer.setTimeout(timeoutMinutes, TimeUnit.MINUTES);
ValidatingTrainer temp_20_0011 = mainTrainer.addRef();
temp_20_0011.setMaxIterations(batchSize);
ValidatingTrainer temp_20_0012 = temp_20_0011.addRef();
temp_20_0012.run();
temp_20_0012.freeRef();
temp_20_0011.freeRef();
}, mainTrainer));
TestUtil.extractPerformance(log, trainingNetwork.addRef());
if (!history.isEmpty()) {
log.eval(() -> {
return TestUtil.plot(history);
});
log.eval(() -> {
return TestUtil.plotTime(history);
});
}
try {
@Nonnull
String filename = log.getFileName() + EncodingProblem.modelNo++ + "_plot.png";
ImageIO.write(Util.toImage(TestUtil.plot(history)), "png", log.file(filename));
log.addMetadata("result_plot", filename);
} catch (IOException e) {
throw Util.throwException(e);
}
//log.file()
@Nonnull final String modelName = "encoding_model_" + EncodingProblem.modelNo++ + ".json";
log.addMetadata("result_model", modelName);
log.p("Saved model as " + log.file(trainingNetwork.getJson().toString(), modelName, modelName));
log.h3("Results");
@Nonnull final PipelineNetwork testNetwork = new PipelineNetwork(2);
RefUtil.freeRef(testNetwork.add(imageNetwork.addRef(), testNetwork.getInput(0)));
log.eval(RefUtil.wrapInterface((UncheckedSupplier) () -> {
@Nonnull final TableOutput table = new TableOutput();
RefArrays.stream(RefUtil.addRef(trainingData)).map(RefUtil
.wrapInterface((Function super Tensor[], ? extends LinkedHashMap>) tensorArray -> {
Result temp_20_0013 = testNetwork.eval(RefUtil.addRef(tensorArray));
assert temp_20_0013 != null;
TensorList temp_20_0014 = temp_20_0013.getData();
@Nullable final Tensor predictionSignal = temp_20_0014.get(0);
temp_20_0014.freeRef();
temp_20_0013.freeRef();
@Nonnull final LinkedHashMap row = new LinkedHashMap<>();
row.put("Source", log.png(tensorArray[1].toImage(), ""));
RefUtil.freeRef(tensorArray);
row.put("Echo", log.png(predictionSignal.toImage(), ""));
predictionSignal.freeRef();
return row;
}, testNetwork.addRef())).filter(Objects::nonNull).limit(10)
.forEach(table::putRow);
return table;
}, testNetwork, RefUtil.addRef(trainingData)));
log.p("Learned Model Statistics:");
RefUtil.freeRef(log.eval(RefUtil.wrapInterface((UncheckedSupplier
© 2015 - 2025 Weber Informatics LLC | Privacy Policy