All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.dataset.MultiDataSet Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.nd4j.linalg.dataset;

import org.apache.commons.math3.util.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

import java.io.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;

/**Implementation of {@link org.nd4j.linalg.dataset.api.MultiDataSet}
 * @author Alex Black
 */
public class MultiDataSet implements org.nd4j.linalg.dataset.api.MultiDataSet {

    private INDArray[] features;
    private INDArray[] labels;
    private INDArray[] featuresMaskArrays;
    private INDArray[] labelsMaskArrays;

    /** Create a new (empty) MultiDataSet object (all fields are null) */
    public MultiDataSet(){

    }

    /** MultiDataSet constructor with single features/labels input, no mask arrays */
    public MultiDataSet(INDArray features, INDArray labels){
        this( (features != null ? new INDArray[]{features} : null),
                (labels != null ? new INDArray[]{labels} : null));
    }

    /** MultiDataSet constructor with no mask arrays */
    public MultiDataSet(INDArray[] features, INDArray[] labels){
        this(features,labels,null,null);
    }

    /**
     *
     * @param features The features (inputs) to the algorithm/neural network
     * @param labels The labels (outputs) to the algorithm/neural network
     * @param featuresMaskArrays The mask arrays for the features. May be null. Typically used with variable-length time series models, etc
     * @param labelsMaskArrays The mask arrays for the labels. May be null. Typically used with variable-length time series models, etc
     */
    public MultiDataSet(INDArray[] features, INDArray[] labels, INDArray[] featuresMaskArrays, INDArray[] labelsMaskArrays ){
        if(features != null && featuresMaskArrays != null && features.length != featuresMaskArrays.length){
            throw new IllegalArgumentException("Invalid features / features mask arrays combination: "
                    + "features and features mask arrays must not be different lengths");
        }
        if(labels != null && labelsMaskArrays != null && labels.length != labelsMaskArrays.length){
            throw new IllegalArgumentException("Invalid labels / labels mask arrays combination: "
                    + "labels and labels mask arrays must not be different lengths");
        }

        this.features = features;
        this.labels = labels;
        this.featuresMaskArrays = featuresMaskArrays;
        this.labelsMaskArrays = labelsMaskArrays;

        if (Nd4j.getExecutioner() instanceof GridExecutioner)
            ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();

    }


    @Override
    public int numFeatureArrays() {
        return (features != null ? features.length : 0);
    }

    @Override
    public int numLabelsArrays() {
        return (labels != null ? labels.length : 0);
    }

    @Override
    public INDArray[] getFeatures() {
        return features;
    }

    @Override
    public INDArray getFeatures(int index) {
        return features[index];
    }

    @Override
    public void setFeatures(INDArray[] features) {
        this.features = features;
    }

    @Override
    public void setFeatures(int idx, INDArray features) {
        this.features[idx] = features;
    }

    @Override
    public INDArray[] getLabels() {
        return labels;
    }

    @Override
    public INDArray getLabels(int index) {
        return labels[index];
    }

    @Override
    public void setLabels(INDArray[] labels) {
        this.labels = labels;
    }

    @Override
    public void setLabels(int idx, INDArray labels) {
        this.labels[idx] = labels;
    }

    @Override
    public boolean hasMaskArrays() {
        if( featuresMaskArrays == null && labelsMaskArrays == null ) return false;
        if(featuresMaskArrays != null){
            for( INDArray i : featuresMaskArrays ){
                if(i != null) return true;
            }
        }
        if(labelsMaskArrays != null){
            for( INDArray i : labelsMaskArrays ){
                if(i != null) return true;
            }
        }
        return false;
    }

    @Override
    public INDArray[] getFeaturesMaskArrays() {
        return featuresMaskArrays;
    }

    @Override
    public INDArray getFeaturesMaskArray(int index) {
        return (featuresMaskArrays != null ? featuresMaskArrays[index] : null);
    }

    @Override
    public void setFeaturesMaskArrays(INDArray[] maskArrays) {
        this.featuresMaskArrays = maskArrays;
    }

    @Override
    public void setFeaturesMaskArray(int idx, INDArray maskArray) {
        this.featuresMaskArrays[idx] = maskArray;
    }

    @Override
    public INDArray[] getLabelsMaskArrays() {
        return labelsMaskArrays;
    }

    @Override
    public INDArray getLabelsMaskArray(int index) {
        return (labelsMaskArrays != null ? labelsMaskArrays[index] : null);
    }

    @Override
    public void setLabelsMaskArray(INDArray[] labelsMaskArrays) {
        this.labelsMaskArrays = labelsMaskArrays;
    }

    @Override
    public void setLabelsMaskArray(int idx, INDArray labelsMaskArray) {
        this.labelsMaskArrays[idx] = labelsMaskArray;
    }

    @Override
    public void save(OutputStream to) throws IOException {
        int numFArr = (features == null ? 0 : features.length);
        int numLArr = (labels == null ? 0 : labels.length);
        int numFMArr = (featuresMaskArrays == null ? 0 : featuresMaskArrays.length);
        int numLMArr = (labelsMaskArrays == null ? 0 : labelsMaskArrays.length);

        try(DataOutputStream dos = new DataOutputStream(new BufferedOutputStream(to))){
            dos.writeInt(numFArr);
            dos.writeInt(numLArr);
            dos.writeInt(numFMArr);
            dos.writeInt(numLMArr);

            if(features != null && features.length > 0){
                for( INDArray f : features){
                    Nd4j.write(f, dos);
                }
            }

            if(labels != null && labels.length > 0){
                for( INDArray l : labels ){
                    Nd4j.write(l, dos);
                }
            }

            if(featuresMaskArrays != null && featuresMaskArrays.length > 0){
                for(INDArray fm : featuresMaskArrays){
                    Nd4j.write(fm, dos);
                }
            }

            if(labelsMaskArrays != null && labelsMaskArrays.length > 0){
                for( INDArray lm : labelsMaskArrays ){
                    Nd4j.write(lm, dos);
                }
            }
        }
    }

    @Override
    public void save(File to) throws IOException {
        save(new FileOutputStream(to));
    }

    @Override
    public void load(InputStream from) throws IOException {

        try(DataInputStream dis = new DataInputStream(from)){
            int numFArr = dis.readInt();
            int numLArr = dis.readInt();
            int numFMArr = dis.readInt();
            int numLMArr = dis.readInt();

            if(numFArr > 0){
                features = new INDArray[numFArr];
                for( int i=0; i 0){
                labels = new INDArray[numLArr];
                for( int i=0; i 0){
                featuresMaskArrays = new INDArray[numFMArr];
                for( int i=0; i 0){
                labelsMaskArrays = new INDArray[numLMArr];
                for( int i=0; i asList() {
        int nExamples = features[0].size(0);

        List list = new ArrayList<>();

        for( int i=0; i toMerge){
        if(toMerge.size() == 1){
            org.nd4j.linalg.dataset.api.MultiDataSet mds = toMerge.iterator().next();
            if(mds instanceof MultiDataSet) return (MultiDataSet) mds;
            else return new MultiDataSet(mds.getFeatures(),mds.getLabels(),mds.getFeaturesMaskArrays(),mds.getLabelsMaskArrays());
        }

        List list;
        if(toMerge instanceof List) list = (List)toMerge;
        else list = new ArrayList<>(toMerge);

        int nInArrays = list.get(0).numFeatureArrays();
        int nOutArrays = list.get(0).numLabelsArrays();

        INDArray[][] features = new INDArray[list.size()][0];
        INDArray[][] labels = new INDArray[list.size()][0];
        INDArray[][] featuresMasks = new INDArray[list.size()][0];
        INDArray[][] labelsMasks = new INDArray[list.size()][0];

        int i=0;
        for( org.nd4j.linalg.dataset.api.MultiDataSet mds : list ){
            features[i] = mds.getFeatures();
            labels[i] = mds.getLabels();
            featuresMasks[i] = mds.getFeaturesMaskArrays();
            labelsMasks[i] = mds.getLabelsMaskArrays();

            if(features[i] == null || features[i].length != nInArrays){
                throw new IllegalStateException("Cannot merge MultiDataSets with different number of input arrays: toMerge[0] has "
                        + nInArrays + " input arrays; toMerge[" + i + "] has " + (features[i] != null ? features[i].length : null) + " arrays");
            }
            if(labels[i] == null || labels[i].length != nOutArrays){
                throw new IllegalStateException("Cannot merge MultiDataSets with different number of output arrays: toMerge[0] has "
                        + nOutArrays + " output arrays; toMerge[" + i + "] has " + (labels[i] != null ? labels[i].length : null) + " arrays");
            }

            i++;
        }

        //Now, merge:
        INDArray[] mergedFeatures = new INDArray[nInArrays];
        INDArray[] mergedLabels = new INDArray[nOutArrays];
        INDArray[] mergedFeaturesMasks = new INDArray[nInArrays];
        INDArray[] mergedLabelsMasks = new INDArray[nOutArrays];

        boolean needFeaturesMasks = false;
        for( i=0; i pair = merge(features,featuresMasks,i);
            mergedFeatures[i] = pair.getFirst();
            mergedFeaturesMasks[i] = pair.getSecond();
            if(mergedFeaturesMasks[i] != null) needFeaturesMasks = true;
        }
        if(!needFeaturesMasks) mergedFeaturesMasks = null;

        boolean needLabelsMasks = false;
        for( i=0; i pair = merge(labels,labelsMasks,i);
            mergedLabels[i] = pair.getFirst();
            mergedLabelsMasks[i] = pair.getSecond();
            if(mergedLabelsMasks[i] != null) needLabelsMasks = true;
        }
        if(!needLabelsMasks) mergedLabelsMasks = null;

        return new MultiDataSet(mergedFeatures,mergedLabels,mergedFeaturesMasks,mergedLabelsMasks);
    }

    private static Pair merge(INDArray[][] arrays, INDArray[][] masks, int column){
        int rank = arrays[column][0].rank();
        if(rank == 2){
            return new Pair<>(merge2d(arrays,column),null);
        } else if(rank == 3) {
            return mergeTimeSeries(arrays,masks,column);
        } else if(rank == 4){
            return new Pair<>(merge4d(arrays,column),null);
        } else {
            throw new UnsupportedOperationException("Cannot merge arrays with rank 5 or more (input/output number: " + column + ")");
        }
    }

    private static INDArray merge2d(INDArray[][] arrays, int inOutIdx){
        //Merge 2d data. Mask arrays don't really make sense for 2d, hence are not used here
        int nExamples = 0;
        int cols = arrays[0][inOutIdx].columns();
        for( int i=0; i mergeTimeSeries(INDArray[][] arrays, INDArray[][] masks, int inOutIdx){
        //Merge time series data, and handle masking etc for different length arrays

        //Complications with time series:
        //(a) They may have different lengths (if so: need input + output masking arrays)
        //(b) Even if they are all the same length, they may have masking arrays (if so: merge the masking arrays too)

        int firstLength = arrays[0][inOutIdx].size(2);
        int size = arrays[0][inOutIdx].size(1);
        int maxLength = firstLength;

        boolean hasMask = false;
        boolean lengthsDiffer = false;
        int totalExamples = 0;
        for(int i=0; i(arr,null);
        } else {
            //Either different length, or have mask arrays (or, both)
            for( int i=0; i(arr,mask);
    }

    private static INDArray merge4d(INDArray[][] arrays, int inOutIdx){
        //4d -> images. Mask arrays for images: not really used

        int nExamples = 0;
        int[] shape = arrays[0][inOutIdx].shape();
        for( int i=0; i




© 2015 - 2024 Weber Informatics LLC | Privacy Policy