All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.elasticsearch.xpack.core.ml.inference.results.LegacyFeatureImportance Maven / Gradle / Ivy

There is a newer version: 8.13.2
Show newest version
/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0; you may not use this file except in compliance with the Elastic License
 * 2.0.
 */
package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

/**
 * This class captures serialization of feature importance for
 * classification and regression prior to version 7.10.
 */
public class LegacyFeatureImportance implements Writeable {

    public static LegacyFeatureImportance fromClassification(ClassificationFeatureImportance classificationFeatureImportance) {
        return new LegacyFeatureImportance(
            classificationFeatureImportance.getFeatureName(),
            classificationFeatureImportance.getTotalImportance(),
            classificationFeatureImportance.getClassImportance().stream().map(classImportance -> new ClassImportance(
                classImportance.getClassName(), classImportance.getImportance())).collect(Collectors.toList())
        );
    }

    public static LegacyFeatureImportance fromRegression(RegressionFeatureImportance regressionFeatureImportance) {
        return new LegacyFeatureImportance(
            regressionFeatureImportance.getFeatureName(),
            regressionFeatureImportance.getImportance(),
            null
        );
    }

    private final List classImportance;
    private final double importance;
    private final String featureName;

    LegacyFeatureImportance(String featureName, double importance, List classImportance) {
        this.featureName = Objects.requireNonNull(featureName);
        this.importance = importance;
        this.classImportance = classImportance == null ? null : Collections.unmodifiableList(classImportance);
    }

    public LegacyFeatureImportance(StreamInput in) throws IOException {
        this.featureName = in.readString();
        this.importance = in.readDouble();
        if (in.readBoolean()) {
            if (in.getVersion().before(Version.V_7_10_0)) {
                Map classImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
                this.classImportance = ClassImportance.fromMap(classImportance);
            } else {
                this.classImportance = in.readList(ClassImportance::new);
            }
        } else {
            this.classImportance = null;
        }
    }

    @Override
    public void writeTo(StreamOutput out) throws IOException {
        out.writeString(featureName);
        out.writeDouble(importance);
        out.writeBoolean(classImportance != null);
        if (classImportance != null) {
            if (out.getVersion().before(Version.V_7_10_0)) {
                out.writeMap(ClassImportance.toMap(classImportance), StreamOutput::writeString, StreamOutput::writeDouble);
            } else {
                out.writeList(classImportance);
            }
        }
    }

    @Override
    public boolean equals(Object object) {
        if (object == this) { return true; }
        if (object == null || getClass() != object.getClass()) { return false; }
        LegacyFeatureImportance that = (LegacyFeatureImportance) object;
        return Objects.equals(featureName, that.featureName)
            && Objects.equals(importance, that.importance)
            && Objects.equals(classImportance, that.classImportance);
    }

    @Override
    public int hashCode() {
        return Objects.hash(featureName, importance, classImportance);
    }

    public RegressionFeatureImportance forRegression() {
        assert classImportance == null;
        return new RegressionFeatureImportance(featureName, importance);
    }

    public ClassificationFeatureImportance forClassification() {
        assert classImportance != null;
        return new ClassificationFeatureImportance(featureName, classImportance.stream().map(
            aClassImportance -> new ClassificationFeatureImportance.ClassImportance(
                aClassImportance.className, aClassImportance.importance)).collect(Collectors.toList()));
    }

    public static class ClassImportance implements Writeable {

        private static ClassImportance fromMapEntry(Map.Entry entry) {
            return new ClassImportance(entry.getKey(), entry.getValue());
        }

        private static List fromMap(Map classImportanceMap) {
            return classImportanceMap.entrySet().stream().map(ClassImportance::fromMapEntry).collect(Collectors.toList());
        }

        private static Map toMap(List importances) {
            return importances.stream().collect(Collectors.toMap(i -> i.className.toString(), i -> i.importance));
        }

        private final Object className;
        private final double importance;

        public ClassImportance(Object className, double importance) {
            this.className = className;
            this.importance = importance;
        }

        public ClassImportance(StreamInput in) throws IOException {
            this.className = in.readGenericValue();
            this.importance = in.readDouble();
        }

        double getImportance() {
            return importance;
        }

        @Override
        public void writeTo(StreamOutput out) throws IOException {
            out.writeGenericValue(className);
            out.writeDouble(importance);
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (o == null || getClass() != o.getClass()) return false;
            ClassImportance that = (ClassImportance) o;
            return Double.compare(that.importance, importance) == 0 &&
                Objects.equals(className, that.className);
        }

        @Override
        public int hashCode() {
            return Objects.hash(className, importance);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy