Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.nd4j.linalg.dataset.MultiDataSet Maven / Gradle / Ivy
package org.nd4j.linalg.dataset;
import org.apache.commons.math3.util.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import java.io.*;
import java.util.*;
/**Implementation of {@link org.nd4j.linalg.dataset.api.MultiDataSet}
* @author Alex Black
*/
public class MultiDataSet implements org.nd4j.linalg.dataset.api.MultiDataSet {
private static final INDArray EMPTY_MASK_ARRAY_PLACEHOLDER = Nd4j.create(new float[]{-1});
private INDArray[] features;
private INDArray[] labels;
private INDArray[] featuresMaskArrays;
private INDArray[] labelsMaskArrays;
private List exampleMetaData;
/** Create a new (empty) MultiDataSet object (all fields are null) */
public MultiDataSet(){
}
/** MultiDataSet constructor with single features/labels input, no mask arrays */
public MultiDataSet(INDArray features, INDArray labels){
this( (features != null ? new INDArray[]{features} : null),
(labels != null ? new INDArray[]{labels} : null));
}
/** MultiDataSet constructor with no mask arrays */
public MultiDataSet(INDArray[] features, INDArray[] labels){
this(features,labels,null,null);
}
/**
*
* @param features The features (inputs) to the algorithm/neural network
* @param labels The labels (outputs) to the algorithm/neural network
* @param featuresMaskArrays The mask arrays for the features. May be null. Typically used with variable-length time series models, etc
* @param labelsMaskArrays The mask arrays for the labels. May be null. Typically used with variable-length time series models, etc
*/
public MultiDataSet(INDArray[] features, INDArray[] labels, INDArray[] featuresMaskArrays, INDArray[] labelsMaskArrays ){
if(features != null && featuresMaskArrays != null && features.length != featuresMaskArrays.length){
throw new IllegalArgumentException("Invalid features / features mask arrays combination: "
+ "features and features mask arrays must not be different lengths");
}
if(labels != null && labelsMaskArrays != null && labels.length != labelsMaskArrays.length){
throw new IllegalArgumentException("Invalid labels / labels mask arrays combination: "
+ "labels and labels mask arrays must not be different lengths");
}
this.features = features;
this.labels = labels;
this.featuresMaskArrays = featuresMaskArrays;
this.labelsMaskArrays = labelsMaskArrays;
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
}
@Override
public List getExampleMetaData(){
return exampleMetaData;
}
@Override
public List getExampleMetaData(Class metaDataType){
return (List)exampleMetaData;
}
@Override
public void setExampleMetaData(List extends Serializable> exampleMetaData){
this.exampleMetaData = (List)exampleMetaData;
}
@Override
public int numFeatureArrays() {
return (features != null ? features.length : 0);
}
@Override
public int numLabelsArrays() {
return (labels != null ? labels.length : 0);
}
@Override
public INDArray[] getFeatures() {
return features;
}
@Override
public INDArray getFeatures(int index) {
return features[index];
}
@Override
public void setFeatures(INDArray[] features) {
this.features = features;
}
@Override
public void setFeatures(int idx, INDArray features) {
this.features[idx] = features;
}
@Override
public INDArray[] getLabels() {
return labels;
}
@Override
public INDArray getLabels(int index) {
return labels[index];
}
@Override
public void setLabels(INDArray[] labels) {
this.labels = labels;
}
@Override
public void setLabels(int idx, INDArray labels) {
this.labels[idx] = labels;
}
@Override
public boolean hasMaskArrays() {
if( featuresMaskArrays == null && labelsMaskArrays == null ) return false;
if(featuresMaskArrays != null){
for( INDArray i : featuresMaskArrays ){
if(i != null) return true;
}
}
if(labelsMaskArrays != null){
for( INDArray i : labelsMaskArrays ){
if(i != null) return true;
}
}
return false;
}
@Override
public INDArray[] getFeaturesMaskArrays() {
return featuresMaskArrays;
}
@Override
public INDArray getFeaturesMaskArray(int index) {
return (featuresMaskArrays != null ? featuresMaskArrays[index] : null);
}
@Override
public void setFeaturesMaskArrays(INDArray[] maskArrays) {
this.featuresMaskArrays = maskArrays;
}
@Override
public void setFeaturesMaskArray(int idx, INDArray maskArray) {
this.featuresMaskArrays[idx] = maskArray;
}
@Override
public INDArray[] getLabelsMaskArrays() {
return labelsMaskArrays;
}
@Override
public INDArray getLabelsMaskArray(int index) {
return (labelsMaskArrays != null ? labelsMaskArrays[index] : null);
}
@Override
public void setLabelsMaskArray(INDArray[] labelsMaskArrays) {
this.labelsMaskArrays = labelsMaskArrays;
}
@Override
public void setLabelsMaskArray(int idx, INDArray labelsMaskArray) {
this.labelsMaskArrays[idx] = labelsMaskArray;
}
@Override
public void save(OutputStream to) throws IOException {
int numFArr = (features == null ? 0 : features.length);
int numLArr = (labels == null ? 0 : labels.length);
int numFMArr = (featuresMaskArrays == null ? 0 : featuresMaskArrays.length);
int numLMArr = (labelsMaskArrays == null ? 0 : labelsMaskArrays.length);
try (DataOutputStream dos = new DataOutputStream(new BufferedOutputStream(to))) {
dos.writeInt(numFArr);
dos.writeInt(numLArr);
dos.writeInt(numFMArr);
dos.writeInt(numLMArr);
saveINDArrays(features, dos, false);
saveINDArrays(labels, dos, false);
saveINDArrays(featuresMaskArrays, dos, true);
saveINDArrays(labelsMaskArrays, dos, true);
}
}
private void saveINDArrays(INDArray[] arrays, DataOutputStream dos, boolean isMask) throws IOException {
if (arrays != null && arrays.length > 0) {
for (INDArray fm : arrays) {
if (isMask && fm == null) {
fm = EMPTY_MASK_ARRAY_PLACEHOLDER;
}
Nd4j.write(fm, dos);
}
}
}
@Override
public void save(File to) throws IOException {
save(new FileOutputStream(to));
}
@Override
public void load(InputStream from) throws IOException {
try (DataInputStream dis = new DataInputStream(from)) {
int numFArr = dis.readInt();
int numLArr = dis.readInt();
int numFMArr = dis.readInt();
int numLMArr = dis.readInt();
features = loadINDArrays(numFArr, dis, false);
labels = loadINDArrays(numLArr, dis, false);
featuresMaskArrays = loadINDArrays(numFMArr, dis, true);
labelsMaskArrays = loadINDArrays(numLMArr, dis, true);
}
}
private INDArray[] loadINDArrays(int numArrays, DataInputStream dis, boolean isMask) throws IOException {
INDArray[] result = null;
if (numArrays > 0) {
result = new INDArray[numArrays];
for (int i = 0; i < numArrays; i++) {
INDArray arr = Nd4j.read(dis);
result[i] = isMask && arr.equals(EMPTY_MASK_ARRAY_PLACEHOLDER) ? null : arr;
}
}
return result;
}
@Override
public void load(File from) throws IOException {
load(new FileInputStream(from));
}
@Override
public List asList() {
int nExamples = features[0].size(0);
List list = new ArrayList<>();
for( int i=0; i toMerge){
if(toMerge.size() == 1){
org.nd4j.linalg.dataset.api.MultiDataSet mds = toMerge.iterator().next();
if(mds instanceof MultiDataSet) return (MultiDataSet) mds;
else return new MultiDataSet(mds.getFeatures(),mds.getLabels(),mds.getFeaturesMaskArrays(),mds.getLabelsMaskArrays());
}
List list;
if(toMerge instanceof List) list = (List)toMerge;
else list = new ArrayList<>(toMerge);
int nInArrays = list.get(0).numFeatureArrays();
int nOutArrays = list.get(0).numLabelsArrays();
INDArray[][] features = new INDArray[list.size()][0];
INDArray[][] labels = new INDArray[list.size()][0];
INDArray[][] featuresMasks = new INDArray[list.size()][0];
INDArray[][] labelsMasks = new INDArray[list.size()][0];
int i=0;
for( org.nd4j.linalg.dataset.api.MultiDataSet mds : list ){
features[i] = mds.getFeatures();
labels[i] = mds.getLabels();
featuresMasks[i] = mds.getFeaturesMaskArrays();
labelsMasks[i] = mds.getLabelsMaskArrays();
if(features[i] == null || features[i].length != nInArrays){
throw new IllegalStateException("Cannot merge MultiDataSets with different number of input arrays: toMerge[0] has "
+ nInArrays + " input arrays; toMerge[" + i + "] has " + (features[i] != null ? features[i].length : null) + " arrays");
}
if(labels[i] == null || labels[i].length != nOutArrays){
throw new IllegalStateException("Cannot merge MultiDataSets with different number of output arrays: toMerge[0] has "
+ nOutArrays + " output arrays; toMerge[" + i + "] has " + (labels[i] != null ? labels[i].length : null) + " arrays");
}
i++;
}
//Now, merge:
INDArray[] mergedFeatures = new INDArray[nInArrays];
INDArray[] mergedLabels = new INDArray[nOutArrays];
INDArray[] mergedFeaturesMasks = new INDArray[nInArrays];
INDArray[] mergedLabelsMasks = new INDArray[nOutArrays];
boolean needFeaturesMasks = false;
for( i=0; i pair = merge(features,featuresMasks,i);
mergedFeatures[i] = pair.getFirst();
mergedFeaturesMasks[i] = pair.getSecond();
if(mergedFeaturesMasks[i] != null) needFeaturesMasks = true;
}
if(!needFeaturesMasks) mergedFeaturesMasks = null;
boolean needLabelsMasks = false;
for( i=0; i pair = merge(labels,labelsMasks,i);
mergedLabels[i] = pair.getFirst();
mergedLabelsMasks[i] = pair.getSecond();
if(mergedLabelsMasks[i] != null) needLabelsMasks = true;
}
if(!needLabelsMasks) mergedLabelsMasks = null;
return new MultiDataSet(mergedFeatures,mergedLabels,mergedFeaturesMasks,mergedLabelsMasks);
}
private static Pair merge(INDArray[][] arrays, INDArray[][] masks, int column){
int rank = arrays[0][column].rank();
if(rank == 2){
return new Pair<>(merge2d(arrays,column),null);
} else if(rank == 3) {
return mergeTimeSeries(arrays,masks,column);
} else if(rank == 4){
return new Pair<>(merge4d(arrays,column),null);
} else {
throw new UnsupportedOperationException("Cannot merge arrays with rank 5 or more (input/output number: " + column + ")");
}
}
private static INDArray merge2d(INDArray[][] arrays, int inOutIdx){
//Merge 2d data. Mask arrays don't really make sense for 2d, hence are not used here
int nExamples = 0;
int cols = arrays[0][inOutIdx].columns();
for( int i=0; i mergeTimeSeries(INDArray[][] arrays, INDArray[][] masks, int inOutIdx){
//Merge time series data, and handle masking etc for different length arrays
//Complications with time series:
//(a) They may have different lengths (if so: need input + output masking arrays)
//(b) Even if they are all the same length, they may have masking arrays (if so: merge the masking arrays too)
int firstLength = arrays[0][inOutIdx].size(2);
int size = arrays[0][inOutIdx].size(1);
int maxLength = firstLength;
boolean hasMask = false;
boolean lengthsDiffer = false;
int totalExamples = 0;
for(int i=0; i(arr,null);
} else {
//Either different length, or have mask arrays (or, both)
for( int i=0; i(arr,mask);
}
private static INDArray merge4d(INDArray[][] arrays, int inOutIdx){
//4d -> images. Mask arrays for images: not really used
int nExamples = 0;
int[] shape = arrays[0][inOutIdx].shape();
for( int i=0; i