All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.EnsembleInferenceModel Maven / Gradle / Ivy

/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0; you may not use this file except in compliance with the Elastic License
 * 2.0.
 */

package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LenientlyParsedOutputAggregator;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.decodeFeatureImportances;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.sumDoubleArrays;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportanceClassification;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportanceRegression;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.AGGREGATE_OUTPUT;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_LABELS;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_WEIGHTS;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.TARGET_TYPE;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.TRAINED_MODELS;

public class EnsembleInferenceModel implements InferenceModel {

    public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(EnsembleInferenceModel.class);
    private static final Logger LOGGER = LogManager.getLogger(EnsembleInferenceModel.class);

    @SuppressWarnings("unchecked")
    private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(
        "ensemble_inference_model",
        true,
        a -> new EnsembleInferenceModel((List)a[0],
            (OutputAggregator)a[1],
            TargetType.fromString((String)a[2]),
            (List)a[3],
            (List)a[4]));
    static {
        PARSER.declareNamedObjects(constructorArg(),
            (p, c, n) -> p.namedObject(InferenceModel.class, n, null),
            (ensembleBuilder) -> {},
            TRAINED_MODELS);
        PARSER.declareNamedObject(constructorArg(),
            (p, c, n) -> p.namedObject(LenientlyParsedOutputAggregator.class, n, null),
            AGGREGATE_OUTPUT);
        PARSER.declareString(constructorArg(), TARGET_TYPE);
        PARSER.declareStringArray(optionalConstructorArg(), CLASSIFICATION_LABELS);
        PARSER.declareDoubleArray(optionalConstructorArg(), CLASSIFICATION_WEIGHTS);
    }

    public static EnsembleInferenceModel fromXContent(XContentParser parser) {
        return PARSER.apply(parser, null);
    }

    private String[] featureNames = new String[0];
    private final List models;
    private final OutputAggregator outputAggregator;
    private final TargetType targetType;
    private final List classificationLabels;
    private final double[] classificationWeights;
    private volatile boolean preparedForInference = false;

    private EnsembleInferenceModel(List models,
                                   OutputAggregator outputAggregator,
                                   TargetType targetType,
                                   @Nullable List classificationLabels,
                                   List classificationWeights) {
        this.models = ExceptionsHelper.requireNonNull(models, TRAINED_MODELS);
        this.outputAggregator = ExceptionsHelper.requireNonNull(outputAggregator, AGGREGATE_OUTPUT);
        this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE);
        this.classificationLabels = classificationLabels;
        this.classificationWeights = classificationWeights == null ?
            null :
            classificationWeights.stream().mapToDouble(Double::doubleValue).toArray();
    }

    @Override
    public String[] getFeatureNames() {
        return featureNames;
    }

    @Override
    public TargetType targetType() {
        return targetType;
    }

    @Override
    public InferenceResults infer(Map fields, InferenceConfig config, Map featureDecoderMap) {
        return innerInfer(InferenceModel.extractFeatures(featureNames, fields), config, featureDecoderMap);
    }

    @Override
    public InferenceResults infer(double[] features, InferenceConfig config) {
        return innerInfer(features, config, Collections.emptyMap());
    }

    private InferenceResults innerInfer(double[] features, InferenceConfig config, Map featureDecoderMap) {
        if (config.isTargetTypeSupported(targetType) == false) {
            throw ExceptionsHelper.badRequestException(
                "Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
        }
        if (preparedForInference == false) {
            throw ExceptionsHelper.serverError("model is not prepared for inference");
        }
        LOGGER.debug(
            () -> new ParameterizedMessage("Inference called with feature names [{}]", Strings.arrayToCommaDelimitedString(featureNames))
        );
        double[][] inferenceResults = new double[this.models.size()][];
        double[][] featureInfluence = new double[features.length][];
        int i = 0;
        NullInferenceConfig subModelInferenceConfig = new NullInferenceConfig(config.requestingImportance());
        for (InferenceModel model : models) {
            InferenceResults result = model.infer(features, subModelInferenceConfig);
            assert result instanceof RawInferenceResults;
            RawInferenceResults inferenceResult = (RawInferenceResults) result;
            inferenceResults[i++] = inferenceResult.getValue();
            if (config.requestingImportance()) {
                addFeatureImportance(featureInfluence, inferenceResult);
            }
        }
        double[] processed = outputAggregator.processValues(inferenceResults);
        return buildResults(processed, featureInfluence, featureDecoderMap, config);
    }

    //For testing
    double[][] featureImportance(double[] features) {
        double[][] featureInfluence = new double[features.length][];
        NullInferenceConfig subModelInferenceConfig = new NullInferenceConfig(true);
        for (InferenceModel model : models) {
            InferenceResults result = model.infer(features, subModelInferenceConfig);
            assert result instanceof RawInferenceResults;
            RawInferenceResults inferenceResult = (RawInferenceResults) result;
            addFeatureImportance(featureInfluence, inferenceResult);
        }
        return featureInfluence;
    }

    private void addFeatureImportance(double[][] featureInfluence, RawInferenceResults inferenceResult) {
        double[][] modelFeatureImportance = inferenceResult.getFeatureImportance();
        assert modelFeatureImportance.length == featureInfluence.length;
        for (int j = 0; j < modelFeatureImportance.length; j++) {
            if (featureInfluence[j] == null) {
                featureInfluence[j] = new double[modelFeatureImportance[j].length];
            }
            featureInfluence[j] = sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]);
        }
    }

    private InferenceResults buildResults(double[] processedInferences,
                                          double[][] featureImportance,
                                          Map featureDecoderMap,
                                          InferenceConfig config) {
        // Indicates that the config is useless and the caller just wants the raw value
        if (config instanceof NullInferenceConfig) {
            return new RawInferenceResults(
                new double[] {outputAggregator.aggregate(processedInferences)},
                featureImportance);
        }
        Map decodedFeatureImportance = config.requestingImportance() ?
            decodeFeatureImportances(featureDecoderMap,
                IntStream.range(0, featureImportance.length)
                    .boxed()
                    .collect(Collectors.toMap(i -> featureNames[i], i -> featureImportance[i]))) :
            Collections.emptyMap();
        switch(targetType) {
            case REGRESSION:
                return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences),
                    config,
                    transformFeatureImportanceRegression(decodedFeatureImportance));
            case CLASSIFICATION:
                ClassificationConfig classificationConfig = (ClassificationConfig) config;
                assert classificationWeights == null || processedInferences.length == classificationWeights.length;
                // Adjust the probabilities according to the thresholds
                Tuple> topClasses = InferenceHelpers.topClasses(
                    processedInferences,
                    classificationLabels,
                    classificationWeights,
                    classificationConfig.getNumTopClasses(),
                    classificationConfig.getPredictionFieldType());
                final InferenceHelpers.TopClassificationValue value = topClasses.v1();
                return new ClassificationInferenceResults(value.getValue(),
                    classificationLabel(topClasses.v1().getValue(), classificationLabels),
                    topClasses.v2(),
                    transformFeatureImportanceClassification(decodedFeatureImportance,
                        classificationLabels,
                        classificationConfig.getPredictionFieldType()),
                    config,
                    value.getProbability(),
                    value.getScore());
            default:
                throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on ensemble model");
        }
    }

    @Override
    public boolean supportsFeatureImportance() {
        return models.stream().allMatch(InferenceModel::supportsFeatureImportance);
    }

    @Override
    public String getName() {
        return "ensemble";
    }

    @Override
    public void rewriteFeatureIndices(final Map newFeatureIndexMapping) {
        LOGGER.debug(() -> new ParameterizedMessage("rewriting features {}", newFeatureIndexMapping));
        if (preparedForInference) {
            return;
        }
        preparedForInference = true;
        Map featureIndexMapping = new HashMap<>();
        if (newFeatureIndexMapping == null || newFeatureIndexMapping.isEmpty()) {
            Set referencedFeatures = subModelFeatures();
            LOGGER.debug(() -> new ParameterizedMessage("detected submodel feature names {}", referencedFeatures));
            int newFeatureIndex = 0;
            featureIndexMapping = new HashMap<>();
            this.featureNames = new String[referencedFeatures.size()];
            for (String featureName : referencedFeatures) {
                featureIndexMapping.put(featureName, newFeatureIndex);
                this.featureNames[newFeatureIndex++] = featureName;
            }
        } else {
            this.featureNames = new String[0];
        }
        for (InferenceModel model : models) {
            model.rewriteFeatureIndices(featureIndexMapping);
        }
    }

    private Set subModelFeatures() {
        Set referencedFeatures = new LinkedHashSet<>();
        for (InferenceModel model : models) {
            if (model instanceof EnsembleInferenceModel) {
                referencedFeatures.addAll(((EnsembleInferenceModel) model).subModelFeatures());
            } else {
                for (String featureName : model.getFeatureNames()) {
                    referencedFeatures.add(featureName);
                }
            }
        }
        return referencedFeatures;
    }

    @Override
    public long ramBytesUsed() {
        long size = SHALLOW_SIZE;
        size += RamUsageEstimator.sizeOf(featureNames);
        size += RamUsageEstimator.sizeOfCollection(classificationLabels);
        size += RamUsageEstimator.sizeOfCollection(models);
        if (classificationWeights != null) {
            size += RamUsageEstimator.sizeOf(classificationWeights);
        }
        size += outputAggregator.ramBytesUsed();
        return size;
    }

    public List getModels() {
        return models;
    }

    public OutputAggregator getOutputAggregator() {
        return outputAggregator;
    }

    public TargetType getTargetType() {
        return targetType;
    }

    public double[] getClassificationWeights() {
        return classificationWeights;
    }

    @Override
    public String toString() {
        return "EnsembleInferenceModel{" +
            "featureNames=" + Arrays.toString(featureNames) +
            ", models=" + models +
            ", outputAggregator=" + outputAggregator +
            ", targetType=" + targetType +
            ", classificationLabels=" + classificationLabels +
            ", classificationWeights=" + Arrays.toString(classificationWeights) +
            ", preparedForInference=" + preparedForInference +
            '}';
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy