All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.datasets.fetchers.MnistDataFetcher Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.datasets.fetchers;

import lombok.SneakyThrows;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.datasets.base.MnistFetcher;
import org.deeplearning4j.common.resources.DL4JResources;
import org.deeplearning4j.common.resources.ResourceType;
import org.deeplearning4j.datasets.mnist.MnistManager;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.fetcher.BaseDataFetcher;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.common.util.MathUtils;

import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Random;
import java.util.zip.Adler32;
import java.util.zip.Checksum;


public class MnistDataFetcher extends BaseDataFetcher {
    public static final int NUM_EXAMPLES = 60000;
    public static final int NUM_EXAMPLES_TEST = 10000;

    protected static final long CHECKSUM_TRAIN_FEATURES = 2094436111L;
    protected static final long CHECKSUM_TRAIN_LABELS = 4008842612L;
    protected static final long CHECKSUM_TEST_FEATURES = 2165396896L;
    protected static final long CHECKSUM_TEST_LABELS = 2212998611L;

    protected static final long[] CHECKSUMS_TRAIN = new long[]{CHECKSUM_TRAIN_FEATURES, CHECKSUM_TRAIN_LABELS};
    protected static final long[] CHECKSUMS_TEST = new long[]{CHECKSUM_TEST_FEATURES, CHECKSUM_TEST_LABELS};

    protected boolean binarize = true;
    protected boolean train;
    protected int[] order;
    protected Random rng;
    protected boolean shuffle;
    protected boolean oneIndexed = false;
    protected boolean fOrder = false; //MNIST is C order, EMNIST is F order

    protected boolean firstShuffle = true;
    protected final int numExamples;
    protected String images,labels;
    //note: we default to zero here on purpose, otherwise when first initializes an error is thrown.
    private long lastCursor = 0;


    /**
     * Constructor telling whether to binarize the dataset or not
     * @param binarize whether to binarize the dataset or not
     * @throws IOException
     */
    public MnistDataFetcher(boolean binarize) throws IOException {
        this(binarize, true, true, System.currentTimeMillis(), NUM_EXAMPLES);
    }

    public MnistDataFetcher(boolean binarize, boolean train, boolean shuffle, long rngSeed, int numExamples) throws IOException {
        if (!mnistExists()) {
            new MnistFetcher().downloadAndUntar();
        }

        String MNIST_ROOT = DL4JResources.getDirectory(ResourceType.DATASET, "MNIST").getAbsolutePath();
        long[] checksums;
        if (train) {
            images = FilenameUtils.concat(MNIST_ROOT, MnistFetcher.TRAINING_FILES_FILENAME_UNZIPPED);
            labels = FilenameUtils.concat(MNIST_ROOT, MnistFetcher.TRAINING_FILE_LABELS_FILENAME_UNZIPPED);
            totalExamples = NUM_EXAMPLES;
            checksums = CHECKSUMS_TRAIN;
        } else {
            images = FilenameUtils.concat(MNIST_ROOT, MnistFetcher.TEST_FILES_FILENAME_UNZIPPED);
            labels = FilenameUtils.concat(MNIST_ROOT, MnistFetcher.TEST_FILE_LABELS_FILENAME_UNZIPPED);
            totalExamples = NUM_EXAMPLES_TEST;
            checksums = CHECKSUMS_TEST;
        }
        String[] files = new String[]{images, labels};

        try {
            MnistManager man = new MnistManager(images, labels, train);
            validateFiles(files, checksums);
            man.close();
        } catch (Exception e) {
            try {
                FileUtils.deleteDirectory(new File(MNIST_ROOT));
            } catch (Exception e2){ }
            new MnistFetcher().downloadAndUntar();
            MnistManager man = new MnistManager(images, labels, train);
            lastCursor = man.getCurrent();
            validateFiles(files, checksums);
            man.close();
        }

        MnistManager man = new MnistManager(images, labels, train);

        numOutcomes = 10;
        this.binarize = binarize;
        cursor = 0;
        inputColumns = man.getImages().getEntryLength();
        this.train = train;
        this.shuffle = shuffle;

        if (train) {
            order = new int[NUM_EXAMPLES];
        } else {
            order = new int[NUM_EXAMPLES_TEST];
        }
        for (int i = 0; i < order.length; i++)
            order[i] = i;
        rng = new Random(rngSeed);
        this.numExamples = numExamples;
        reset(); //Shuffle order
        man.close();
    }

    private boolean mnistExists() {
        String MNIST_ROOT = DL4JResources.getDirectory(ResourceType.DATASET, "MNIST").getAbsolutePath();
        //Check 4 files:
        File f = new File(MNIST_ROOT, MnistFetcher.TRAINING_FILES_FILENAME_UNZIPPED);
        if (!f.exists())
            return false;
        f = new File(MNIST_ROOT, MnistFetcher.TRAINING_FILE_LABELS_FILENAME_UNZIPPED);
        if (!f.exists())
            return false;
        f = new File(MNIST_ROOT, MnistFetcher.TEST_FILES_FILENAME_UNZIPPED);
        if (!f.exists())
            return false;
        f = new File(MNIST_ROOT, MnistFetcher.TEST_FILE_LABELS_FILENAME_UNZIPPED);
        if (!f.exists())
            return false;
        return true;
    }

    private void validateFiles(String[] files, long[] checksums) {
        //Validate files:
        try {
            for (int i = 0; i < files.length; i++) {
                File f = new File(files[i]);
                Checksum adler = new Adler32();
                long checksum = f.exists() ? FileUtils.checksum(f, adler).getValue() : -1;
                if (!f.exists() || checksum != checksums[i]) {
                    throw new IllegalStateException("Failed checksum: expected " + checksums[i] +
                            ", got " + checksum + " for file: " + f);
                }
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public MnistDataFetcher() throws IOException {
        this(true);
    }

    private float[][] featureData = null;

    @SneakyThrows
    @Override
    public void fetch(int numExamples) {
        if (!hasMore()) {
            throw new IllegalStateException("Unable to get more; there are no more images");
        }

        MnistManager man = new MnistManager(images, labels, totalExamples);
        man.setCurrent((int) lastCursor);
        INDArray labels = Nd4j.zeros(DataType.FLOAT, numExamples, numOutcomes);

        if(featureData == null || featureData.length < numExamples){
            featureData = new float[numExamples][28 * 28];
        }

        int actualExamples = 0;
        byte[] working = null;
        for (int i = 0; i < numExamples; i++, cursor++) {
            if (!hasMore())
                break;

            man.setCurrent(cursor);
            lastCursor = cursor;
            byte[] img = man.readImageUnsafe(order[cursor]);

            if (fOrder) {
                //EMNIST requires F order to C order
                if (working == null) {
                    working = new byte[28 * 28];
                }
                for (int j = 0; j < 28 * 28; j++) {
                    working[j] = img[28 * (j % 28) + j / 28];
                }
                img = working;
            }

            int label = man.readLabel(order[cursor]);
            if (oneIndexed) {
                //For some inexplicable reason, Emnist LETTERS set is indexed 1 to 26 (i.e., 1 to nClasses), while everything else
                // is indexed (0 to nClasses-1) :/
                label--;
            }

            labels.put(actualExamples, label, 1.0f);

            for(int j = 0 ; j < img.length ; j++) {
                featureData[actualExamples][j] = ((int) img[j]) & 0xFF;
            }

            actualExamples++;
        }

        INDArray features;

        if(featureData.length == actualExamples){
            features = Nd4j.create(featureData);
        } else {
            features = Nd4j.create(Arrays.copyOfRange(featureData, 0, actualExamples));
        }

        if (actualExamples < numExamples) {
            labels = labels.get(NDArrayIndex.interval(0, actualExamples), NDArrayIndex.all());
        }

        if(binarize){
            features = features.gt(30.0).castTo(DataType.FLOAT);
        } else {
            features.divi(255.0);
        }

        curr = new DataSet(features, labels);
        man.close();
    }

    @Override
    public void reset() {
        cursor = 0;
        curr = null;
        if (shuffle) {
            if((train && numExamples < NUM_EXAMPLES) || (!train && numExamples < NUM_EXAMPLES_TEST)){
                //Shuffle only first N elements
                if(firstShuffle){
                    MathUtils.shuffleArray(order, rng);
                    firstShuffle = false;
                } else {
                    MathUtils.shuffleArraySubset(order, numExamples, rng);
                }
            } else {
                MathUtils.shuffleArray(order, rng);
            }
        }
    }

    @Override
    public DataSet next() {
        DataSet next = super.next();
        return next;
    }

    public void close() {
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy