All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.elasticsearch.client.analytics.ParsedInference Maven / Gradle / Ivy

There is a newer version: 8.0.0-alpha2
Show newest version
/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0 and the Server Side Public License, v 1; you may not use this file except
 * in compliance with, at your election, the Elastic License 2.0 or the Server
 * Side Public License, v 1.
 */

package org.elasticsearch.client.analytics;

import org.elasticsearch.client.ml.inference.results.FeatureImportance;
import org.elasticsearch.client.ml.inference.results.TopClassEntry;
import org.elasticsearch.search.aggregations.ParsedAggregation;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParseException;
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.util.List;

import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;

/**
 * This class parses the superset of all possible fields that may be written by
 * InferenceResults. The warning field is mutually exclusive with all the other fields.
 *
 * In the case of classification results {@link #getValue()} may return a String,
 * Boolean or a Double. For regression results {@link #getValue()} is always
 * a Double.
 */
public class ParsedInference extends ParsedAggregation {

    @SuppressWarnings("unchecked")
    private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(
        ParsedInference.class.getSimpleName(),
        true,
        args -> new ParsedInference(args[0], (List) args[1], (List) args[2], (String) args[3])
    );

    public static final ParseField FEATURE_IMPORTANCE = new ParseField("feature_importance");
    public static final ParseField WARNING = new ParseField("warning");
    public static final ParseField TOP_CLASSES = new ParseField("top_classes");

    static {
        PARSER.declareField(optionalConstructorArg(), (p, n) -> {
            Object o;
            XContentParser.Token token = p.currentToken();
            if (token == XContentParser.Token.VALUE_STRING) {
                o = p.text();
            } else if (token == XContentParser.Token.VALUE_BOOLEAN) {
                o = p.booleanValue();
            } else if (token == XContentParser.Token.VALUE_NUMBER) {
                o = p.doubleValue();
            } else {
                throw new XContentParseException(
                    p.getTokenLocation(),
                    "["
                        + ParsedInference.class.getSimpleName()
                        + "] failed to parse field ["
                        + CommonFields.VALUE
                        + "] "
                        + "value ["
                        + token
                        + "] is not a string, boolean or number"
                );
            }
            return o;
        }, CommonFields.VALUE, ObjectParser.ValueType.VALUE);
        PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> FeatureImportance.fromXContent(p), FEATURE_IMPORTANCE);
        PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> TopClassEntry.fromXContent(p), TOP_CLASSES);
        PARSER.declareString(optionalConstructorArg(), WARNING);
        declareAggregationFields(PARSER);
    }

    public static ParsedInference fromXContent(XContentParser parser, final String name) {
        ParsedInference parsed = PARSER.apply(parser, null);
        parsed.setName(name);
        return parsed;
    }

    private final Object value;
    private final List featureImportance;
    private final List topClasses;
    private final String warning;

    ParsedInference(Object value, List featureImportance, List topClasses, String warning) {
        this.value = value;
        this.warning = warning;
        this.featureImportance = featureImportance;
        this.topClasses = topClasses;
    }

    public Object getValue() {
        return value;
    }

    public List getFeatureImportance() {
        return featureImportance;
    }

    public List getTopClasses() {
        return topClasses;
    }

    public String getWarning() {
        return warning;
    }

    @Override
    protected XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
        if (warning != null) {
            builder.field(WARNING.getPreferredName(), warning);
        } else {
            builder.field(CommonFields.VALUE.getPreferredName(), value);
            if (topClasses != null && topClasses.size() > 0) {
                builder.field(TOP_CLASSES.getPreferredName(), topClasses);
            }
            if (featureImportance != null && featureImportance.size() > 0) {
                builder.field(FEATURE_IMPORTANCE.getPreferredName(), featureImportance);
            }
        }
        return builder;
    }

    @Override
    public String getType() {
        return InferencePipelineAggregationBuilder.NAME;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy