com.simiacryptus.mindseye.test.integration.ClassifyProblem Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mindseye-test Show documentation
Show all versions of mindseye-test Show documentation
Testing Tools for Neural Network Components
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.test.integration;
import com.google.common.collect.Lists;
import com.simiacryptus.mindseye.eval.ArrayTrainable;
import com.simiacryptus.mindseye.eval.SampledArrayTrainable;
import com.simiacryptus.mindseye.lang.*;
import com.simiacryptus.mindseye.network.DAGNetwork;
import com.simiacryptus.mindseye.network.SimpleLossNetwork;
import com.simiacryptus.mindseye.opt.TrainingMonitor;
import com.simiacryptus.mindseye.opt.ValidatingTrainer;
import com.simiacryptus.mindseye.test.StepRecord;
import com.simiacryptus.mindseye.test.TestUtil;
import com.simiacryptus.notebook.NotebookOutput;
import com.simiacryptus.notebook.TableOutput;
import com.simiacryptus.util.Util;
import com.simiacryptus.util.test.LabeledObject;
import guru.nidi.graphviz.engine.Format;
import guru.nidi.graphviz.engine.Graphviz;
import guru.nidi.graphviz.model.Graph;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
/**
* The type Mnist apply base.
*/
public abstract class ClassifyProblem implements Problem {
private static final Logger logger = LoggerFactory.getLogger(ClassifyProblem.class);
private static int modelNo = 0;
private final int categories;
private final ImageProblemData data;
private final FwdNetworkFactory fwdFactory;
private final List history = new ArrayList<>();
private final OptimizationStrategy optimizer;
private final List labels;
private int batchSize = 10000;
private int timeoutMinutes = 1;
/**
* Instantiates a new Classify problem.
*
* @param fwdFactory the fwd factory
* @param optimizer the optimizer
* @param data the data
* @param categories the categories
*/
public ClassifyProblem(final FwdNetworkFactory fwdFactory, final OptimizationStrategy optimizer, final ImageProblemData data, final int categories) {
this.fwdFactory = fwdFactory;
this.optimizer = optimizer;
this.data = data;
this.categories = categories;
try {
this.labels = Stream.concat(this.data.trainingData(), this.data.validationData()).map(x -> x.label).distinct().sorted().collect(Collectors.toList());
} catch (IOException e) {
throw new RuntimeException(e);
}
}
@Nonnull
@Override
public List getHistory() {
return history;
}
/**
* Gets timeout minutes.
*
* @return the timeout minutes
*/
public int getTimeoutMinutes() {
return timeoutMinutes;
}
/**
* Sets timeout minutes.
*
* @param timeoutMinutes the timeout minutes
* @return the timeout minutes
*/
@Nonnull
public ClassifyProblem setTimeoutMinutes(final int timeoutMinutes) {
this.timeoutMinutes = timeoutMinutes;
return this;
}
/**
* Get training data tensor [ ] [ ].
*
* @param log the log
* @return the tensor [ ] [ ]
*/
public Tensor[][] getTrainingData(final NotebookOutput log) {
try {
return data.trainingData().map(labeledObject -> {
@Nonnull final Tensor categoryTensor = new Tensor(categories);
final int category = parse(labeledObject.label);
categoryTensor.set(category, 1);
return new Tensor[]{labeledObject.data, categoryTensor};
}).toArray(i -> new Tensor[i][]);
} catch (@Nonnull final IOException e) {
throw new RuntimeException(e);
}
}
/**
* Parse int.
*
* @param label the label
* @return the int
*/
public int parse(final CharSequence label) {
return this.labels.indexOf(label);
}
/**
* Predict int [ ].
*
* @param network the network
* @param labeledObject the labeled object
* @return the int [ ]
*/
public int[] predict(@Nonnull final Layer network, @Nonnull final LabeledObject labeledObject) {
@Nullable final double[] predictionSignal = network.eval(labeledObject.data).getData().get(0).getData();
return IntStream.range(0, categories).mapToObj(x -> x).sorted(Comparator.comparing(i -> -predictionSignal[i])).mapToInt(x -> x).toArray();
}
@Nonnull
@Override
public ClassifyProblem run(@Nonnull final NotebookOutput log) {
@Nonnull final TrainingMonitor monitor = TestUtil.getMonitor(history);
final Tensor[][] trainingData = getTrainingData(log);
@Nonnull final DAGNetwork network = fwdFactory.imageToVector(log, categories);
log.h3("Network Diagram");
log.eval(() -> {
return Graphviz.fromGraph((Graph) TestUtil.toGraph(network))
.height(400).width(600).render(Format.PNG).toImage();
});
log.h3("Training");
@Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, lossLayer());
TestUtil.instrumentPerformance(supervisedNetwork);
int initialSampleSize = Math.max(trainingData.length / 5, Math.min(10, trainingData.length / 2));
@Nonnull final ValidatingTrainer trainer = optimizer.train(log,
new SampledArrayTrainable(trainingData, supervisedNetwork, initialSampleSize, getBatchSize()),
new ArrayTrainable(trainingData, supervisedNetwork, getBatchSize()), monitor);
log.run(() -> {
trainer.setTimeout(timeoutMinutes, TimeUnit.MINUTES).setMaxIterations(10000).run();
});
if (!history.isEmpty()) {
log.eval(() -> {
return TestUtil.plot(history);
});
log.eval(() -> {
return TestUtil.plotTime(history);
});
}
@Nonnull String training_name = log.getName() + "_" + ClassifyProblem.modelNo++ + "_plot.png";
try {
BufferedImage image = Util.toImage(TestUtil.plot(history));
if (null != image) ImageIO.write(image, "png", log.file(training_name));
} catch (IOException e) {
logger.warn("Error writing result images", e);
}
log.appendFrontMatterProperty("result_plot", new File(log.getResourceDir(), training_name).toString(), ";");
TestUtil.extractPerformance(log, supervisedNetwork);
@Nonnull final String modelName = "classification_model_" + ClassifyProblem.modelNo++ + ".json";
log.appendFrontMatterProperty("result_model", modelName, ";");
log.p("Saved model as " + log.file(network.getJson().toString(), modelName, modelName));
log.h3("Validation");
log.p("If we apply our model against the entire validation dataset, we get this accuracy:");
log.eval(() -> {
return data.validationData().mapToDouble(labeledObject ->
predict(network, labeledObject)[0] == parse(labeledObject.label) ? 1 : 0)
.average().getAsDouble() * 100;
});
log.p("Let's examine some incorrectly predicted results in more detail:");
log.eval(() -> {
try {
@Nonnull final TableOutput table = new TableOutput();
Lists.partition(data.validationData().collect(Collectors.toList()), 100).stream().flatMap(batch -> {
@Nonnull TensorList batchIn = TensorArray.create(batch.stream().map(x -> x.data).toArray(i -> new Tensor[i]));
TensorList batchOut = network.eval(new ConstantResult(batchIn)).getData();
return IntStream.range(0, batchOut.length())
.mapToObj(i -> toRow(log, batch.get(i), batchOut.get(i).getData()));
}).filter(x -> null != x).limit(10).forEach(table::putRow);
return table;
} catch (@Nonnull final IOException e) {
throw new RuntimeException(e);
}
});
return this;
}
protected abstract Layer lossLayer();
/**
* To row linked hash buildMap.
*
* @param log the log
* @param labeledObject the labeled object
* @param predictionSignal the prediction signal
* @return the linked hash buildMap
*/
@Nullable
public LinkedHashMap toRow(@Nonnull final NotebookOutput log, @Nonnull final LabeledObject labeledObject, final double[] predictionSignal) {
final int actualCategory = parse(labeledObject.label);
final int[] predictionList = IntStream.range(0, categories).mapToObj(x -> x).sorted(Comparator.comparing(i -> -predictionSignal[i])).mapToInt(x -> x).toArray();
if (predictionList[0] == actualCategory) return null; // We will only examine mispredicted rows
@Nonnull final LinkedHashMap row = new LinkedHashMap<>();
row.put("Image", log.png(labeledObject.data.toImage(), labeledObject.label));
row.put("Prediction", Arrays.stream(predictionList).limit(3)
.mapToObj(i -> String.format("%d (%.1f%%)", i, 100.0 * predictionSignal[i]))
.reduce((a, b) -> a + ", " + b).get());
return row;
}
/**
* Gets batch size.
*
* @return the batch size
*/
public int getBatchSize() {
return batchSize;
}
/**
* Sets batch size.
*
* @param batchSize the batch size
* @return the batch size
*/
@Nonnull
public ClassifyProblem setBatchSize(int batchSize) {
this.batchSize = batchSize;
return this;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy