All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.util.InputSplit Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.util;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;

import org.deeplearning4j.berkeley.Pair;
import org.jblas.DoubleMatrix;


public class InputSplit {

	public static void splitInputs(DoubleMatrix inputs,DoubleMatrix outcomes,List> train,List> test,double split) {
		List inputRows = inputs.rowsAsList();
		List outcomeRows = outcomes.rowsAsList();
		assert inputRows.size() == outcomeRows.size();
		List> list = new ArrayList<>();
		for(int i = 0; i < inputRows.size(); i++) {
			list.add(new Pair<>(inputRows.get(i),outcomeRows.get(i)));
		}

		splitInputs(list,train,test,split);
	}

	public static void splitInputs(List> pairs,List> train,List> test,double split) {
		Random rand = new Random();

		for(Pair pair : pairs)
			if(rand.nextDouble() <= split) 
				train.add(pair);
			else
				test.add(pair);

			
	}

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy