All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.dataset.DistributionStats Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.nd4j.linalg.dataset;

import lombok.Getter;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.*;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

/**
 * Statistics about the normal distribution of values in data (means and standard deviations).
 * Can be constructed incrementally by using the Builder, which is useful for obtaining these statistics from an
 * iterator. Can also load and save from files.
 *
 * @author Ede Meijer
 */
@Getter
public class DistributionStats {
    private static final Logger logger = LoggerFactory.getLogger(NormalizerStandardize.class);

    private final INDArray mean;
    private final INDArray std;

    /**
     * @param mean row vector of means
     * @param std  row vector of standard deviations
     */
    public DistributionStats(@NonNull INDArray mean, @NonNull INDArray std) {
        Transforms.max(std, Nd4j.EPS_THRESHOLD, false);
        if (std.min(1) == Nd4j.scalar(Nd4j.EPS_THRESHOLD)) {
            logger.info("API_INFO: Std deviation found to be zero. Transform will round up to epsilon to avoid nans.");
        }

        this.mean = mean;
        this.std = std;
    }

    /**
     * Load distribution statistics from the file system
     *
     * @param meanFile file containing the means
     * @param stdFile  file containing the standard deviations
     */
    public static DistributionStats load(@NonNull File meanFile, @NonNull File stdFile) throws IOException {
        return new DistributionStats(Nd4j.readBinary(meanFile), Nd4j.readBinary(stdFile));
    }

    /**
     * Save distribution statistics to the file system
     *
     * @param meanFile file to contain the means
     * @param stdFile  file to contain the standard deviations
     */
    public void save(@NonNull File meanFile, @NonNull File stdFile) throws IOException {
        Nd4j.saveBinary(getMean(), meanFile);
        Nd4j.saveBinary(getStd(), stdFile);
    }

    /**
     * Builder class that can incrementally update a running mean and variance in order to create statistics for a
     * large set of data
     */
    public static class Builder {
        private int runningCount = 0;
        private INDArray runningMean;
        private INDArray runningVariance;

        /**
         * Add the features of a DataSet to the statistics
         */
        public Builder addFeatures(@NonNull org.nd4j.linalg.dataset.api.DataSet dataSet) {
            return add(dataSet.getFeatures(), dataSet.getFeaturesMaskArray());
        }

        /**
         * Add the labels of a DataSet to the statistics
         */
        public Builder addLabels(@NonNull org.nd4j.linalg.dataset.api.DataSet dataSet) {
            return add(dataSet.getLabels(), dataSet.getLabelsMaskArray());
        }

        /**
         * Add rows of data to the statistics
         *
         * @param data the matrix containing multiple rows of data to include
         * @param mask (optionally) the mask of the data, useful for e.g. time series
         */
        public Builder add(@NonNull INDArray data, INDArray mask) {
            data = DataSetUtil.tailor2d(data, mask);

            // Using https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
            INDArray mean = data.mean(0);
            INDArray variance = data.var(false, 0);
            int count = data.size(0);

            if (runningMean == null) {
                // First batch
                runningMean = mean;
                runningVariance = variance;
                runningCount = count;
            } else {
                // Update running variance
                INDArray deltaSquared = Transforms.pow(mean.subRowVector(runningMean), 2);
                INDArray mB = variance.muli(count);
                runningVariance.muli(runningCount)
                    .addiRowVector(mB)
                    .addiRowVector(deltaSquared.muli((float) (runningCount * count) / (runningCount + count)))
                    .divi(runningCount + count);

                // Update running count
                runningCount += count;

                // Update running mean
                INDArray xMinusMean = data.subRowVector(runningMean);
                runningMean.addi(xMinusMean.sum(0).divi(runningCount));
            }

            return this;
        }

        /**
         * Create a DistributionStats object from the data ingested so far. Can be used multiple times when updating
         * online.
         */
        public DistributionStats build() {
            return new DistributionStats(runningMean.dup(), Transforms.sqrt(runningVariance, true));
        }

        /**
         * Utility function for building a list of DistributionStat objects from a list of builders
         *
         * @param builders the builders
         * @return the list of DistributionStat objects
         */
        public static List buildList(@NonNull List builders) {
            List result = new ArrayList<>(builders.size());
            for (Builder builder : builders) {
                result.add(builder.build());
            }
            return result;
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy