org.opensearch.search.aggregations.matrix.stats.RunningStats Maven / Gradle / Ivy
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/
package org.opensearch.search.aggregations.matrix.stats;
import org.opensearch.OpenSearchException;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
/**
* Descriptive stats gathered per shard. Coordinating node computes final correlation and covariance stats
* based on these descriptive stats. This single pass, parallel approach is based on:
*
* http://prod.sandia.gov/techlib/access-control.cgi/2008/086212.pdf
*/
public class RunningStats implements Writeable, Cloneable {
/** count of observations (same number of observations per field) */
protected long docCount = 0;
/** per field sum of observations */
protected HashMap fieldSum;
/** counts */
protected HashMap counts;
/** mean values (first moment) */
protected HashMap means;
/** variance values (second moment) */
protected HashMap variances;
/** skewness values (third moment) */
protected HashMap skewness;
/** kurtosis values (fourth moment) */
protected HashMap kurtosis;
/** covariance values */
protected HashMap> covariances;
RunningStats() {
init();
}
RunningStats(final String[] fieldNames, final double[] fieldVals) {
if (fieldVals != null && fieldVals.length > 0) {
init();
this.add(fieldNames, fieldVals);
}
}
private void init() {
counts = new HashMap<>();
fieldSum = new HashMap<>();
means = new HashMap<>();
skewness = new HashMap<>();
kurtosis = new HashMap<>();
covariances = new HashMap<>();
variances = new HashMap<>();
}
/** Ctor to create an instance of running statistics */
@SuppressWarnings("unchecked")
public RunningStats(StreamInput in) throws IOException {
this();
// read doc count
docCount = (Long) in.readGenericValue();
// read fieldSum
fieldSum = convertIfNeeded((Map) in.readGenericValue());
// counts
counts = convertIfNeeded((Map) in.readGenericValue());
// means
means = convertIfNeeded((Map) in.readGenericValue());
// variances
variances = convertIfNeeded((Map) in.readGenericValue());
// skewness
skewness = convertIfNeeded((Map) in.readGenericValue());
// kurtosis
kurtosis = convertIfNeeded((Map) in.readGenericValue());
// read covariances
covariances = convertIfNeeded((Map>) in.readGenericValue());
}
// Convert Map to HashMap if it isn't
private static HashMap convertIfNeeded(Map map) {
if (map instanceof HashMap) {
return (HashMap) map;
} else {
return new HashMap<>(map);
}
}
@Override
public void writeTo(StreamOutput out) throws IOException {
// marshall doc count
out.writeGenericValue(docCount);
// marshall fieldSum
out.writeGenericValue(fieldSum);
// counts
out.writeGenericValue(counts);
// mean
out.writeGenericValue(means);
// variances
out.writeGenericValue(variances);
// skewness
out.writeGenericValue(skewness);
// kurtosis
out.writeGenericValue(kurtosis);
// covariances
out.writeGenericValue(covariances);
}
/** updates running statistics with a documents field values **/
public void add(final String[] fieldNames, final double[] fieldVals) {
if (fieldNames == null) {
throw new IllegalArgumentException("Cannot add statistics without field names.");
} else if (fieldVals == null) {
throw new IllegalArgumentException("Cannot add statistics without field values.");
} else if (fieldNames.length != fieldVals.length) {
throw new IllegalArgumentException("Number of field values do not match number of field names.");
}
// update total, mean, and variance
++docCount;
String fieldName;
double fieldValue;
double m1, m2, m3, m4; // moments
double d, dn, dn2, t1;
final HashMap deltas = new HashMap<>();
for (int i = 0; i < fieldNames.length; ++i) {
fieldName = fieldNames[i];
fieldValue = fieldVals[i];
// update counts
counts.put(fieldName, 1 + (counts.containsKey(fieldName) ? counts.get(fieldName) : 0));
// update running sum
fieldSum.put(fieldName, fieldValue + (fieldSum.containsKey(fieldName) ? fieldSum.get(fieldName) : 0));
// update running deltas
deltas.put(fieldName, fieldValue * docCount - fieldSum.get(fieldName));
// update running mean, variance, skewness, kurtosis
if (means.containsKey(fieldName)) {
// update running means
m1 = means.get(fieldName);
d = fieldValue - m1;
means.put(fieldName, m1 + d / docCount);
// update running variances
dn = d / docCount;
t1 = d * dn * (docCount - 1);
m2 = variances.get(fieldName);
variances.put(fieldName, m2 + t1);
m3 = skewness.get(fieldName);
skewness.put(fieldName, m3 + (t1 * dn * (docCount - 2D) - 3D * dn * m2));
dn2 = dn * dn;
m4 = t1 * dn2 * (docCount * docCount - 3D * docCount + 3D) + 6D * dn2 * m2 - 4D * dn * m3;
kurtosis.put(fieldName, kurtosis.get(fieldName) + m4);
} else {
means.put(fieldName, fieldValue);
variances.put(fieldName, 0.0);
skewness.put(fieldName, 0.0);
kurtosis.put(fieldName, 0.0);
}
}
this.updateCovariance(fieldNames, deltas);
}
/** Update covariance matrix */
private void updateCovariance(final String[] fieldNames, final Map deltas) {
// deep copy of hash keys (field names)
ArrayList cFieldNames = new ArrayList<>(Arrays.asList(fieldNames));
String fieldName;
double dR, newVal;
for (int i = 0; i < fieldNames.length; ++i) {
fieldName = fieldNames[i];
cFieldNames.remove(fieldName);
// update running covariances
dR = deltas.get(fieldName);
HashMap cFieldVals = (covariances.get(fieldName) != null) ? covariances.get(fieldName) : new HashMap<>();
for (String cFieldName : cFieldNames) {
if (cFieldVals.containsKey(cFieldName)) {
newVal = cFieldVals.get(cFieldName) + 1.0 / (docCount * (docCount - 1.0)) * dR * deltas.get(cFieldName);
cFieldVals.put(cFieldName, newVal);
} else {
cFieldVals.put(cFieldName, 0.0);
}
}
if (cFieldVals.size() > 0) {
covariances.put(fieldName, cFieldVals);
}
}
}
/**
* Merges the descriptive statistics of a second data set (e.g., per shard)
*
* running computations taken from: http://prod.sandia.gov/techlib/access-control.cgi/2008/086212.pdf
**/
public void merge(final RunningStats other) {
if (other == null) {
return;
} else if (this.docCount == 0) {
for (Map.Entry fs : other.means.entrySet()) {
final String fieldName = fs.getKey();
this.means.put(fieldName, fs.getValue().doubleValue());
this.counts.put(fieldName, other.counts.get(fieldName).longValue());
this.fieldSum.put(fieldName, other.fieldSum.get(fieldName).doubleValue());
this.variances.put(fieldName, other.variances.get(fieldName).doubleValue());
this.skewness.put(fieldName, other.skewness.get(fieldName).doubleValue());
this.kurtosis.put(fieldName, other.kurtosis.get(fieldName).doubleValue());
if (other.covariances.containsKey(fieldName)) {
this.covariances.put(fieldName, other.covariances.get(fieldName));
}
this.docCount = other.docCount;
}
return;
}
final double nA = docCount;
final double nB = other.docCount;
// merge count
docCount += other.docCount;
final HashMap deltas = new HashMap<>();
double meanA, varA, skewA, kurtA, meanB, varB, skewB, kurtB;
double d, d2, d3, d4, n2, nA2, nB2;
double newSkew, nk;
// across fields
for (Map.Entry fs : other.means.entrySet()) {
final String fieldName = fs.getKey();
meanA = means.get(fieldName);
varA = variances.get(fieldName);
skewA = skewness.get(fieldName);
kurtA = kurtosis.get(fieldName);
meanB = other.means.get(fieldName);
varB = other.variances.get(fieldName);
skewB = other.skewness.get(fieldName);
kurtB = other.kurtosis.get(fieldName);
// merge counts of two sets
counts.put(fieldName, counts.get(fieldName) + other.counts.get(fieldName));
// merge means of two sets
means.put(fieldName, (nA * means.get(fieldName) + nB * other.means.get(fieldName)) / (nA + nB));
// merge deltas
deltas.put(fieldName, other.fieldSum.get(fieldName) / nB - fieldSum.get(fieldName) / nA);
// merge totals
fieldSum.put(fieldName, fieldSum.get(fieldName) + other.fieldSum.get(fieldName));
// merge variances, skewness, and kurtosis of two sets
d = meanB - meanA; // delta mean
d2 = d * d; // delta mean squared
d3 = d * d2; // delta mean cubed
d4 = d2 * d2; // delta mean 4th power
n2 = docCount * docCount; // num samples squared
nA2 = nA * nA; // doc A num samples squared
nB2 = nB * nB; // doc B num samples squared
// variance
variances.put(fieldName, varA + varB + d2 * nA * other.docCount / docCount);
// skeewness
newSkew = skewA + skewB + d3 * nA * nB * (nA - nB) / n2;
skewness.put(fieldName, newSkew + 3D * d * (nA * varB - nB * varA) / docCount);
// kurtosis
nk = kurtA + kurtB + d4 * nA * nB * (nA2 - nA * nB + nB2) / (n2 * docCount);
kurtosis.put(fieldName, nk + 6D * d2 * (nA2 * varB + nB2 * varA) / n2 + 4D * d * (nA * skewB - nB * skewA) / docCount);
}
this.mergeCovariance(other, deltas);
}
/** Merges two covariance matrices */
private void mergeCovariance(final RunningStats other, final Map deltas) {
final double countA = docCount - other.docCount;
double f, dR, newVal;
for (Map.Entry fs : other.means.entrySet()) {
final String fieldName = fs.getKey();
// merge covariances of two sets
f = countA * other.docCount / this.docCount;
dR = deltas.get(fieldName);
// merge covariances
if (covariances.containsKey(fieldName)) {
HashMap cFieldVals = covariances.get(fieldName);
for (String cFieldName : cFieldVals.keySet()) {
newVal = cFieldVals.get(cFieldName);
if (other.covariances.containsKey(fieldName) && other.covariances.get(fieldName).containsKey(cFieldName)) {
newVal += other.covariances.get(fieldName).get(cFieldName) + f * dR * deltas.get(cFieldName);
} else {
newVal += other.covariances.get(cFieldName).get(fieldName) + f * dR * deltas.get(cFieldName);
}
cFieldVals.put(cFieldName, newVal);
}
covariances.put(fieldName, cFieldVals);
}
}
}
@Override
public RunningStats clone() {
try {
return (RunningStats) super.clone();
} catch (CloneNotSupportedException e) {
throw new OpenSearchException("Error trying to create a copy of RunningStats");
}
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
RunningStats that = (RunningStats) o;
return docCount == that.docCount
&& Objects.equals(fieldSum, that.fieldSum)
&& Objects.equals(counts, that.counts)
&& Objects.equals(means, that.means)
&& Objects.equals(variances, that.variances)
&& Objects.equals(skewness, that.skewness)
&& Objects.equals(kurtosis, that.kurtosis)
&& Objects.equals(covariances, that.covariances);
}
@Override
public int hashCode() {
return Objects.hash(docCount, fieldSum, counts, means, variances, skewness, kurtosis, covariances);
}
}