org.hipparchus.stat.descriptive.MultivariateSummaryStatistics Maven / Gradle / Ivy
Show all versions of hipparchus-stat Show documentation
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* This is not the original file distributed by the Apache Software Foundation
* It has been modified by the Hipparchus project
*/
package org.hipparchus.stat.descriptive;
import java.io.Serializable;
import java.util.Arrays;
import org.hipparchus.exception.LocalizedCoreFormats;
import org.hipparchus.exception.MathIllegalArgumentException;
import org.hipparchus.linear.RealMatrix;
import org.hipparchus.stat.descriptive.moment.GeometricMean;
import org.hipparchus.stat.descriptive.moment.Mean;
import org.hipparchus.stat.descriptive.rank.Max;
import org.hipparchus.stat.descriptive.rank.Min;
import org.hipparchus.stat.descriptive.summary.Sum;
import org.hipparchus.stat.descriptive.summary.SumOfLogs;
import org.hipparchus.stat.descriptive.summary.SumOfSquares;
import org.hipparchus.stat.descriptive.vector.VectorialCovariance;
import org.hipparchus.stat.descriptive.vector.VectorialStorelessStatistic;
import org.hipparchus.util.FastMath;
import org.hipparchus.util.MathArrays;
import org.hipparchus.util.MathUtils;
/**
* Computes summary statistics for a stream of n-tuples added using the
* {@link #addValue(double[]) addValue} method. The data values are not stored
* in memory, so this class can be used to compute statistics for very large
* n-tuple streams.
*
* To compute statistics for a stream of n-tuples, construct a
* {@link MultivariateSummaryStatistics} instance with dimension n and then use
* {@link #addValue(double[])} to add n-tuples. The getXxx
* methods where Xxx is a statistic return an array of double
* values, where for i = 0,...,n-1
the ith array element
* is the value of the given statistic for data range consisting of the ith
* element of each of the input n-tuples. For example, if addValue
is
* called with actual parameters {0, 1, 2}, then {3, 4, 5} and finally {6, 7, 8},
* getSum
will return a three-element array with values {0+3+6, 1+4+7, 2+5+8}
*
* Note: This class is not thread-safe.
*/
public class MultivariateSummaryStatistics
implements StatisticalMultivariateSummary, Serializable {
/** Serialization UID */
private static final long serialVersionUID = 20160424L;
/** Dimension of the data. */
private final int k;
/** Sum statistic implementation */
private final StorelessMultivariateStatistic sumImpl;
/** Sum of squares statistic implementation */
private final StorelessMultivariateStatistic sumSqImpl;
/** Minimum statistic implementation */
private final StorelessMultivariateStatistic minImpl;
/** Maximum statistic implementation */
private final StorelessMultivariateStatistic maxImpl;
/** Sum of log statistic implementation */
private final StorelessMultivariateStatistic sumLogImpl;
/** Geometric mean statistic implementation */
private final StorelessMultivariateStatistic geoMeanImpl;
/** Mean statistic implementation */
private final StorelessMultivariateStatistic meanImpl;
/** Covariance statistic implementation */
private final VectorialCovariance covarianceImpl;
/** Count of values that have been added */
private long n;
/**
* Construct a MultivariateSummaryStatistics instance for the given
* dimension. The returned instance will compute the unbiased sample
* covariance.
*
* The returned instance is not thread-safe.
*
* @param dimension dimension of the data
*/
public MultivariateSummaryStatistics(int dimension) {
this(dimension, true);
}
/**
* Construct a MultivariateSummaryStatistics instance for the given
* dimension.
*
* The returned instance is not thread-safe.
*
* @param dimension dimension of the data
* @param covarianceBiasCorrection if true, the returned instance will compute
* the unbiased sample covariance, otherwise the population covariance
*/
public MultivariateSummaryStatistics(int dimension, boolean covarianceBiasCorrection) {
this.k = dimension;
sumImpl = new VectorialStorelessStatistic(k, new Sum());
sumSqImpl = new VectorialStorelessStatistic(k, new SumOfSquares());
minImpl = new VectorialStorelessStatistic(k, new Min());
maxImpl = new VectorialStorelessStatistic(k, new Max());
sumLogImpl = new VectorialStorelessStatistic(k, new SumOfLogs());
geoMeanImpl = new VectorialStorelessStatistic(k, new GeometricMean());
meanImpl = new VectorialStorelessStatistic(k, new Mean());
covarianceImpl = new VectorialCovariance(k, covarianceBiasCorrection);
}
/**
* Add an n-tuple to the data
*
* @param value the n-tuple to add
* @throws MathIllegalArgumentException if the array is null or the length
* of the array does not match the one used at construction
*/
public void addValue(double[] value) throws MathIllegalArgumentException {
MathUtils.checkNotNull(value, LocalizedCoreFormats.INPUT_ARRAY);
MathUtils.checkDimension(value.length, k);
sumImpl.increment(value);
sumSqImpl.increment(value);
minImpl.increment(value);
maxImpl.increment(value);
sumLogImpl.increment(value);
geoMeanImpl.increment(value);
meanImpl.increment(value);
covarianceImpl.increment(value);
n++;
}
/**
* Resets all statistics and storage.
*/
public void clear() {
this.n = 0;
minImpl.clear();
maxImpl.clear();
sumImpl.clear();
sumLogImpl.clear();
sumSqImpl.clear();
geoMeanImpl.clear();
meanImpl.clear();
covarianceImpl.clear();
}
/** {@inheritDoc} **/
@Override
public int getDimension() {
return k;
}
/** {@inheritDoc} **/
@Override
public long getN() {
return n;
}
/** {@inheritDoc} **/
@Override
public double[] getSum() {
return sumImpl.getResult();
}
/** {@inheritDoc} **/
@Override
public double[] getSumSq() {
return sumSqImpl.getResult();
}
/** {@inheritDoc} **/
@Override
public double[] getSumLog() {
return sumLogImpl.getResult();
}
/** {@inheritDoc} **/
@Override
public double[] getMean() {
return meanImpl.getResult();
}
/** {@inheritDoc} **/
@Override
public RealMatrix getCovariance() {
return covarianceImpl.getResult();
}
/** {@inheritDoc} **/
@Override
public double[] getMax() {
return maxImpl.getResult();
}
/** {@inheritDoc} **/
@Override
public double[] getMin() {
return minImpl.getResult();
}
/** {@inheritDoc} **/
@Override
public double[] getGeometricMean() {
return geoMeanImpl.getResult();
}
/**
* Returns an array whose ith entry is the standard deviation of the
* ith entries of the arrays that have been added using
* {@link #addValue(double[])}
*
* @return the array of component standard deviations
*/
@Override
public double[] getStandardDeviation() {
double[] stdDev = new double[k];
if (getN() < 1) {
Arrays.fill(stdDev, Double.NaN);
} else if (getN() < 2) {
Arrays.fill(stdDev, 0.0);
} else {
RealMatrix matrix = getCovariance();
for (int i = 0; i < k; ++i) {
stdDev[i] = FastMath.sqrt(matrix.getEntry(i, i));
}
}
return stdDev;
}
/**
* Generates a text report displaying
* summary statistics from values that
* have been added.
* @return String with line feeds displaying statistics
*/
@Override
public String toString() {
final String separator = ", ";
final String suffix = System.getProperty("line.separator");
StringBuilder outBuffer = new StringBuilder(200); // the size is just a wild guess
outBuffer.append("MultivariateSummaryStatistics:").append(suffix).
append("n: ").append(getN()).append(suffix);
append(outBuffer, getMin(), "min: ", separator, suffix);
append(outBuffer, getMax(), "max: ", separator, suffix);
append(outBuffer, getMean(), "mean: ", separator, suffix);
append(outBuffer, getGeometricMean(), "geometric mean: ", separator, suffix);
append(outBuffer, getSumSq(), "sum of squares: ", separator, suffix);
append(outBuffer, getSumLog(), "sum of logarithms: ", separator, suffix);
append(outBuffer, getStandardDeviation(), "standard deviation: ", separator, suffix);
outBuffer.append("covariance: ").append(getCovariance().toString()).append(suffix);
return outBuffer.toString();
}
/**
* Append a text representation of an array to a buffer.
* @param buffer buffer to fill
* @param data data array
* @param prefix text prefix
* @param separator elements separator
* @param suffix text suffix
*/
private void append(StringBuilder buffer, double[] data,
String prefix, String separator, String suffix) {
buffer.append(prefix);
for (int i = 0; i < data.length; ++i) {
if (i > 0) {
buffer.append(separator);
}
buffer.append(data[i]);
}
buffer.append(suffix);
}
/**
* Returns true iff object
is a MultivariateSummaryStatistics
* instance and all statistics have the same values as this.
* @param object the object to test equality against.
* @return true if object equals this
*/
@Override
public boolean equals(Object object) {
if (object == this) {
return true;
}
if (!(object instanceof MultivariateSummaryStatistics)) {
return false;
}
MultivariateSummaryStatistics other = (MultivariateSummaryStatistics) object;
return other.getN() == getN() &&
MathArrays.equalsIncludingNaN(other.getGeometricMean(), getGeometricMean()) &&
MathArrays.equalsIncludingNaN(other.getMax(), getMax()) &&
MathArrays.equalsIncludingNaN(other.getMean(), getMean()) &&
MathArrays.equalsIncludingNaN(other.getMin(), getMin()) &&
MathArrays.equalsIncludingNaN(other.getSum(), getSum()) &&
MathArrays.equalsIncludingNaN(other.getSumSq(), getSumSq()) &&
MathArrays.equalsIncludingNaN(other.getSumLog(), getSumLog()) &&
other.getCovariance().equals(getCovariance());
}
/**
* Returns hash code based on values of statistics
*
* @return hash code
*/
@Override
public int hashCode() {
int result = 31 + MathUtils.hash(getN());
result = result * 31 + MathUtils.hash(getGeometricMean());
result = result * 31 + MathUtils.hash(getMax());
result = result * 31 + MathUtils.hash(getMean());
result = result * 31 + MathUtils.hash(getMin());
result = result * 31 + MathUtils.hash(getSum());
result = result * 31 + MathUtils.hash(getSumSq());
result = result * 31 + MathUtils.hash(getSumLog());
result = result * 31 + getCovariance().hashCode();
return result;
}
}