All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.base.LFWLoader Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.base;

import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

import org.apache.commons.io.FileUtils;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.util.ArrayUtil;
import org.deeplearning4j.util.ImageLoader;
import org.deeplearning4j.util.MatrixUtil;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


public class LFWLoader {

	private File baseDir = new File(System.getProperty("java.io.tmpdir"));
	public final static String LFW = "lfw";
	private File lfwDir = new File(baseDir,LFW);
	public final static String LFW_URL = "http://vis-www.cs.umass.edu/lfw/lfw.tgz";
	private File lfwTarFile = new File(lfwDir,"lfw.tgz");
	private static Logger log = LoggerFactory.getLogger(LFWLoader.class);
	private int numNames;
	private int numPixelColumns;
	private ImageLoader loader = new ImageLoader();
	private List images = new ArrayList();
	private List outcomes = new ArrayList();
	
	public void getIfNotExists() throws Exception {
		if(!lfwDir.exists()) {
			lfwDir.mkdir();
			FileUtils.copyURLToFile(new URL(LFW_URL), lfwTarFile);
			//untar to /tmp/lfw
			untarFile(baseDir,lfwTarFile);
			
		}
		
		File firstImage = lfwDir.listFiles()[0].listFiles()[0];
		//number of input neurons
		numPixelColumns = ArrayUtil.flatten(loader.fromFile(firstImage)).length;
		
		//each subdir is a person
		numNames = lfwDir.getAbsoluteFile().listFiles().length;
		
		Collection allImages = FileUtils.listFiles(lfwDir, org.apache.commons.io.filefilter.FileFileFilter.FILE, org.apache.commons.io.filefilter.DirectoryFileFilter.DIRECTORY);
		for(File f : allImages) {
			images.add(f.getAbsolutePath());
		}
		for(File dir : lfwDir.getAbsoluteFile().listFiles())
			outcomes.add(dir.getAbsolutePath());
		
	}
	
	
	
	public Pair convertListPairs(List> images) {
		DoubleMatrix inputs = new DoubleMatrix(images.size(),numPixelColumns);
		DoubleMatrix outputs = new DoubleMatrix(images.size(),numNames);
		
		for(int i = 0; i < images.size(); i++) {
			inputs.putRow(i,images.get(i).getFirst());
			outputs.putRow(i,images.get(i).getSecond());
		}
		return new Pair(inputs,outputs);
	}
	
	
	
	public Pair getDataFor(int i) {
		File image = new File(images.get(i));
		int outcome = outcomes.indexOf(image.getParentFile().getAbsolutePath());
		try {
			return new Pair(loader.asRowVector(image),MatrixUtil.toOutcomeVector(outcome, outcomes.size()));
		} catch (Exception e) {
			throw new IllegalStateException("Unable to get data for image " + i + " for path " + images.get(i));
		}
	}
	
	/**
	 * Get the first num found images
	 * @param num the number of images to get
	 * @return 
	 * @throws Exception
	 */
	public List> getFirst(int num) throws Exception {
		List> ret = new ArrayList<>(num);
		File[] files = lfwDir.listFiles();
		int label = 0;
		for(File file : files) {
			ret.addAll(getImages(label,file));
			label++;
			if(ret.size() >= num)
				break;
		}
		
		return ret;
	}
	
	public Pair getAllImagesAsMatrix() throws Exception {
		List> images = getImagesAsList();
		return convertListPairs(images);
	}
	

	public Pair getAllImagesAsMatrix(int numRows) throws Exception {
		List> images = getImagesAsList().subList(0, numRows);
		return convertListPairs(images);
	}
	
	public List> getImagesAsList() throws Exception {
		List> list = new ArrayList>();
		File[] dirs = lfwDir.listFiles();
		for(int i = 0; i < dirs.length; i++) {
			list.addAll(getImages(i,dirs[i]));
		}
		return list;
	}
	
	public List> getImages(int label,File file) throws Exception {
		File[] images = file.listFiles();
		List> ret = new ArrayList<>();
		for(File f : images)
			ret.add(fromImageFile(label,f));
		return ret;
	}
	
	
	public Pair fromImageFile(int label,File image) throws Exception {
		DoubleMatrix outcome = MatrixUtil.toOutcomeVector(label, numNames);
		DoubleMatrix image2 = MatrixUtil.toMatrix(loader.flattenedImageFromFile(image));
		return new Pair<>(image2,outcome);
	}
	
	

	public  void untarFile(File baseDir, File tarFile) throws IOException {

		log.info("Untaring File: " + tarFile.toString());

		Process p = Runtime.getRuntime().exec(String.format("tar -C %s -xvf %s", 
				baseDir.getAbsolutePath(), tarFile.getAbsolutePath()));
		BufferedReader stdError = new BufferedReader(new 
				InputStreamReader(p.getErrorStream()));
		log.info("Here is the standard error of the command (if any):\n");
		String s;
		while ((s = stdError.readLine()) != null) {
			log.info(s);
		}
		stdError.close();


	}

	public static void gunzipFile(File baseDir, File gzFile) throws IOException {

		log.info("gunzip'ing File: " + gzFile.toString());

		Process p = Runtime.getRuntime().exec(String.format("gunzip %s", 
				gzFile.getAbsolutePath()));
		BufferedReader stdError = new BufferedReader(new 
				InputStreamReader(p.getErrorStream()));
		log.info("Here is the standard error of the command (if any):\n");
		String s;
		while ((s = stdError.readLine()) != null) {
			log.info(s);
		}
		stdError.close();


	}



	public int getNumNames() {
		return numNames;
	}



	public int getNumPixelColumns() {
		return numPixelColumns;
	}
	
	

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy