org.deeplearning4j.util.InputSplit Maven / Gradle / Ivy
package org.deeplearning4j.util;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.deeplearning4j.berkeley.Pair;
import org.jblas.DoubleMatrix;
public class InputSplit {
public static void splitInputs(DoubleMatrix inputs,DoubleMatrix outcomes,List> train,List> test,double split) {
List inputRows = inputs.rowsAsList();
List outcomeRows = outcomes.rowsAsList();
assert inputRows.size() == outcomeRows.size();
List> list = new ArrayList<>();
for(int i = 0; i < inputRows.size(); i++) {
list.add(new Pair<>(inputRows.get(i),outcomeRows.get(i)));
}
splitInputs(list,train,test,split);
}
public static void splitInputs(List> pairs,List> train,List> test,double split) {
Random rand = new Random();
for(Pair pair : pairs)
if(rand.nextDouble() <= split)
train.add(pair);
else
test.add(pair);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy