All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.datasets.DataSet Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.datasets;

import java.io.*;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;

import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.datasets.fetchers.MnistDataFetcher;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator;
import org.deeplearning4j.nn.Persistable;
import org.deeplearning4j.util.MathUtils;
import org.jblas.DoubleMatrix;
import org.jblas.SimpleBlas;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.collect.Lists;

/**
 * A data set (example/outcome pairs)
 * The outcomes are specifically for neural network encoding such that
 * any labels that are considered true are 1s. The rest are zeros.
 * @author Adam Gibson
 *
 */
public class DataSet extends Pair implements Persistable,Iterable {

	private static final long serialVersionUID = 1935520764586513365L;
	private static Logger log = LoggerFactory.getLogger(DataSet.class);

	public DataSet(Pair pair) {
		this(pair.getFirst(),pair.getSecond());
	}

	public DataSet(DoubleMatrix first, DoubleMatrix second) {
		super(first, second);
		if(first.rows != second.rows)
			throw new IllegalStateException("Invalid data set; first and second do not have equal rows. First was " + first.rows + " second was " + second.rows);
		
		
	}

	public DataSetIterator iterator(int batches) {
		List list = this.dataSetBatches(batches);
		return new ListDataSetIterator(list);
	}
	
	
	public DataSet copy() {
		return new DataSet(getFirst(),getSecond());
	}
	
	
	public static DataSet empty() {
		return new DataSet(DoubleMatrix.zeros(1),DoubleMatrix.zeros(1));
	}
	
	public static DataSet merge(List data) {
		if(data.isEmpty())
			throw new IllegalArgumentException("Unable to merge empty dataset");
		DataSet first = data.iterator().next();

		DoubleMatrix in = new DoubleMatrix(data.size(),first.getFirst().columns);
		DoubleMatrix out = new DoubleMatrix(data.size(),first.getSecond().columns);


		for(int i = 0; i < data.size(); i++) {
			in.putRow(i,data.get(i).getFirst());
			out.putRow(i,data.get(i).getSecond());

		}
		return new DataSet(in,out);
	}

	
	
	public int numInputs() {
		return getFirst().columns;
	}
	
	public void validate() {
		if(getFirst().rows != getSecond().rows)
			throw new IllegalStateException("Invalid dataset");
	}
	
	public int outcome() {
		if(this.numExamples() > 1)
			throw new IllegalStateException("Unable to derive outcome for dataset greater than one row");
		return SimpleBlas.iamax(getSecond());
	}

	public DataSet get(int i) {
		return new DataSet(getFirst().getRow(i),getSecond().getRow(i));
	}
	
	public List> batchBy(int num) {
		return Lists.partition(asList(),num);
	}
	
	
	
	
	public List dataSetBatches(int num) {
		List> list =  Lists.partition(asList(),num);
		List ret = new ArrayList<>();
		for(List l : list)
			ret.add(DataSet.merge(l));
		return ret;
	
	}
	

	/**
	 * Sorts the dataset by label:
	 * Splits the data set such that examples are sorted by their labels.
	 * A ten label dataset would produce lists with batches like the following:
	 * x1   y = 1
	 * x2   y = 2
	 * ...
	 * x10  y = 10
	 * @return a list of data sets partitioned by outcomes
	 */
	public List> sortAndBatchByNumLabels() {
		sortByLabel();
		return Lists.partition(asList(),numOutcomes());
	}

	public List> batchByNumLabels() {
		return Lists.partition(asList(),numOutcomes());
	}


	public List asList() {
		List list = new ArrayList(numExamples());
		for(int i = 0; i < numExamples(); i++)  {
			list.add(new DataSet(getFirst().getRow(i),getSecond().getRow(i)));
		}
		return list;
	}

	public Pair splitTestAndTrain(int numHoldout) {

		if(numHoldout >= numExamples())
			throw new IllegalArgumentException("Unable to split on size larger than the number of rows");


		List list = asList();

		Collections.rotate(list, 3);
		Collections.shuffle(list);
		List> partition = new ArrayList>();
		partition.add(list.subList(0, numHoldout));
		partition.add(list.subList(numHoldout, list.size()));
		DataSet train = merge(partition.get(0));
		DataSet test = merge(partition.get(1));
		return new Pair<>(train,test);
	}

	/**
	 * Organizes the dataset to minimize sampling error
	 * while still allowing efficient batching.
	 */
	public void sortByLabel() {
		Map> map = new HashMap>();
		List data = asList();
		int numLabels = numOutcomes();
		int examples = numExamples();
		for(DataSet d : data) {
			int label = getLabel(d);
			Queue q = map.get(label);
			if(q == null) {
				q = new ArrayDeque();
				map.put(label, q);
			}
			q.add(d);
		}

		for(Integer label : map.keySet()) {
			log.info("Label " + label + " has " + map.get(label).size() + " elements");
		}

		//ideal input splits: 1 of each label in each batch
		//after we run out of ideal batches: fall back to a new strategy
		boolean optimal = true;
		for(int i = 0; i < examples; i++) {
			if(optimal) {
				for(int j = 0; j < numLabels; j++) {
					Queue q = map.get(j);
					DataSet next = q.poll();
					//add a row; go to next
					if(next != null) {
						addRow(next,i);
						i++;
					}
					else {
						optimal = false;
						break;
					}
				}
			}
			else {
				DataSet add = null;
				for(Queue q : map.values()) {
					if(!q.isEmpty()) {
						add = q.poll();
						break;
					}
				}

				addRow(add,i);

			}


		}


	}


	public void addRow(DataSet d, int i) {
		if(i > numExamples() || d == null)
			throw new IllegalArgumentException("Invalid index for adding a row");
		getFirst().putRow(i, d.getFirst());
		getSecond().putRow(i,d.getSecond());
	}


	private int getLabel(DataSet data) {
		return SimpleBlas.iamax(data.getSecond());
	}


	public DoubleMatrix exampleSums() {
		return getFirst().columnSums();
	}

	public DoubleMatrix exampleMaxs() {
		return getFirst().columnMaxs();
	}

	public DoubleMatrix exampleMeans() {
		return getFirst().columnMeans();
	}

	public void saveTo(File file,boolean binary) throws IOException {
		if(file.exists())
			file.delete();
		file.createNewFile();

		if(binary) {
			DataOutputStream bos = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(file)));
			getFirst().out(bos);
			getSecond().out(bos);
			bos.flush();
			bos.close();

		}
		else {
			BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(file));
			for(int i = 0; i < numExamples(); i++) {
				bos.write(getFirst().getRow(i).toString("%.3f", "[", "]", ", ", ";").getBytes());
				bos.write("\t".getBytes());
				bos.write(getSecond().getRow(i).toString("%.3f", "[", "]", ", ", ";").getBytes());
				bos.write("\n".getBytes())	;


			}
			bos.flush();
			bos.close();

		}
	}


	public static DataSet load(File path) throws IOException {
		DataInputStream bis = new DataInputStream(new BufferedInputStream(new FileInputStream(path)));
		DoubleMatrix x = new DoubleMatrix(1,1);
		DoubleMatrix y = new DoubleMatrix(1,1);
		x.in(bis);
		y.in(bis);
		bis.close();
		return new DataSet(x,y);
	}

	/**
	 * Sample without replacement and a random rng
	 * @param numSamples the number of samples to get
	 * @return a sample data set without replacement
	 */
	public DataSet sample(int numSamples) {
		return sample(numSamples,new MersenneTwister(System.currentTimeMillis()));
	}

	/**
	 * Sample without replacement
	 * @param numSamples the number of samples to get
	 * @param rng the rng to use
	 * @return the sampled dataset without replacement
	 */
	public DataSet sample(int numSamples,RandomGenerator rng) {
		return sample(numSamples,rng,false);
	}

	/**
	 * Sample a dataset numSamples times
	 * @param numSamples the number of samples to get
	 * @param withReplacement the rng to use
	 * @return the sampled dataset without replacement
	 */
	public DataSet sample(int numSamples,boolean withReplacement) {
		return sample(numSamples,new MersenneTwister(System.currentTimeMillis()),withReplacement);
	}

	/**
	 * Sample a dataset
	 * @param numSamples the number of samples to get
	 * @param rng the rng to use
	 * @param withReplacement whether to allow duplicates (only tracked by example row number)
	 * @return the sample dataset
	 */
	public DataSet sample(int numSamples,RandomGenerator rng,boolean withReplacement) {
		if(numSamples >= numExamples())
			return this;
		else {
			DoubleMatrix examples = new DoubleMatrix(numSamples,getFirst().columns);
			DoubleMatrix outcomes = new DoubleMatrix(numSamples,numOutcomes());
			Set added = new HashSet();
			for(int i = 0; i < numSamples; i++) {
				int picked = rng.nextInt(numExamples());
				while(added.contains(picked)) {
					picked = rng.nextInt(numExamples());

				}
				examples.putRow(i,getFirst().getRow(i));
				outcomes.putRow(i,getSecond().getRow(i));

			}
			return new DataSet(examples,outcomes);
		}
	}

	public void roundToTheNearest(int roundTo) {
		for(int i = 0; i < getFirst().length; i++) {
			double curr = getFirst().get(i);
			getFirst().put(i,MathUtils.roundDouble(curr, roundTo));
		}
	}

	public int numOutcomes() {
		return getSecond().columns;
	}

	public int numExamples() {
		return getFirst().rows;
	}

	
	

	@Override
	public String toString() {
		StringBuilder builder = new StringBuilder();
		builder.append("===========INPUT===================\n")
		.append(getFirst().toString().replaceAll(";","\n"))
		.append("\n=================OUTPUT==================\n")
		.append(getSecond().toString().replaceAll(";","\n"));
		return builder.toString();
	}

	public static void main(String[] args) throws IOException {
		MnistDataFetcher fetcher = new MnistDataFetcher();
		fetcher.fetch(100);
		DataSet write = new DataSet(fetcher.next());
		write.saveTo(new File(args[0]), false);


	}

	@Override
	public void write(OutputStream os) {
		DataOutputStream dos = new DataOutputStream(os);
		
		try {
			getFirst().out(dos);
			getSecond().out(dos);

		} catch (IOException e) {
			throw new RuntimeException(e);
		}
		
	}

	@Override
	public void load(InputStream is) {
		DataInputStream dis = new DataInputStream(is);
		try {
			getFirst().in(dis);
			getSecond().in(dis);

		} catch (IOException e) {
			throw new RuntimeException(e);
		}
	}

	@Override
	public Iterator iterator() {
		return asList().iterator();
	}

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy