All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.datasets.fetchers.TinyImageNetFetcher Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*-
 *  * Copyright 2016 Skymind, Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 */

package org.deeplearning4j.datasets.fetchers;

import org.datavec.api.io.filters.RandomPathFilter;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.image.loader.BaseImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.image.transform.ImageTransform;
import org.datavec.image.transform.MultiImageTransform;
import org.datavec.image.transform.ResizeImageTransform;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;

/**
 * Tiny ImageNet is a subset of the ImageNet database. TinyImageNet is the default course challenge for CS321n
 * at Stanford University.
 *
 * Tiny ImageNet has 200 classes, each consisting of 500 training images.
 *
 * See: http://cs231n.stanford.edu/ and
 * https://tiny-imagenet.herokuapp.com/
 *
 * @author Justin Long (crockpotveggies)
 */
public class TinyImageNetFetcher extends CacheableExtractableDataSetFetcher {

    private File fileDir;
    private static Logger log = LoggerFactory.getLogger(TinyImageNetFetcher.class);

    public static int INPUT_WIDTH = 64;
    public static int INPUT_HEIGHT = 64;
    public static int INPUT_CHANNELS = 3;
    public static int NUM_LABELS = 200;
    public static int NUM_EXAMPLES = NUM_LABELS*500;

    @Override
    public String remoteDataUrl(DataSetType set) { return "http://blob.deeplearning4j.org/datasets/tinyimagenet_200_dl4j.v1.zip"; }
    @Override
    public String localCacheName(){ return "TINYIMAGENET_200"; }
    @Override
    public long expectedChecksum(DataSetType set) { return 33822361L; }
    @Override
    public RecordReader getRecordReader(long rngSeed, int[] imgDim, DataSetType set, ImageTransform imageTransform) {
        // check empty cache
        if(LOCAL_CACHE.exists()) {
            if(LOCAL_CACHE.listFiles().length<1) LOCAL_CACHE.delete();
        }

        try {
            if (!LOCAL_CACHE.exists()) downloadAndExtract();
        } catch(Exception e) {
            throw new RuntimeException("Could not download TinyImageNet", e);
        }

        Random rng = new Random(rngSeed);
        File datasetPath;
        switch (set) {
            case TRAIN:
                datasetPath = new File(LOCAL_CACHE, "/train/");
                break;
            case TEST:
                datasetPath = new File(LOCAL_CACHE, "/test/");
                break;
            case VALIDATION:
                throw new IllegalArgumentException("You will need to manually iterate the /validation/images/ directory, TinyImageNet does not provide labels");

            default:
                datasetPath = new File(LOCAL_CACHE, "/train/");
        }

        // set up file paths
        RandomPathFilter pathFilter = new RandomPathFilter(rng, BaseImageLoader.ALLOWED_FORMATS);
        FileSplit filesInDir = new FileSplit(datasetPath, BaseImageLoader.ALLOWED_FORMATS, rng);
        InputSplit[] filesInDirSplit = filesInDir.sample(pathFilter, 1);

        // add transforms
        List transforms = new LinkedList<>();
        if(imgDim.length > 0) new ResizeImageTransform(imgDim[0], imgDim[1]);
        if(imageTransform != null) transforms.add(imageTransform);

        ImageRecordReader rr = new ImageRecordReader(TinyImageNetFetcher.INPUT_HEIGHT, TinyImageNetFetcher.INPUT_WIDTH,
                    TinyImageNetFetcher.INPUT_CHANNELS, new ParentPathLabelGenerator(), new MultiImageTransform(transforms.toArray(new ImageTransform[transforms.size()])));

        try {
            rr.initialize(filesInDirSplit[0]);
        } catch(Exception e) {
            throw new RuntimeException(e);
        }

        return rr;
    }


}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy