org.elasticsearch.search.aggregations.matrix.stats.MatrixStatsResults Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of aggs-matrix-stats-client Show documentation
Show all versions of aggs-matrix-stats-client Show documentation
Adds aggregations whose input are a list of numeric fields and output includes a matrix.
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.search.aggregations.matrix.stats;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
/**
* Descriptive stats gathered per shard. Coordinating node computes final pearson product coefficient
* based on these descriptive stats
*/
class MatrixStatsResults implements Writeable {
/** object holding results - computes results in place */
protected final RunningStats results;
/** pearson product correlation coefficients */
protected final Map> correlation;
/** Base ctor */
MatrixStatsResults() {
results = new RunningStats();
this.correlation = new HashMap<>();
}
/** creates and computes result from provided stats */
MatrixStatsResults(RunningStats stats) {
this.results = stats.clone();
this.correlation = new HashMap<>();
this.compute();
}
/** creates a results object from the given stream */
@SuppressWarnings("unchecked")
protected MatrixStatsResults(StreamInput in) {
try {
results = new RunningStats(in);
correlation = (Map>) in.readGenericValue();
} catch (IOException e) {
throw new ElasticsearchException("Error trying to create multifield_stats results from stream input", e);
}
}
/** Marshalls MatrixStatsResults */
@Override
public void writeTo(StreamOutput out) throws IOException {
// marshall results
results.writeTo(out);
// marshall correlation
out.writeGenericValue(correlation);
}
/** return document count */
public final long getDocCount() {
return results.docCount;
}
/** return the field counts - not public, used for getProperty() */
protected Map getFieldCounts() {
return Collections.unmodifiableMap(results.counts);
}
/** return the fied count for the requested field */
public long getFieldCount(String field) {
if (results.counts.containsKey(field) == false) {
return 0;
}
return results.counts.get(field);
}
/** return the means - not public, used for getProperty() */
protected Map getMeans() {
return Collections.unmodifiableMap(results.means);
}
/** return the mean for the requested field */
public double getMean(String field) {
checkField(field, results.means);
return results.means.get(field);
}
/** return the variances - not public, used for getProperty() */
protected Map getVariances() {
return Collections.unmodifiableMap(results.variances);
}
/** return the variance for the requested field */
public double getVariance(String field) {
checkField(field, results.variances);
return results.variances.get(field);
}
/** return the skewness - not public, used for getProperty() */
protected Map getSkewness() {
return Collections.unmodifiableMap(results.skewness);
}
/** return the skewness for the requested field */
public double getSkewness(String field) {
checkField(field, results.skewness);
return results.skewness.get(field);
}
/** return the kurtosis */
protected Map getKurtosis() {
return Collections.unmodifiableMap(results.kurtosis);
}
/** return the kurtosis for the requested field */
public double getKurtosis(String field) {
checkField(field, results.kurtosis);
return results.kurtosis.get(field);
}
/** return the covariances as a map - not public, used for getProperty() */
protected Map> getCovariances() {
return Collections.unmodifiableMap(results.covariances);
}
/** return the covariance between two fields */
public double getCovariance(String fieldX, String fieldY) {
if (fieldX.equals(fieldY)) {
checkField(fieldX, results.variances);
return results.variances.get(fieldX);
}
return getValFromUpperTriangularMatrix(results.covariances, fieldX, fieldY);
}
/** return the correlations as a map - not public, used for getProperty() */
protected Map> getCorrelations() {
return Collections.unmodifiableMap(correlation);
}
/** return the correlation coefficient between two fields */
public Double getCorrelation(String fieldX, String fieldY) {
if (fieldX.equals(fieldY)) {
return 1.0;
}
return getValFromUpperTriangularMatrix(correlation, fieldX, fieldY);
}
/** return the value for two fields in an upper triangular matrix, regardless of row col location. */
static > double getValFromUpperTriangularMatrix(Map map, String fieldX, String fieldY) {
// for the co-value to exist, one of the two (or both) fields has to be a row key
if (map.containsKey(fieldX) == false && map.containsKey(fieldY) == false) {
throw new IllegalArgumentException("neither field " + fieldX + " nor " + fieldY + " exist");
} else if (map.containsKey(fieldX)) {
// fieldX exists as a row key
if (map.get(fieldX).containsKey(fieldY)) {
// fieldY exists as a col key to fieldX
return map.get(fieldX).get(fieldY);
} else {
// otherwise fieldX is the col key to fieldY
return map.get(fieldY).get(fieldX);
}
} else if (map.containsKey(fieldY)) {
// fieldX did not exist as a row key, it must be a col key
return map.get(fieldY).get(fieldX);
}
throw new IllegalArgumentException("Coefficient not computed between fields: " + fieldX + " and " + fieldY);
}
private void checkField(String field, Map map) {
if (field == null) {
throw new IllegalArgumentException("field name cannot be null");
}
if (map.containsKey(field) == false) {
throw new IllegalArgumentException("field " + field + " does not exist");
}
}
/** Computes final covariance, variance, and correlation */
private void compute() {
final double nM1 = results.docCount - 1D;
// compute final skewness and kurtosis
for (String fieldName : results.means.keySet()) {
final double var = results.variances.get(fieldName);
// update skewness
results.skewness.put(fieldName, Math.sqrt(results.docCount) * results.skewness.get(fieldName) / Math.pow(var, 1.5D));
// update kurtosis
results.kurtosis.put(fieldName, (double) results.docCount * results.kurtosis.get(fieldName) / (var * var));
// update variances
results.variances.put(fieldName, results.variances.get(fieldName) / nM1);
}
// compute final covariances and correlation
double cor;
for (Map.Entry> row : results.covariances.entrySet()) {
final String rowName = row.getKey();
final HashMap covRow = row.getValue();
final HashMap corRow = new HashMap<>();
for (Map.Entry col : covRow.entrySet()) {
final String colName = col.getKey();
// update covariance
covRow.put(colName, covRow.get(colName) / nM1);
// update correlation
// if there is no variance in the data then correlation is NaN
if (results.variances.get(rowName) == 0d || results.variances.get(colName) == 0d) {
cor = Double.NaN;
} else {
final double corDen = Math.sqrt(results.variances.get(rowName)) * Math.sqrt(results.variances.get(colName));
cor = covRow.get(colName) / corDen;
}
corRow.put(colName, cor);
}
results.covariances.put(rowName, covRow);
correlation.put(rowName, corRow);
}
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
MatrixStatsResults that = (MatrixStatsResults) o;
return Objects.equals(results, that.results) && Objects.equals(correlation, that.correlation);
}
@Override
public int hashCode() {
return Objects.hash(results, correlation);
}
}