All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.elasticsearch.xpack.core.ml.stats.StatsAccumulator Maven / Gradle / Ivy

/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0; you may not use this file except in compliance with the Elastic License
 * 2.0.
 */
package org.elasticsearch.xpack.core.ml.stats;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.search.aggregations.metrics.Stats;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

/**
 * Helper class to collect min, max, avg and total statistics for a quantity
 */
public class StatsAccumulator implements Writeable {

    public static class Fields {
        public static final String MIN = "min";
        public static final String MAX = "max";
        public static final String AVG = "avg";
        public static final String TOTAL = "total";
    }

    private long count;
    private double total;
    private Double min;
    private Double max;

    public StatsAccumulator() {
    }

    public StatsAccumulator(StreamInput in) throws IOException {
        count = in.readLong();
        total = in.readDouble();
        min = in.readOptionalDouble();
        max = in.readOptionalDouble();
    }

    private StatsAccumulator(long count, double total, double min, double max) {
        this.count = count;
        this.total = total;
        this.min = min;
        this.max = max;
    }

    public void add(double value) {
        count++;
        total += value;
        min = min == null ? value : (value < min ? value : min);
        max = max == null ? value : (value > max ? value : max);
    }

    public double getMin() {
        return min == null ? 0.0 : min;
    }

    public double getMax() {
        return max == null ? 0.0 : max;
    }

    public double getAvg() {
        return count == 0.0 ? 0.0 : total/count;
    }

    public double getTotal() {
        return total;
    }

    public void merge(StatsAccumulator other) {
        count += other.count;
        total += other.total;

        // note: not using Math.min/max as some internal prefetch optimization causes an NPE
        min = min == null ? other.min : (other.min == null ? min : other.min < min ? other.min : min);
        max = max == null ? other.max : (other.max == null ? max : other.max > max ? other.max : max);
    }

    public Map asMap() {
        Map map = new HashMap<>();
        map.put(Fields.MIN, getMin());
        map.put(Fields.MAX, getMax());
        map.put(Fields.AVG, getAvg());
        map.put(Fields.TOTAL, getTotal());
        return map;
    }

    @Override
    public void writeTo(StreamOutput out) throws IOException {
        out.writeLong(count);
        out.writeDouble(total);
        out.writeOptionalDouble(min);
        out.writeOptionalDouble(max);
    }

    public static StatsAccumulator fromStatsAggregation(Stats statsAggregation) {
        return new StatsAccumulator(statsAggregation.getCount(), statsAggregation.getSum(), statsAggregation.getMin(),
                statsAggregation.getMax());
    }

    @Override
    public int hashCode() {
        return Objects.hash(count, total, min, max);
    }

    @Override
    public boolean equals(Object obj) {
        if (obj == null) {
            return false;
        }

        if (getClass() != obj.getClass()) {
            return false;
        }

        StatsAccumulator other = (StatsAccumulator) obj;
        return Objects.equals(count, other.count) && Objects.equals(total, other.total) && Objects.equals(min, other.min)
                && Objects.equals(max, other.max);
    }
}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy