All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Recall Maven / Gradle / Ivy

/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0; you may not use this file except in compliance with the Elastic License
 * 2.0.
 */
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection;

import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;

public class Recall extends AbstractConfusionMatrixMetric {

    public static final ParseField NAME = new ParseField("recall");

    private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(),
        a -> new Recall((List) a[0]));

    static {
        PARSER.declareDoubleArray(ConstructingObjectParser.constructorArg(), AT);
    }

    public static Recall fromXContent(XContentParser parser) {
        return PARSER.apply(parser, null);
    }

    public Recall(List at) {
        super(at);
    }

    public Recall(StreamInput in) throws IOException {
        super(in);
    }

    @Override
    public String getWriteableName() {
        return registeredMetricName(OutlierDetection.NAME, NAME);
    }

    @Override
    public String getName() {
        return NAME.getPreferredName();
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) return true;
        if (o == null || getClass() != o.getClass()) return false;
        Recall that = (Recall) o;
        return Arrays.equals(thresholds, that.thresholds);
    }

    @Override
    public int hashCode() {
        return Arrays.hashCode(thresholds);
    }

    @Override
    protected List aggsAt(String actualField, String predictedProbabilityField) {
        List aggs = new ArrayList<>();
        for (int i = 0; i < thresholds.length; i++) {
            double threshold = thresholds[i];
            aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.TP));
            aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.FN));
        }
        return aggs;
    }

    @Override
    public EvaluationMetricResult evaluate(Aggregations aggs) {
        double[] recalls = new double[thresholds.length];
        for (int i = 0; i < thresholds.length; i++) {
            double threshold = thresholds[i];
            Filter tpAgg = aggs.get(aggName(threshold, Condition.TP));
            Filter fnAgg = aggs.get(aggName(threshold, Condition.FN));
            long tp = tpAgg.getDocCount();
            long fn = fnAgg.getDocCount();
            recalls[i] = tp + fn == 0 ? 0.0 : (double) tp / (tp + fn);
        }
        return new ScoreByThresholdResult(NAME.getPreferredName(), thresholds, recalls);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy