All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.base.LFWLoader Maven / Gradle / Ivy

package org.deeplearning4j.base;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.net.URL;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.linalg.util.FeatureUtil;
import org.deeplearning4j.util.ArchiveUtils;
import org.deeplearning4j.util.ImageLoader;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Loads LFW faces data applyTransformToDestination. You can customize the size of the images as well
 * @author Adam Gibson
 *
 */
public class LFWLoader {

    private File baseDir = new File(System.getProperty("user.home"));
    public final static String LFW = "lfw";
    private File lfwDir = new File(baseDir,LFW);
    public final static String LFW_URL = "http://vis-www.cs.umass.edu/lfw/lfw.tgz";
    private File lfwTarFile = new File(lfwDir,"lfw.tgz");
    private static Logger log = LoggerFactory.getLogger(LFWLoader.class);
    private int numNames;
    private int numPixelColumns;
    private ImageLoader loader = new ImageLoader(28,28);
    private List images = new ArrayList();
    private List outcomes = new ArrayList();



    public LFWLoader() {
        this(28,28);
    }


    public LFWLoader(int imageWidth,int imageHeight) {
        loader = new ImageLoader(imageWidth,imageHeight);
    }

    public void getIfNotExists() throws Exception {
        if(!lfwDir.exists()) {
            lfwDir.mkdir();
            log.info("Grabbing LFW...");

            URL website = new URL(LFW_URL);
            ReadableByteChannel rbc = Channels.newChannel(website.openStream());
            if(!lfwTarFile.exists())
                 lfwTarFile.createNewFile();
            FileOutputStream fos = new FileOutputStream(lfwTarFile);
            fos.getChannel().transferFrom(rbc, 0, Long.MAX_VALUE);
            fos.flush();
            IOUtils.closeQuietly(fos);
            rbc.close();
            log.info("Downloaded lfw");
            untarFile(baseDir,lfwTarFile);

        }


        File firstImage = null;
        try {
            firstImage = lfwDir.listFiles()[0].listFiles()[0];

        }catch(Exception e) {
            FileUtils.deleteDirectory(lfwDir);
            log.warn("Error opening first image; probably corrupt download...trying again",e);
            getIfNotExists();

        }


        //number of input neurons
        numPixelColumns = ArrayUtil.flatten(loader.fromFile(firstImage)).length;

        //each subdir is a person
        numNames = lfwDir.getAbsoluteFile().listFiles().length;

        @SuppressWarnings("unchecked")
        Collection allImages = FileUtils.listFiles(lfwDir, org.apache.commons.io.filefilter.FileFileFilter.FILE, org.apache.commons.io.filefilter.DirectoryFileFilter.DIRECTORY);
        for(File f : allImages) {
            images.add(f.getAbsolutePath());
        }
        for(File dir : lfwDir.getAbsoluteFile().listFiles())
            outcomes.add(dir.getAbsolutePath());

    }



    public DataSet convertListPairs(List images) {
        INDArray inputs = Nd4j.create(images.size(), numPixelColumns);
        INDArray outputs = Nd4j.create(images.size(),numNames);

        for(int i = 0; i < images.size(); i++) {
            inputs.putRow(i,images.get(i).getFeatureMatrix());
            outputs.putRow(i,images.get(i).getLabels());
        }
        return new DataSet(inputs,outputs);
    }



    public DataSet getDataFor(int i) {
        File image = new File(images.get(i));
        int outcome = outcomes.indexOf(image.getParentFile().getAbsolutePath());
        try {
            return new DataSet(loader.asRowVector(image), FeatureUtil.toOutcomeVector(outcome, outcomes.size()));
        } catch (Exception e) {
            throw new IllegalStateException("Unable to getFromOrigin data for image " + i + " for path " + images.get(i));
        }
    }

    /**
     * Get the first num found images
     * @param num the number of images to getFromOrigin
     * @return
     * @throws Exception
     */
    public List getFeatureMatrix(int num) throws Exception {
        List ret = new ArrayList<>(num);
        File[] files = lfwDir.listFiles();
        int label = 0;
        for(File file : files) {
            ret.addAll(getImages(label,file));
            label++;
            if(ret.size() >= num)
                break;
        }

        return ret;
    }

    public DataSet getAllImagesAsMatrix() throws Exception {
        List images = getImagesAsList();
        return convertListPairs(images);
    }


    public DataSet getAllImagesAsMatrix(int numRows) throws Exception {
        List images = getImagesAsList().subList(0, numRows);
        return convertListPairs(images);
    }

    public List getImagesAsList() throws Exception {
        List list = new ArrayList<>();
        File[] dirs = lfwDir.listFiles();
        for(int i = 0; i < dirs.length; i++) {
            list.addAll(getImages(i,dirs[i]));
        }
        return list;
    }

    public List getImages(int label,File file) throws Exception {
        File[] images = file.listFiles();
        List ret = new ArrayList<>();
        for(File f : images)
            ret.add(fromImageFile(label,f));
        return ret;
    }


    public DataSet fromImageFile(int label,File image) throws Exception {
        INDArray outcome = FeatureUtil.toOutcomeVector(label, numNames);
        INDArray image2 = ArrayUtil.toNDArray(loader.flattenedImageFromFile(image));
        return new DataSet(image2,outcome);
    }



    public  void untarFile(File baseDir, File tarFile) throws IOException {


        log.info("Untaring File: " + tarFile.toString());

        ArchiveUtils.unzipFileTo(tarFile.getAbsolutePath(),baseDir.getAbsolutePath());

    }




    public int getNumNames() {
        return numNames;
    }



    public int getNumPixelColumns() {
        return numPixelColumns;
    }



}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy