All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.dataset.DataSet Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.nd4j.linalg.dataset;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;

import com.google.common.base.Function;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.util.FeatureUtil;
import org.nd4j.linalg.util.MathUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.collect.Lists;

/**
 * A data applyTransformToDestination (example/outcome pairs)
 * The outcomes are specifically for neural network encoding such that
 * any labels that are considered true are 1s. The rest are zeros.
 * @author Adam Gibson
 *
 */
public class DataSet  implements org.nd4j.linalg.dataset.api.DataSet {

    private static final long serialVersionUID = 1935520764586513365L;
    private static Logger log = LoggerFactory.getLogger(DataSet.class);
    private List columnNames = new ArrayList<>();
    private List labelNames = new ArrayList<>();
    private INDArray features,labels;

    public DataSet() {
        this(Nd4j.zeros(new int[]{1}), Nd4j.zeros(new int[]{1}));
    }

    /**
     * Creates a dataset with the specified input matrix and labels
     * @param first the feature matrix
     * @param second the labels (these should be binarized label matrices such that the specified label
     *               has a value of 1 in the desired column with the label)
     */
    public DataSet(INDArray first, INDArray second) {
        if(first.rows() != second.rows())
            throw new IllegalStateException("Invalid data applyTransformToDestination; first and second do not have equal rows. First was " + first.rows() + " second was " + second.rows());
        this.features = first;
        this.labels = second;
    }



    @Override
    public INDArray getFeatures() {
        return features;
    }

    @Override
    public void apply(Condition condition, Function function) {
        BooleanIndexing.applyWhere(getFeatureMatrix(),condition,function);
    }

    @Override
    public void setFeatures(INDArray features) {
        this.features = features;
    }

    @Override
    public void setLabels(INDArray labels) {
        this.labels = labels;
    }

    /**
     * Clone the dataset
     * @return a clone of the dataset
     */
    @Override
    public DataSet copy() {
        return new DataSet(getFeatures().dup(),getLabels().dup());
    }


    /**
     * Returns a single dataset
     * @return an empty dataset with 2 1x1 zero matrices
     */
    public static DataSet empty() {
        return new DataSet(Nd4j.zeros(new int[]{1}), Nd4j.zeros(new int[]{1}));
    }

    /**
     * Merge the list of datasets in to one list.
     * All the rows are merged in to one dataset
     * @param data the data to merge
     * @return a single dataset
     */
    public static DataSet merge(List data) {
        if(data.isEmpty())
            throw new IllegalArgumentException("Unable to merge empty dataset");
        DataSet first = data.get(0);
        int numExamples = totalExamples(data);
        INDArray in = Nd4j.create(numExamples, first.getFeatures().columns());
        INDArray out = Nd4j.create(numExamples, first.getLabels().columns());
        int count = 0;

        for(int i = 0; i < data.size(); i++) {
            DataSet d1 = data.get(i);
            for(int j = 0; j < d1.numExamples(); j++) {
                DataSet example = d1.get(j);
                in.putRow(count,example.getFeatures());
                out.putRow(count,example.getLabels());
                count++;
            }


        }
        return new DataSet(in,out);
    }

    /**
     * Reshapes the input in to the given rows and columns
     * @param rows the row size
     * @param cols the column size
     * @return a copy of this data applyTransformToDestination with the input resized
     */
    @Override
    public DataSet reshape(int rows, int cols) {
        DataSet ret = new DataSet(getFeatures().reshape(new int[]{rows,cols}),getLabels());
        return ret;
    }


    @Override
    public void multiplyBy(double num) {
        getFeatures().muli(Nd4j.scalar(num));
    }

    @Override
    public void divideBy(int num) {
        getFeatures().divi(Nd4j.scalar(num));
    }

    @Override
    public void shuffle() {
        List list = asList();
        Collections.shuffle(list);
        DataSet ret = DataSet.merge(list);
        setFeatures(ret.getFeatures());
        setLabels(ret.getLabels());
    }





    /**
     * Squeezes input data to a max and a min
     * @param min the min value to occur in the dataset
     * @param max the max value to ccur in the dataset
     */
    @Override
    public void squishToRange(double min, double max) {
        for(int i = 0;i  < getFeatures().length(); i++) {
            double curr = (double) getFeatures().getScalar(i).element();
            if(curr < min)
                getFeatures().put(i, Nd4j.scalar(min));
            else if(curr> max)
                getFeatures().put(i, Nd4j.scalar(max));
        }
    }

    @Override
    public void scaleMinAndMax(double min, double max) {
        FeatureUtil.scaleMinMax(min,max,getFeatureMatrix());
    }

    /**
     * Divides the input data applyTransformToDestination by the max number in each row
     */
    @Override
    public void scale() {
        FeatureUtil.scaleByMax(getFeatures());
    }

    /**
     * Adds a feature for each example on to the current feature vector
     * @param toAdd the feature vector to add
     */
    @Override
    public void addFeatureVector(INDArray toAdd) {
        setFeatures(Nd4j.hstack());
    }


    /**
     * The feature to add, and the example/row number
     * @param feature the feature vector to add
     * @param example the number of the example to append to
     */
    @Override
    public void addFeatureVector(INDArray feature, int example) {
        getFeatures().putRow(example, Nd4j.hstack());
    }

    @Override
    public void normalize() {
        FeatureUtil.normalizeMatrix(getFeatures());
    }


    /**
     * Same as calling binarize(0)
     */
    @Override
    public void binarize() {
        binarize(0);
    }

    /**
     * Binarizes the dataset such that any number greater than cutoff is 1 otherwise zero
     * @param cutoff the cutoff point
     */
    @Override
    public void binarize(double cutoff) {
        for(int i = 0; i < getFeatures().length(); i++) {
            double curr = (double) getFeatures().getScalar(i).element();
            if (curr > cutoff)
                getFeatures().put(i, Nd4j.scalar(1));
            else
                getFeatures().put(i, Nd4j.scalar(0));
        }
    }


    /**
     * Subtract by the column means and divide by the standard deviation
     */
    @Override
    public void normalizeZeroMeanZeroUnitVariance() {
        INDArray columnMeans = getFeatures().mean(0);
        INDArray columnStds = getFeatureMatrix().std(0);

        setFeatures(getFeatures().subiRowVector(columnMeans));
        columnStds.addi(Nd4j.scalar(1e-6));
        setFeatures(getFeatures().diviRowVector(columnStds));
    }

    private static int totalExamples(Collection coll) {
        int count = 0;
        for(DataSet d : coll)
            count += d.numExamples();
        return count;
    }


    /**
     * The number of inputs in the feature matrix
     * @return
     */
    @Override
    public int numInputs() {
        return getFeatures().columns();
    }

    @Override
    public void validate() {
        if(getFeatures().rows() != getLabels().rows())
            throw new IllegalStateException("Invalid dataset");
    }

    @Override
    public int outcome() {
        if(this.numExamples() > 1)
            throw new IllegalStateException("Unable to derive outcome for dataset greater than one row");
        return Nd4j.getBlasWrapper().iamax(getLabels());
    }


    /**
     * Clears the outcome matrix setting a new number of labels
     * @param labels the number of labels/columns in the outcome matrix
     * Note that this clears the labels for each example
     */
    @Override
    public void setNewNumberOfLabels(int labels) {
        int examples = numExamples();
        INDArray newOutcomes = Nd4j.create(examples, labels);
        setLabels(newOutcomes);
    }

    /**
     * Sets the outcome of a particular example
     * @param example the example to applyTransformToDestination
     * @param label the label of the outcome
     */
    @Override
    public void setOutcome(int example, int label) {
        if(example > numExamples())
            throw new IllegalArgumentException("No example at " + example);
        if(label > numOutcomes() || label < 0)
            throw new IllegalArgumentException("Illegal label");

        INDArray outcome = FeatureUtil.toOutcomeVector(label, numOutcomes());
        getLabels().putRow(example,outcome);
    }

    /**
     * Gets a copy of example i
     * @param i the example to getFromOrigin
     * @return the example at i (one example)
     */
    @Override
    public DataSet get(int i) {
        if(i > numExamples() || i < 0)
            throw new IllegalArgumentException("invalid example number");

        return new DataSet(getFeatures().getRow(i),getLabels().getRow(i));
    }

    /**
     * Gets a copy of example i
     * @param i the example to getFromOrigin
     * @return the example at i (one example)
     */
    @Override
    public DataSet get(int[] i) {
        return new DataSet(getFeatures().getRows(i),getLabels().getRows(i));
    }


    /**
     * Partitions a dataset in to mini batches where
     * each dataset in each list is of the specified number of examples
     * @param num the number to split by
     * @return the partitioned datasets
     */
    @Override
    public List> batchBy(int num) {
        return Lists.partition(asList(),num);
    }




    /**
     * Strips the data applyTransformToDestination of all but the passed in labels
     * @param labels strips the data applyTransformToDestination of all but the passed in labels
     * @return the dataset with only the specified labels
     */
    @Override
    public DataSet filterBy(int[] labels) {
        List list = asList();
        List newList = new ArrayList<>();
        List labelList = new ArrayList<>();
        for(int i : labels)
            labelList.add(i);
        for(DataSet d : list) {
            int outcome = d.outcome();
            if(labelList.contains(outcome)) {
                newList.add(d);
            }
        }

        return DataSet.merge(newList);
    }


    /**
     * Strips the dataset down to the specified labels
     * and remaps them
     * @param labels the labels to strip down to
     */
    @Override
    public void filterAndStrip(int[] labels) {
        DataSet filtered = filterBy(labels);
        List newLabels = new ArrayList<>();

        //map new labels to index according to passed in labels
        Map labelMap = new HashMap<>();

        for(int i = 0; i < labels.length; i++)
            labelMap.put(labels[i],i);

        //map examples
        for(int i = 0; i < filtered.numExamples(); i++)  {
            int o2 = filtered.get(i).outcome();
            Integer outcome = labelMap.get(o2);
            newLabels.add(outcome);

        }


        INDArray newLabelMatrix = Nd4j.create(filtered.numExamples(), labels.length);

        if(newLabelMatrix.rows() != newLabels.size())
            throw new IllegalStateException("Inconsistent label sizes");

        for(int i = 0; i < newLabelMatrix.rows(); i++) {
            Integer i2 = newLabels.get(i);
            if(i2 == null)
                throw new IllegalStateException("Label not found on row " + i);
            INDArray newRow = FeatureUtil.toOutcomeVector(i2, labels.length);
            newLabelMatrix.putRow(i,newRow);

        }

        setFeatures(filtered.getFeatures());
        setLabels(newLabelMatrix);
    }




    /**
     * Partitions the data applyTransformToDestination by the specified number.
     * @param num the number to split by
     * @return the partitioned data applyTransformToDestination
     */
    @Override
    public List dataSetBatches(int num) {
        List> list =  Lists.partition(asList(),num);
        List ret = new ArrayList<>();
        for(List l : list)
            ret.add(DataSet.merge(l));
        return ret;

    }


    /**
     * Sorts the dataset by label:
     * Splits the data applyTransformToDestination such that examples are sorted by their labels.
     * A ten label dataset would produce lists with batches like the following:
     * x1   y = 1
     * x2   y = 2
     * ...
     * x10  y = 10
     * @return a list of data sets partitioned by outcomes
     */
    @Override
    public List> sortAndBatchByNumLabels() {
        sortByLabel();
        return Lists.partition(asList(),numOutcomes());
    }

    @Override
    public List> batchByNumLabels() {
        return Lists.partition(asList(),numOutcomes());
    }


    @Override
    public List asList() {
        List list = new ArrayList<>(numExamples());
        for(int i = 0; i < numExamples(); i++)  {
            list.add(new DataSet(getFeatures().getRow(i),getLabels().getRow(i)));
        }
        return list;
    }


    /**
     * Splits a dataset in to test and train
     * @param numHoldout the number to hold out for training
     * @return the pair of datasets for the train test split
     */
    @Override
    public SplitTestAndTrain splitTestAndTrain(int numHoldout) {

        if(numHoldout >= numExamples())
            throw new IllegalArgumentException("Unable to split on size larger than the number of rows");


        List list = asList();

        Collections.rotate(list, 3);
        Collections.shuffle(list);
        List> partition = new ArrayList<>();
        partition.add(list.subList(0, numHoldout));
        partition.add(list.subList(numHoldout, list.size()));
        DataSet train = merge(partition.get(0));
        DataSet test = merge(partition.get(1));
        return new SplitTestAndTrain(train,test);
    }


    /**
     * Returns the labels for the dataset
     * @return the labels for the dataset
     */
    @Override
    public INDArray getLabels() {
        return labels;
    }

    /**
     * Get the feature matrix (inputs for the data)
     * @return the feature matrix for the dataset
     */
    @Override
    public INDArray getFeatureMatrix() {
        return getFeatures();
    }


    /**
     * Organizes the dataset to minimize sampling error
     * while still allowing efficient batching.
     */
    @Override
    public void sortByLabel() {
        Map> map = new HashMap<>();
        List data = asList();
        int numLabels = numOutcomes();
        int examples = numExamples();
        for(DataSet d : data) {
            int label = d.outcome();
            Queue q = map.get(label);
            if(q == null) {
                q = new ArrayDeque<>();
                map.put(label, q);
            }
            q.add(d);
        }

        for(Integer label : map.keySet()) {
            log.info("Label " + label + " has " + map.get(label).size() + " elements");
        }

        //ideal input splits: 1 of each label in each batch
        //after we run out of ideal batches: fall back to a new strategy
        boolean optimal = true;
        for(int i = 0; i < examples; i++) {
            if(optimal) {
                for(int j = 0; j < numLabels; j++) {
                    Queue q = map.get(j);
                    if(q == null) {
                        optimal = false;
                        break;
                    }
                    DataSet next = q.poll();
                    //add a row; go to next
                    if(next != null) {
                        addRow(next,i);
                        i++;
                    }
                    else {
                        optimal = false;
                        break;
                    }
                }
            }
            else {
                DataSet add = null;
                for(Queue q : map.values()) {
                    if(!q.isEmpty()) {
                        add = q.poll();
                        break;
                    }
                }

                addRow(add,i);

            }


        }


    }


    @Override
    public void addRow(DataSet d, int i) {
        if(i > numExamples() || d == null)
            throw new IllegalArgumentException("Invalid index for adding a row");
        getFeatures().putRow(i, d.getFeatures());
        getLabels().putRow(i,d.getLabels());
    }


    private int getLabel(DataSet data) {
        Float f = (Float) data.getLabels().max(Integer.MAX_VALUE).element();
        return f.intValue();
    }


    @Override
    public INDArray exampleSums() {
        return getFeatures().sum(1);
    }

    @Override
    public INDArray exampleMaxs() {
        return getFeatures().max(1);
    }

    @Override
    public INDArray exampleMeans() {
        return getFeatures().mean(1);
    }


    /**
     * Sample without replacement and a random rng
     * @param numSamples the number of samples to getFromOrigin
     * @return a sample data applyTransformToDestination without replacement
     */
    @Override
    public DataSet sample(int numSamples) {
        return sample(numSamples,new MersenneTwister(System.currentTimeMillis()));
    }

    /**
     * Sample without replacement
     * @param numSamples the number of samples to getFromOrigin
     * @param rng the rng to use
     * @return the sampled dataset without replacement
     */
    @Override
    public DataSet sample(int numSamples, RandomGenerator rng) {
        return sample(numSamples,rng,false);
    }

    /**
     * Sample a dataset numSamples times
     * @param numSamples the number of samples to getFromOrigin
     * @param withReplacement the rng to use
     * @return the sampled dataset without replacement
     */
    @Override
    public DataSet sample(int numSamples, boolean withReplacement) {
        return sample(numSamples,new MersenneTwister(System.currentTimeMillis()),withReplacement);
    }

    /**
     * Sample a dataset
     * @param numSamples the number of samples to getFromOrigin
     * @param rng the rng to use
     * @param withReplacement whether to allow duplicates (only tracked by example row number)
     * @return the sample dataset
     */
    @Override
    public DataSet sample(int numSamples, RandomGenerator rng, boolean withReplacement) {
        if(numSamples >= numExamples())
            return this;
        else {
            INDArray examples = Nd4j.create(numSamples, getFeatures().columns());
            INDArray outcomes = Nd4j.create(numSamples, numOutcomes());
            Set added = new HashSet();
            for(int i = 0; i < numSamples; i++) {
                int picked = rng.nextInt(numExamples());
                if(!withReplacement)
                    while(added.contains(picked)) {
                        picked = rng.nextInt(numExamples());

                    }
                examples.putRow(i,get(picked).getFeatures());
                outcomes.putRow(i,get(picked).getLabels());

            }
            return new DataSet(examples,outcomes);
        }
    }

    @Override
    public void roundToTheNearest(int roundTo) {
        for(int i = 0; i < getFeatures().length(); i++) {
            double curr = (double) getFeatures().getScalar(i).element();
            getFeatures().put(i, Nd4j.scalar(MathUtils.roundDouble(curr, roundTo)));
        }
    }

    @Override
    public int numOutcomes() {
        return getLabels().columns();
    }

    @Override
    public int numExamples() {
        return getFeatures().rows();
    }




    @Override
    public String toString() {
        StringBuilder builder = new StringBuilder();
        builder.append("===========INPUT===================\n")
                .append(getFeatures().toString().replaceAll(";","\n"))
                .append("\n=================OUTPUT==================\n")
                .append(getLabels().toString().replaceAll(";","\n"));
        return builder.toString();
    }




    /**
     * Gets the optional label names
     * @return
     */
    @Override
    public List getLabelNames() {
        return labelNames;
    }

    /**
     * Sets the label names, will throw an exception if the passed
     * in label names doesn't equal the number of outcomes
     * @param labelNames the label names to use
     */
    @Override
    public void setLabelNames(List labelNames) {
        if(labelNames == null || labelNames.size() != numOutcomes())
            throw new IllegalArgumentException("Unable to applyTransformToDestination label names, does not match number of possible outcomes");
        this.labelNames = labelNames;
    }

    /**
     * Optional column names of the data applyTransformToDestination, this is mainly used
     * for interpeting what columns are in the dataset
     * @return
     */
    @Override
    public List getColumnNames() {
        return columnNames;
    }

    /**
     * Sets the column names, will throw an exception if the column names
     * don't match the number of columns
     * @param columnNames
     */
    @Override
    public void setColumnNames(List columnNames) {
        if(columnNames.size() != numInputs())
            throw new IllegalArgumentException("Column names don't match input");
        this.columnNames = columnNames;
    }



    @Override
    public Iterator iterator() {
        return asList().iterator();
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy