All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.AbstractConfusionMatrixMetric Maven / Gradle / Ivy

/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0; you may not use this file except in compliance with the Elastic License
 * 2.0.
 */
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection;

import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Set;

import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.OutlierDetection.actualIsTrueQuery;

abstract class AbstractConfusionMatrixMetric implements EvaluationMetric {

    public static final ParseField AT = new ParseField("at");

    protected final double[] thresholds;
    private EvaluationMetricResult result;

    protected AbstractConfusionMatrixMetric(List at) {
        this.thresholds = ExceptionsHelper.requireNonNull(at, AT).stream().mapToDouble(Double::doubleValue).toArray();
        if (thresholds.length == 0) {
            throw ExceptionsHelper.badRequestException("[" + getName() + "." + AT.getPreferredName() + "] must have at least one value");
        }
        for (double threshold : thresholds) {
            if (threshold < 0 || threshold > 1.0) {
                throw ExceptionsHelper.badRequestException("[" + getName() + "." + AT.getPreferredName()
                    + "] values must be in [0.0, 1.0]");
            }
        }
    }

    protected AbstractConfusionMatrixMetric(StreamInput in) throws IOException {
        this.thresholds = in.readDoubleArray();
    }

    @Override
    public void writeTo(StreamOutput out) throws IOException {
        out.writeDoubleArray(thresholds);
    }

    @Override
    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
        builder.startObject();
        builder.field(AT.getPreferredName(), thresholds);
        builder.endObject();
        return builder;
    }

    @Override
    public Set getRequiredFields() {
        return Sets.newHashSet(
            EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_PROBABILITY_FIELD.getPreferredName());
    }

    @Override
    public Tuple, List> aggs(EvaluationParameters parameters,
                                                                                  EvaluationFields fields) {
        if (result != null) {
            return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
        }
        String actualField = fields.getActualField();
        String predictedProbabilityField = fields.getPredictedProbabilityField();
        return Tuple.tuple(aggsAt(actualField, predictedProbabilityField), Collections.emptyList());
    }

    @Override
    public void process(Aggregations aggs) {
        result = evaluate(aggs);
    }

    @Override
    public Optional getResult() {
        return Optional.ofNullable(result);
    }

    protected abstract List aggsAt(String actualField, String predictedProbabilityField);

    protected abstract EvaluationMetricResult evaluate(Aggregations aggs);

    enum Condition {
        TP(true, true),
        FP(false, true),
        TN(false, false),
        FN(true, false);

        final boolean actual;
        final boolean predicted;

        Condition(boolean actual, boolean predicted) {
            this.actual = actual;
            this.predicted = predicted;
        }
    }

    protected String aggName(double threshold, Condition condition) {
        return getName() + "_at_" + threshold + "_" + condition.name();
    }

    protected AggregationBuilder buildAgg(String actualField, String predictedProbabilityField, double threshold, Condition condition) {
        BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
        QueryBuilder actualIsTrueQuery = actualIsTrueQuery(actualField);
        QueryBuilder predictedIsTrueQuery = QueryBuilders.rangeQuery(predictedProbabilityField).gte(threshold);
        if (condition.actual) {
            boolQuery.must(actualIsTrueQuery);
        } else {
            boolQuery.mustNot(actualIsTrueQuery);
        }
        if (condition.predicted) {
            boolQuery.must(predictedIsTrueQuery);
        } else {
            boolQuery.mustNot(predictedIsTrueQuery);
        }
        return AggregationBuilders.filter(aggName(threshold, condition), boolQuery);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy