All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.vision.ImageDataset Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2010-2024 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Smile.  If not, see .
 */
package smile.vision;

import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Objects;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.function.ToIntFunction;
import javax.imageio.ImageIO;
import smile.deep.Dataset;
import smile.deep.SampleBatch;
import smile.deep.tensor.Tensor;
import smile.math.MathEx;
import smile.vision.transform.Transform;

/**
 * Each of these directories should contain one subdirectory for each class
 * in the dataset. The subdirectories are named after the corresponding
 * class and contain all the images for that class. Ensure that each image
 * file is named uniquely and stored in a common format such as JPEG or PNG.
 *
 * @author Haifeng Li
 */
public class ImageDataset implements Dataset {
    record ImageFile(File file, String label) { }
    private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(ImageDataset.class);
    private final ArrayList samples = new ArrayList<>();
    private final int batch;
    private final Transform transform;
    private final ToIntFunction targetTransform;

    /**
     * Constructor.
     * @param batch the mini-batch size.
     * @param root the root directory of image dataset.
     * @param transform the transformation from image to tensor.
     * @param targetTransform the transform from image label to class index.
     * @throws IOException if the root directory doesn't exist or doesn't have images.
     */
    public ImageDataset(int batch, String root, Transform transform, ToIntFunction targetTransform) throws IOException {
        this.batch = batch;
        this.transform = transform;
        this.targetTransform = targetTransform;

        File dir = new File(root);
        if (!dir.exists()) {
            throw new IOException("Dataset root directory doesn't exist: " + root);
        }

        for (var child : Objects.requireNonNull(dir.listFiles())) {
            if (child.isDirectory()) {
                String label = child.getName();
                File[] images = Objects.requireNonNull(child.listFiles());
                for (var image : images) {
                    if (image.isFile()) {
                        String name = image.getName().toLowerCase();
                        if (name.endsWith(".jpg") || name.endsWith(".jpeg") || name.endsWith(".png")) {
                            samples.add(new ImageFile(image, label));
                        }
                    }
                }
            }
        }

        if (samples.isEmpty()) {
            throw new IOException("No JPEG or PNG images found in " + root);
        }
    }

    @Override
    public void close() {
        // We don't hold any (external) resources.
    }

    @Override
    public long size() {
        return samples.size();
    }

    @Override
    public Iterator iterator() {
        final int size = samples.size();
        final int[] permutation = MathEx.permutate(size);
        final BlockingQueue queue = new LinkedBlockingQueue<>(100);

        final int start = Math.min(batch, size);
        final int[] index = Arrays.copyOf(permutation, start);

        try {
            // prefetch the first batch
            queue.put(readImages(index));
        } catch (Exception ex) {
            logger.error("Failed to load the first batch", ex);
        }

        final Runnable worker = () -> {
            for (int i = start; i < size; ) {
                int n = Math.min(batch, size - i);
                System.arraycopy(permutation, i, index,  0, n);
                i += n;

                try {
                    queue.put(readImages(n == index.length ? index : Arrays.copyOf(index, n)));
                } catch (Exception ex) {
                    logger.error("Failed to load images", ex);
                }
            }
        };

        Thread thread = new Thread(worker, "ImageDatasetLoader");
        thread.start();

        return new Iterator<>() {
            @Override
            public boolean hasNext() {
                return !queue.isEmpty() || thread.isAlive();
            }

            @Override
            public SampleBatch next() {
                try {
                    return queue.take();
                } catch (InterruptedException ex) {
                    logger.error("Failed to take next sample batch", ex);
                    return null;
                }
            }
        };
    }

    /**
     * Reads a mini-batch of image samples.
     * @param index the sample index.
     * @return the sample batch.
     * @throws IOException if fail to read the image.
     */
    private SampleBatch readImages(int[] index) throws IOException {
        int n = index.length;
        long[] target = new long[n];
        BufferedImage[] images = new BufferedImage[n];
        for (int i = 0; i < n; i++) {
            var sample = samples.get(index[i]);
            images[i] = ImageIO.read(sample.file);
            target[i] = targetTransform.applyAsInt(sample.label);
        }
        return new SampleBatch(transform.forward(images), Tensor.of(target, images.length));
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy