All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation Maven / Gradle / Ivy

/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0; you may not use this file except in compliance with the Elastic License
 * 2.0.
 */
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;

import org.apache.lucene.search.join.ScoreMode;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Supplier;

import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toSet;

/**
 * Defines an evaluation
 */
public interface Evaluation extends ToXContentObject, NamedWriteable {

    /**
     * Returns the evaluation name
     */
    String getName();

    /**
     * Returns the collection of fields required by evaluation
     */
    EvaluationFields getFields();

    /**
     * Returns the list of metrics to evaluate
     * @return list of metrics to evaluate
     */
    List getMetrics();

    default  List initMetrics(@Nullable List parsedMetrics, Supplier> defaultMetricsSupplier) {
        List metrics = parsedMetrics == null ? defaultMetricsSupplier.get() : new ArrayList<>(parsedMetrics);
        if (metrics.isEmpty()) {
            throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", getName());
        }
        Collections.sort(metrics, Comparator.comparing(EvaluationMetric::getName));
        checkRequiredFieldsAreSet(metrics);
        return metrics;
    }

    default  void checkRequiredFieldsAreSet(List metrics) {
        assert (metrics == null || metrics.isEmpty()) == false;
        for (Tuple requiredField : getFields().listPotentiallyRequiredFields()) {
            String fieldDescriptor = requiredField.v1();
            String field = requiredField.v2();
            if (field == null) {
                String metricNamesString =
                    metrics.stream()
                        .filter(m -> m.getRequiredFields().contains(fieldDescriptor))
                        .map(EvaluationMetric::getName)
                        .collect(joining(", "));
                if (metricNamesString.isEmpty() == false) {
                    throw ExceptionsHelper.badRequestException(
                        "[{}] must define [{}] as required by the following metrics [{}]",
                        getName(), fieldDescriptor, metricNamesString);
                }
            }
        }
    }

    /**
     * Builds the search required to collect data to compute the evaluation result
     * @param userProvidedQueryBuilder User-provided query that must be respected when collecting data
     */
    default SearchSourceBuilder buildSearch(EvaluationParameters parameters, QueryBuilder userProvidedQueryBuilder) {
        Objects.requireNonNull(userProvidedQueryBuilder);
        Set requiredFields = new HashSet<>(getRequiredFields());
        BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
        if (getFields().getActualField() != null && requiredFields.contains(getFields().getActualField())) {
            // Verify existence of the actual field if required
            boolQuery.filter(QueryBuilders.existsQuery(getFields().getActualField()));
        }
        if (getFields().getPredictedField() != null && requiredFields.contains(getFields().getPredictedField())) {
            // Verify existence of the predicted field if required
            boolQuery.filter(QueryBuilders.existsQuery(getFields().getPredictedField()));
        }
        if (getFields().getPredictedClassField() != null && requiredFields.contains(getFields().getPredictedClassField())) {
            assert getFields().getTopClassesField() != null;
            // Verify existence of the predicted class name field if required
            QueryBuilder predictedClassFieldExistsQuery = QueryBuilders.existsQuery(getFields().getPredictedClassField());
            boolQuery.filter(
                QueryBuilders.nestedQuery(getFields().getTopClassesField(), predictedClassFieldExistsQuery, ScoreMode.None)
                    .ignoreUnmapped(true));
        }
        if (getFields().getPredictedProbabilityField() != null && requiredFields.contains(getFields().getPredictedProbabilityField())) {
            // Verify existence of the predicted probability field if required
            QueryBuilder predictedProbabilityFieldExistsQuery = QueryBuilders.existsQuery(getFields().getPredictedProbabilityField());
            // predicted probability field may be either nested (just like in case of classification evaluation) or non-nested (just like
            // in case of outlier detection evaluation). Here we support both modes.
            if (getFields().isPredictedProbabilityFieldNested()) {
                assert getFields().getTopClassesField() != null;
                boolQuery.filter(
                    QueryBuilders.nestedQuery(getFields().getTopClassesField(), predictedProbabilityFieldExistsQuery, ScoreMode.None)
                        .ignoreUnmapped(true));
            } else {
                boolQuery.filter(predictedProbabilityFieldExistsQuery);
            }
        }
        // Apply user-provided query
        boolQuery.filter(userProvidedQueryBuilder);
        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
        for (EvaluationMetric metric : getMetrics()) {
            // Fetch aggregations requested by individual metrics
            Tuple, List> aggs = metric.aggs(parameters, getFields());
            aggs.v1().forEach(searchSourceBuilder::aggregation);
            aggs.v2().forEach(searchSourceBuilder::aggregation);
        }
        return searchSourceBuilder;
    }

    /**
     * Processes {@link SearchResponse} from the search action
     * @param searchResponse response from the search action
     */
    default void process(SearchResponse searchResponse) {
        Objects.requireNonNull(searchResponse);
        if (searchResponse.getHits().getTotalHits().value == 0) {
            String requiredFieldsString = String.join(", ", getRequiredFields());
            throw ExceptionsHelper.badRequestException("No documents found containing all the required fields [{}]", requiredFieldsString);
        }
        for (EvaluationMetric metric : getMetrics()) {
            metric.process(searchResponse.getAggregations());
        }
    }

    /**
     * @return list of fields which are required by at least one of the metrics
     */
    default List getRequiredFields() {
        Set requiredFieldDescriptors =
            getMetrics().stream()
                .map(EvaluationMetric::getRequiredFields)
                .flatMap(Set::stream)
                .collect(toSet());
        List requiredFields =
            getFields().listPotentiallyRequiredFields().stream()
                .filter(f -> requiredFieldDescriptors.contains(f.v1()))
                .map(Tuple::v2)
                .collect(toList());
        return requiredFields;
    }

    /**
     * @return true iff all the metrics have their results computed
     */
    default boolean hasAllResults() {
        return getMetrics().stream().map(EvaluationMetric::getResult).allMatch(Optional::isPresent);
    }

    /**
     * Returns the list of evaluation results
     * @return list of evaluation results
     */
    default List getResults() {
        return getMetrics().stream()
            .map(EvaluationMetric::getResult)
            .filter(Optional::isPresent)
            .map(Optional::get)
            .collect(toList());
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy