All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.base.DeepLearningTest Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.base;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.datasets.MnistManager;
import org.deeplearning4j.util.ArrayUtil;
import org.deeplearning4j.util.MathUtils;
import org.deeplearning4j.util.MatrixUtil;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;




public abstract class DeepLearningTest {

	private static Logger log = LoggerFactory.getLogger(DeepLearningTest.class);

	public static Pair getIris() throws IOException {
		Pair pair = IrisUtils.loadIris();
		return pair;
	}
	public static Pair getIris(int num) throws IOException {
		Pair pair = IrisUtils.loadIris(num);
		return pair;
	}
	
	
	
	
	/**
	 * LFW Dataset: pick first num faces
	 * @param num
	 * @return
	 * @throws Exception
	 */
	public static Pair getFaces(int num) throws Exception {
		LFWLoader loader = new LFWLoader();
		loader.getIfNotExists();
		return loader.getAllImagesAsMatrix(num);
	}
	
	
	/**
	 * LFW Dataset: pick all faces
	 * @param num
	 * @return
	 * @throws Exception
	 */
	public static Pair getFacesMatrix() throws Exception {
		LFWLoader loader = new LFWLoader();
		loader.getIfNotExists();
		return loader.getAllImagesAsMatrix();
	}
	
	
	
	/**
	 * LFW Dataset: pick first num faces
	 * @param num
	 * @return
	 * @throws Exception
	 */
	public static List> getFirstFaces(int num) throws Exception {
		LFWLoader loader = new LFWLoader();
		loader.getIfNotExists();
		return loader.getFirst(num);
	}
	
	
	/**
	 * LFW Dataset: pick all faces
	 * @param num
	 * @return
	 * @throws Exception
	 */
	public List> getFaces() throws Exception {
		LFWLoader loader = new LFWLoader();
		loader.getIfNotExists();
		return loader.getImagesAsList();
	}
	
	
	
	/**
	 * Gets an mnist example as an input, label pair.
	 * Keep in mind the return matrix for out come is a 1x1 matrix.
	 * If you need multiple labels, remember to convert to a zeros
	 * with 1 as index of the label for the output training vector.
	 * @param example the example to get
	 * @return the image,label pair
	 * @throws IOException
	 */
	public static Pair getMnistExample(int example) throws IOException {
		File ensureExists = new File("/tmp/MNIST");
		if(!ensureExists.exists())
			new MnistFetcher().downloadAndUntar();

		MnistManager man = new MnistManager("/tmp/MNIST/" + MnistFetcher.trainingFilesFilename_unzipped,"/tmp/MNIST/" + MnistFetcher.trainingFileLabelsFilename_unzipped);
		man.setCurrent(example);
		int[] imageExample = ArrayUtil.flatten(man.readImage());
		return new Pair(MatrixUtil.toMatrix(imageExample).transpose(),MatrixUtil.toOutcomeVector(man.readLabel(),10));
	}



	/**
	 * Gets an mnist example as an input, label pair.
	 * Keep in mind the return matrix for out come is a 1x1 matrix.
	 * If you need multiple labels, remember to convert to a zeros
	 * with 1 as index of the label for the output training vector.
	 * @param example the example to get
	 * @param batchSize the batch size of examples to get
	 * @return the image,label pair
	 * @throws IOException
	 */
	public List> getMnistExampleBatches(int batchSize,int numBatches) throws IOException {
		File ensureExists = new File("/tmp/MNIST");
		List> ret = new ArrayList<>();
		if(!ensureExists.exists()) 
			new MnistFetcher().downloadAndUntar();
		MnistManager man = new MnistManager("/tmp/MNIST/" + MnistFetcher.trainingFilesFilename_unzipped,"/tmp/MNIST/" + MnistFetcher.trainingFileLabelsFilename_unzipped);

		int[][] image = man.readImage();
		int[] imageExample = ArrayUtil.flatten(image);

		for(int batch = 0; batch < numBatches; batch++) {
			double[][] examples = new double[batchSize][imageExample.length];
			int[][] outcomeMatrix = new int[batchSize][10];

			for(int i = 1 + batch; i < batchSize + 1 + batch; i++) {
				//1 based indices
				man.setCurrent(i);
				double[] currExample = ArrayUtil.flatten(ArrayUtil.toDouble(man.readImage()));
				examples[i - 1 - batch] = currExample;
				outcomeMatrix[i - 1 - batch] = ArrayUtil.toOutcomeArray(man.readLabel(), 10);
			}
			DoubleMatrix training = new DoubleMatrix(examples);
			ret.add(new Pair<>(training,MatrixUtil.toMatrix(outcomeMatrix)));
		}

		return ret;
	}

	/**
	 * Gets an mnist example as an input, label pair.
	 * Keep in mind the return matrix for out come is a 1x1 matrix.
	 * If you need multiple labels, remember to convert to a zeros
	 * with 1 as index of the label for the output training vector.
	 * @param example the example to get
	 * @param batchSize the batch size of examples to get
	 * @return the image,label pair
	 * @throws IOException
	 */
	public static Pair getMnistExampleBatch(int batchSize) throws IOException {
		File ensureExists = new File("/tmp/MNIST");
		if(!ensureExists.exists() || !new File("/tmp/MNIST/" + MnistFetcher.trainingFilesFilename_unzipped).exists() || !new File("/tmp/MNIST/" + MnistFetcher.trainingFileLabelsFilename_unzipped).exists()) 
			new MnistFetcher().downloadAndUntar();
		MnistManager man = new MnistManager("/tmp/MNIST/" + MnistFetcher.trainingFilesFilename_unzipped,"/tmp/MNIST/" + MnistFetcher.trainingFileLabelsFilename_unzipped);

		int[][] image = man.readImage();
		int[] imageExample = ArrayUtil.flatten(image);
		double[][] examples = new double[batchSize][imageExample.length];
		int[][] outcomeMatrix = new int[batchSize][10];
		for(int i = 1; i < batchSize + 1; i++) {
			//1 based indices
			man.setCurrent(i);
			double[] currExample = ArrayUtil.flatten(ArrayUtil.toDouble(man.readImage()));
			for(int j = 0; j < currExample.length; j++)
				currExample[j] = MathUtils.normalize(currExample[j], 0, 255);
			examples[i - 1] = currExample;
			outcomeMatrix[i - 1] = ArrayUtil.toOutcomeArray(man.readLabel(), 10);
		}
		DoubleMatrix training = new DoubleMatrix(examples);
		return new Pair(training,MatrixUtil.toMatrix(outcomeMatrix));

	}


}









© 2015 - 2025 Weber Informatics LLC | Privacy Policy