All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.test.data.MNIST Maven / Gradle / Ivy

There is a newer version: 2.1.0
Show newest version
/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.test.data;

import com.simiacryptus.mindseye.lang.Tensor;
import com.simiacryptus.mindseye.test.TestUtil;
import com.simiacryptus.util.Util;
import com.simiacryptus.util.io.BinaryChunkIterator;
import com.simiacryptus.util.io.DataLoader;
import com.simiacryptus.util.test.LabeledObject;
import org.apache.commons.io.IOUtils;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.io.*;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.util.*;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import java.util.zip.GZIPInputStream;

/**
 * References: [LeCun et al., 1998a] Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. "Gradient-based learning applied to
 * document recognition." Proceedings of the IEEE, 86(11):2278-2324, November 1998. See Also:
 * http://yann.lecun.com/exdb/mnist/
 */
public class MNIST {

  /**
   * The constant training.
   */
  public static final DataLoader> training = new DataLoader>() {
    @Override
    protected void read(@Nonnull final List> queue) {
      try {
        final Stream imgStream = MNIST.binaryStream("train-images-idx3-ubyte.gz", 16, 28 * 28).map(b -> {
          return MNIST.fillImage(b, new Tensor(28, 28, 1));
        });
        @Nonnull final Stream labelStream = MNIST.binaryStream("train-labels-idx1-ubyte.gz", 8, 1);

        @Nonnull final Stream> merged = MNIST.toStream(new Iterator>() {
          @Nonnull
          Iterator imgItr = imgStream.iterator();
          @Nonnull
          Iterator labelItr = labelStream.iterator();

          @Override
          public boolean hasNext() {
            return imgItr.hasNext() && labelItr.hasNext();
          }

          @Nonnull
          @Override
          public LabeledObject next() {
            return new LabeledObject<>(imgItr.next(), Arrays.toString(labelItr.next()));
          }
        }, 100);
        merged.forEach(x -> queue.add(x));
      } catch (@Nonnull final IOException e) {
        throw new RuntimeException(e);
      }
    }
  };
  /**
   * The constant validation.
   */
  public static final DataLoader> validation = new DataLoader>() {
    @Override
    protected void read(@Nonnull final List> queue) {
      try {
        final Stream imgStream = MNIST.binaryStream("t10k-images-idx3-ubyte.gz", 16, 28 * 28).map(b -> {
          return MNIST.fillImage(b, new Tensor(28, 28, 1));
        });
        @Nonnull final Stream labelStream = MNIST.binaryStream("t10k-labels-idx1-ubyte.gz", 8, 1);

        @Nonnull final Stream> merged = MNIST.toStream(new Iterator>() {
          @Nonnull
          Iterator imgItr = imgStream.iterator();
          @Nonnull
          Iterator labelItr = labelStream.iterator();

          @Override
          public boolean hasNext() {
            return imgItr.hasNext() && labelItr.hasNext();
          }

          @Nonnull
          @Override
          public LabeledObject next() {
            return new LabeledObject<>(imgItr.next(), Arrays.toString(labelItr.next()));
          }
        }, 100);
        merged.forEach(x -> queue.add(x));
      } catch (@Nonnull final IOException e) {
        throw new RuntimeException(e);
      }
    }
  };

  private static Stream binaryStream(@Nonnull final String name, final int skip, final int recordSize) throws IOException {
    @Nullable InputStream stream = null;
    try {
      stream = Util.cacheStream(TestUtil.S3_ROOT.resolve(name));
    } catch (@Nonnull NoSuchAlgorithmException | KeyManagementException e) {
      throw new RuntimeException(e);
    }
    final byte[] fileData = IOUtils.toByteArray(new BufferedInputStream(new GZIPInputStream(new BufferedInputStream(stream))));
    @Nonnull final DataInputStream in = new DataInputStream(new ByteArrayInputStream(fileData));
    in.skip(skip);
    return MNIST.toIterator(new BinaryChunkIterator(in, recordSize));
  }

  @Nonnull
  private static Tensor fillImage(final byte[] b, @Nonnull final Tensor tensor) {
    for (int x = 0; x < 28; x++) {
      for (int y = 0; y < 28; y++) {
        tensor.set(new int[]{x, y}, b[x + y * 28] & 0xFF);
      }
    }
    return tensor;
  }

  private static  Stream toIterator(@Nonnull final Iterator iterator) {
    return StreamSupport.stream(Spliterators.spliterator(iterator, 1, Spliterator.ORDERED), false);
  }

  private static  Stream toStream(@Nonnull final Iterator iterator, final int size) {
    return MNIST.toStream(iterator, size, false);
  }

  private static  Stream toStream(@Nonnull final Iterator iterator, final int size, final boolean parallel) {
    return StreamSupport.stream(Spliterators.spliterator(iterator, size, Spliterator.ORDERED), parallel);
  }

  /**
   * Training data stream stream.
   *
   * @return the stream
   */
  public static Stream> trainingDataStream() {
    return MNIST.training.stream();
  }

  /**
   * Validation data stream stream.
   *
   * @return the stream
   */
  public static Stream> validationDataStream() {
    return MNIST.validation.stream();
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy