All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction Maven / Gradle / Ivy

/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0; you may not use this file except in compliance with the Elastic License
 * 2.0.
 */
package org.elasticsearch.xpack.core.ml.action;

import org.elasticsearch.Version;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest;
import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse;
import org.elasticsearch.xpack.core.action.util.QueryPage;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;


public class GetTrainedModelsAction extends ActionType {

    public static final GetTrainedModelsAction INSTANCE = new GetTrainedModelsAction();
    public static final String NAME = "cluster:monitor/xpack/ml/inference/get";

    private GetTrainedModelsAction() {
        super(NAME, Response::new);
    }

    public static class Includes implements Writeable {
        static final String DEFINITION = "definition";
        static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance";
        static final String FEATURE_IMPORTANCE_BASELINE = "feature_importance_baseline";
        static final String HYPERPARAMETERS = "hyperparameters";

        private static final Set KNOWN_INCLUDES;
        static {
            HashSet includes = new HashSet<>(4, 1.0f);
            includes.add(DEFINITION);
            includes.add(TOTAL_FEATURE_IMPORTANCE);
            includes.add(FEATURE_IMPORTANCE_BASELINE);
            includes.add(HYPERPARAMETERS);
            KNOWN_INCLUDES = Collections.unmodifiableSet(includes);
        }

        public static Includes forModelDefinition() {
            return new Includes(new HashSet<>(Collections.singletonList(DEFINITION)));
        }

        public static Includes empty() {
            return new Includes(new HashSet<>());
        }

        public static Includes all() {
            return new Includes(KNOWN_INCLUDES);
        }

        private final Set includes;

        public Includes(Set includes) {
            this.includes = includes == null ? Collections.emptySet() : includes;
            Set unknownIncludes = Sets.difference(this.includes, KNOWN_INCLUDES);
            if (unknownIncludes.isEmpty() == false) {
                throw ExceptionsHelper.badRequestException(
                    "unknown [include] parameters {}. Valid options are {}",
                    unknownIncludes,
                    KNOWN_INCLUDES);
            }
        }

        public Includes(StreamInput in) throws IOException {
            this.includes = in.readSet(StreamInput::readString);
        }

        @Override
        public void writeTo(StreamOutput out) throws IOException {
            out.writeCollection(this.includes, StreamOutput::writeString);
        }

        public boolean isIncludeModelDefinition() {
            return this.includes.contains(DEFINITION);
        }

        public boolean isIncludeTotalFeatureImportance() {
            return this.includes.contains(TOTAL_FEATURE_IMPORTANCE);
        }

        public boolean isIncludeFeatureImportanceBaseline() {
            return this.includes.contains(FEATURE_IMPORTANCE_BASELINE);
        }

        public boolean isIncludeHyperparameters() {
            return this.includes.contains(HYPERPARAMETERS);
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (o == null || getClass() != o.getClass()) return false;
            Includes includes1 = (Includes) o;
            return Objects.equals(includes, includes1.includes);
        }

        @Override
        public int hashCode() {
            return Objects.hash(includes);
        }
    }

    public static class Request extends AbstractGetResourcesRequest {

        public static final ParseField INCLUDE = new ParseField("include");
        public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition";
        public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match");
        public static final ParseField TAGS = new ParseField("tags");

        private final Includes includes;
        private final List tags;

        @Deprecated
        public Request(String id, boolean includeModelDefinition, List tags) {
            setResourceId(id);
            setAllowNoResources(true);
            this.tags = tags == null ? Collections.emptyList() : tags;
            if (includeModelDefinition) {
                this.includes = Includes.forModelDefinition();
            } else {
                this.includes = Includes.empty();
            }
        }

        public Request(String id, List tags, Set includes) {
            setResourceId(id);
            setAllowNoResources(true);
            this.tags = tags == null ? Collections.emptyList() : tags;
            this.includes = new Includes(includes);
        }

        public Request(StreamInput in) throws IOException {
            super(in);
            if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
                this.includes = new Includes(in);
            } else {
                this.includes = in.readBoolean() ? Includes.forModelDefinition() : Includes.empty();
            }
            if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
                this.tags = in.readStringList();
            } else {
                this.tags = Collections.emptyList();
            }
        }

        @Override
        public String getResourceIdField() {
            return TrainedModelConfig.MODEL_ID.getPreferredName();
        }

        public List getTags() {
            return tags;
        }

        public Includes getIncludes() {
            return includes;
        }

        @Override
        public void writeTo(StreamOutput out) throws IOException {
            super.writeTo(out);
            if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
                this.includes.writeTo(out);
            } else {
                out.writeBoolean(this.includes.isIncludeModelDefinition());
            }
            if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
                out.writeStringCollection(tags);
            }
        }

        @Override
        public int hashCode() {
            return Objects.hash(super.hashCode(), includes, tags);
        }

        @Override
        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            Request other = (Request) obj;
            return super.equals(obj) && Objects.equals(includes, other.includes) && Objects.equals(tags, other.tags);
        }

        @Override
        public String toString() {
            return "Request{" +
                "includes=" + includes +
                ", tags=" + tags +
                ", page=" + getPageParams() +
                ", id=" + getResourceId() +
                ", allow_missing=" + isAllowNoResources() +
                '}';
        }
    }

    public static class Response extends AbstractGetResourcesResponse {

        public static final ParseField RESULTS_FIELD = new ParseField("trained_model_configs");

        public Response(StreamInput in) throws IOException {
            super(in);
        }

        public Response(QueryPage trainedModels) {
            super(trainedModels);
        }

        @Override
        protected Reader getReader() {
            return TrainedModelConfig::new;
        }

        public static Builder builder() {
            return new Builder();
        }

        public static class Builder {

            private long totalCount;
            private List configs = Collections.emptyList();

            private Builder() {
            }

            public Builder setTotalCount(long totalCount) {
                this.totalCount = totalCount;
                return this;
            }

            public Builder setModels(List configs) {
                this.configs = configs;
                return this;
            }

            public Response build() {
                return new Response(new QueryPage<>(configs, totalCount, RESULTS_FIELD));
            }
        }
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy