
com.simiacryptus.mindseye.test.data.MNIST Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mindseye-test Show documentation
Show all versions of mindseye-test Show documentation
Testing Tools for Neural Network Components
The newest version!
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.test.data;
import com.simiacryptus.mindseye.lang.Tensor;
import com.simiacryptus.mindseye.test.TestUtil;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.wrappers.*;
import com.simiacryptus.util.Util;
import com.simiacryptus.util.io.BinaryChunkIterator;
import com.simiacryptus.util.io.DataLoader;
import com.simiacryptus.util.test.LabeledObject;
import org.apache.commons.io.IOUtils;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.io.*;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.util.Spliterator;
import java.util.function.Consumer;
import java.util.zip.GZIPInputStream;
/**
* The type Mnist.
*/
public class MNIST {
/**
* The constant training.
*/
public static final DataLoader> training = new DataLoader>() {
{
}
public @SuppressWarnings("unused")
void _free() {
super._free();
}
@Override
protected void read(@Nonnull final RefList> queue) {
try {
final RefStream imgStream = MNIST.binaryStream("train-images-idx3-ubyte.gz", 16, 28 * 28).map(b -> {
return MNIST.fillImage(b, new Tensor(28, 28, 1));
});
@Nonnull final RefStream labelStream = MNIST.binaryStream("train-labels-idx1-ubyte.gz", 8, 1);
@Nonnull final RefStream> merged = MNIST
.toStream(new LabeledObjectIterator(imgStream, labelStream), 100);
merged.forEach(RefUtil.wrapInterface((Consumer super LabeledObject>) queue::add,
queue.addRef()));
} catch (@Nonnull final IOException e) {
throw Util.throwException(e);
}
queue.freeRef();
}
};
/**
* The constant validation.
*/
public static final DataLoader> validation = new DataLoader>() {
{
}
public @SuppressWarnings("unused")
void _free() {
super._free();
}
@Override
protected void read(@Nonnull final RefList> queue) {
try {
final RefStream imgStream = MNIST.binaryStream("t10k-images-idx3-ubyte.gz", 16, 28 * 28).map(b -> {
return MNIST.fillImage(b, new Tensor(28, 28, 1));
});
@Nonnull final RefStream labelStream = MNIST.binaryStream("t10k-labels-idx1-ubyte.gz", 8, 1);
@Nonnull final RefStream> merged = MNIST
.toStream(new LabeledObjectIterator(imgStream, labelStream), 100);
merged.forEach(RefUtil.wrapInterface((Consumer super LabeledObject>) queue::add,
queue.addRef()));
} catch (@Nonnull final IOException e) {
throw Util.throwException(e);
}
queue.freeRef();
}
};
/**
* Training data stream ref stream.
*
* @return the ref stream
*/
@Nonnull
public static RefStream> trainingDataStream() {
return MNIST.training.stream();
}
/**
* Validation data stream ref stream.
*
* @return the ref stream
*/
@Nonnull
public static RefStream> validationDataStream() {
return MNIST.validation.stream();
}
private static RefStream binaryStream(@Nonnull final String name, final int skip, final int recordSize)
throws IOException {
@Nullable
InputStream stream = null;
try {
stream = Util.cacheStream(TestUtil.S3_ROOT.resolve(name));
} catch (@Nonnull NoSuchAlgorithmException | KeyManagementException e) {
throw Util.throwException(e);
}
final byte[] fileData = IOUtils
.toByteArray(new BufferedInputStream(new GZIPInputStream(new BufferedInputStream(stream))));
@Nonnull final DataInputStream in = new DataInputStream(new ByteArrayInputStream(fileData));
in.skip(skip);
return MNIST.toIterator(new BinaryChunkIterator(in, recordSize));
}
@Nonnull
private static Tensor fillImage(final byte[] b, @Nonnull final Tensor tensor) {
for (int x = 0; x < 28; x++) {
for (int y = 0; y < 28; y++) {
tensor.set(new int[]{x, y}, b[x + y * 28] & 0xFF);
}
}
return tensor;
}
private static RefStream toIterator(@Nonnull final RefIteratorBase iterator) {
return RefStreamSupport
.stream(RefSpliterators.spliterator(iterator, 1, Spliterator.ORDERED), false);
}
private static RefStream toStream(@Nonnull final RefIteratorBase iterator, final int size) {
return MNIST.toStream(iterator, size, false);
}
private static RefStream toStream(@Nonnull final RefIteratorBase iterator, final int size,
final boolean parallel) {
return RefStreamSupport
.stream(RefSpliterators.spliterator(iterator, size, Spliterator.ORDERED), parallel);
}
private static class LabeledObjectIterator extends RefIteratorBase> {
@Nonnull
private final RefIterator imgItr;
@Nonnull
private final RefIterator labelItr;
/**
* Instantiates a new Labeled object iterator.
*
* @param imgStream the img stream
* @param labelStream the label stream
*/
public LabeledObjectIterator(@Nonnull RefStream imgStream, @Nonnull RefStream labelStream) {
imgItr = imgStream.iterator();
labelItr = labelStream.iterator();
}
@Override
public boolean hasNext() {
return imgItr.hasNext() && labelItr.hasNext();
}
@Nonnull
@Override
public LabeledObject next() {
return new LabeledObject<>(imgItr.next(), RefArrays.toString(labelItr.next()));
}
public @SuppressWarnings("unused")
void _free() {
super._free();
labelItr.freeRef();
imgItr.freeRef();
}
@Nonnull
public @Override
@SuppressWarnings("unused")
LabeledObjectIterator addRef() {
return (LabeledObjectIterator) super.addRef();
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy