All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.datasets.fetchers.Cifar10Fetcher Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.datasets.fetchers;

import org.datavec.api.io.filters.RandomPathFilter;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.image.loader.BaseImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.image.transform.ImageTransform;
import org.deeplearning4j.common.resources.DL4JResources;
import org.nd4j.common.base.Preconditions;

import java.io.File;
import java.util.Random;

public class Cifar10Fetcher extends CacheableExtractableDataSetFetcher {
    public static final String LABELS_FILENAME = "labels.txt";
    public static final String LOCAL_CACHE_NAME = "cifar10";

    public static int INPUT_WIDTH = 32;
    public static int INPUT_HEIGHT = 32;
    public static int INPUT_CHANNELS = 3;
    public static int NUM_LABELS = 10;

    @Override
    public String remoteDataUrl(DataSetType set) {
        return DL4JResources.getURLString("datasets/cifar10_dl4j.v1.zip");
    }
    @Override
    public String localCacheName(){ return LOCAL_CACHE_NAME; }
    @Override
    public long expectedChecksum(DataSetType set) { return 292852033L; }
    @Override
    public RecordReader getRecordReader(long rngSeed, int[] imgDim, DataSetType set, ImageTransform imageTransform) {
        Preconditions.checkState(imgDim == null || imgDim.length == 2, "Invalid image dimensions: must be null or lenth 2. Got: %s", imgDim);
        // check empty cache
        File localCache = getLocalCacheDir();
        deleteIfEmpty(localCache);

        try {
            if (!localCache.exists()){
                downloadAndExtract();
            }
        } catch(Exception e) {
            throw new RuntimeException("Could not download CIFAR-10", e);
        }

        Random rng = new Random(rngSeed);
        File datasetPath;
        switch (set) {
            case TRAIN:
                datasetPath = new File(localCache, "/train/");
                break;
            case TEST:
                datasetPath = new File(localCache, "/test/");
                break;
            case VALIDATION:
                throw new IllegalArgumentException("You will need to manually create and iterate a validation directory, CIFAR-10 does not provide labels");

            default:
                datasetPath = new File(localCache, "/train/");
        }

        // set up file paths
        RandomPathFilter pathFilter = new RandomPathFilter(rng, BaseImageLoader.ALLOWED_FORMATS);
        FileSplit filesInDir = new FileSplit(datasetPath, BaseImageLoader.ALLOWED_FORMATS, rng);
        InputSplit[] filesInDirSplit = filesInDir.sample(pathFilter, 1);

        int h = (imgDim == null ? Cifar10Fetcher.INPUT_HEIGHT : imgDim[0]);
        int w = (imgDim == null ? Cifar10Fetcher.INPUT_WIDTH : imgDim[1]);
        ImageRecordReader rr = new ImageRecordReader(h, w, Cifar10Fetcher.INPUT_CHANNELS, new ParentPathLabelGenerator(), imageTransform);

        try {
            rr.initialize(filesInDirSplit[0]);
        } catch(Exception e) {
            throw new RuntimeException(e);
        }

        return rr;
    }


}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy