All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.elasticsearch.client.ml.inference.TrainedModelConfig Maven / Gradle / Ivy

There is a newer version: 8.0.0-alpha2
Show newest version
/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0 and the Server Side Public License, v 1; you may not use this file except
 * in compliance with, at your election, the Elastic License 2.0 or the Server
 * Side Public License, v 1.
 */
package org.elasticsearch.client.ml.inference;

import org.elasticsearch.Version;
import org.elasticsearch.client.common.TimeUtil;
import org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.client.ml.inference.NamedXContentObjectHelper.writeNamedObject;

public class TrainedModelConfig implements ToXContentObject {

    public static final String NAME = "trained_model_config";

    public static final ParseField MODEL_ID = new ParseField("model_id");
    public static final ParseField CREATED_BY = new ParseField("created_by");
    public static final ParseField VERSION = new ParseField("version");
    public static final ParseField DESCRIPTION = new ParseField("description");
    public static final ParseField CREATE_TIME = new ParseField("create_time");
    public static final ParseField DEFINITION = new ParseField("definition");
    public static final ParseField COMPRESSED_DEFINITION = new ParseField("compressed_definition");
    public static final ParseField TAGS = new ParseField("tags");
    public static final ParseField METADATA = new ParseField("metadata");
    public static final ParseField INPUT = new ParseField("input");
    @Deprecated
    public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes");
    public static final ParseField MODEL_SIZE_BYTES = new ParseField("model_size_bytes", "estimated_heap_memory_usage_bytes");
    public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations");
    public static final ParseField LICENSE_LEVEL = new ParseField("license_level");
    public static final ParseField DEFAULT_FIELD_MAP = new ParseField("default_field_map");
    public static final ParseField INFERENCE_CONFIG = new ParseField("inference_config");

    public static final ObjectParser PARSER = new ObjectParser<>(NAME, true, TrainedModelConfig.Builder::new);
    static {
        PARSER.declareString(TrainedModelConfig.Builder::setModelId, MODEL_ID);
        PARSER.declareString(TrainedModelConfig.Builder::setCreatedBy, CREATED_BY);
        PARSER.declareString(TrainedModelConfig.Builder::setVersion, VERSION);
        PARSER.declareString(TrainedModelConfig.Builder::setDescription, DESCRIPTION);
        PARSER.declareField(
            TrainedModelConfig.Builder::setCreateTime,
            (p, c) -> TimeUtil.parseTimeFieldToInstant(p, CREATE_TIME.getPreferredName()),
            CREATE_TIME,
            ObjectParser.ValueType.VALUE
        );
        PARSER.declareObject(TrainedModelConfig.Builder::setDefinition, (p, c) -> TrainedModelDefinition.fromXContent(p), DEFINITION);
        PARSER.declareString(TrainedModelConfig.Builder::setCompressedDefinition, COMPRESSED_DEFINITION);
        PARSER.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS);
        PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
        PARSER.declareObject(TrainedModelConfig.Builder::setInput, (p, c) -> TrainedModelInput.fromXContent(p), INPUT);
        PARSER.declareLong(TrainedModelConfig.Builder::setModelSize, MODEL_SIZE_BYTES);
        PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS);
        PARSER.declareString(TrainedModelConfig.Builder::setLicenseLevel, LICENSE_LEVEL);
        PARSER.declareObject(TrainedModelConfig.Builder::setDefaultFieldMap, (p, c) -> p.mapStrings(), DEFAULT_FIELD_MAP);
        PARSER.declareNamedObject(
            TrainedModelConfig.Builder::setInferenceConfig,
            (p, c, n) -> p.namedObject(InferenceConfig.class, n, null),
            INFERENCE_CONFIG
        );
    }

    public static TrainedModelConfig fromXContent(XContentParser parser) throws IOException {
        return PARSER.parse(parser, null).build();
    }

    private final String modelId;
    private final String createdBy;
    private final Version version;
    private final String description;
    private final Instant createTime;
    private final TrainedModelDefinition definition;
    private final String compressedDefinition;
    private final List tags;
    private final Map metadata;
    private final TrainedModelInput input;
    private final Long modelSize;
    private final Long estimatedOperations;
    private final String licenseLevel;
    private final Map defaultFieldMap;
    private final InferenceConfig inferenceConfig;

    TrainedModelConfig(
        String modelId,
        String createdBy,
        Version version,
        String description,
        Instant createTime,
        TrainedModelDefinition definition,
        String compressedDefinition,
        List tags,
        Map metadata,
        TrainedModelInput input,
        Long modelSize,
        Long estimatedOperations,
        String licenseLevel,
        Map defaultFieldMap,
        InferenceConfig inferenceConfig
    ) {
        this.modelId = modelId;
        this.createdBy = createdBy;
        this.version = version;
        this.createTime = createTime == null ? null : Instant.ofEpochMilli(createTime.toEpochMilli());
        this.definition = definition;
        this.compressedDefinition = compressedDefinition;
        this.description = description;
        this.tags = tags == null ? null : Collections.unmodifiableList(tags);
        this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata);
        this.input = input;
        this.modelSize = modelSize;
        this.estimatedOperations = estimatedOperations;
        this.licenseLevel = licenseLevel;
        this.defaultFieldMap = defaultFieldMap == null ? null : Collections.unmodifiableMap(defaultFieldMap);
        this.inferenceConfig = inferenceConfig;
    }

    public String getModelId() {
        return modelId;
    }

    public String getCreatedBy() {
        return createdBy;
    }

    public Version getVersion() {
        return version;
    }

    public String getDescription() {
        return description;
    }

    public Instant getCreateTime() {
        return createTime;
    }

    public List getTags() {
        return tags;
    }

    public Map getMetadata() {
        return metadata;
    }

    public TrainedModelDefinition getDefinition() {
        return definition;
    }

    public String getCompressedDefinition() {
        return compressedDefinition;
    }

    public TrainedModelInput getInput() {
        return input;
    }

    /**
     * @deprecated use {@link TrainedModelConfig#getModelSize()} instead
     * @return the {@link ByteSizeValue} of the model size if available.
     */
    @Deprecated
    public ByteSizeValue getEstimatedHeapMemory() {
        return modelSize == null ? null : new ByteSizeValue(modelSize);
    }

    /**
     * @deprecated use {@link TrainedModelConfig#getModelSizeBytes()} instead
     * @return the model size in bytes if available.
     */
    @Deprecated
    public Long getEstimatedHeapMemoryBytes() {
        return modelSize;
    }

    /**
     * @return the {@link ByteSizeValue} of the model size if available.
     */
    public ByteSizeValue getModelSize() {
        return modelSize == null ? null : new ByteSizeValue(modelSize);
    }

    /**
     * @return the model size in bytes if available.
     */
    public Long getModelSizeBytes() {
        return modelSize;
    }

    public String getLicenseLevel() {
        return licenseLevel;
    }

    public Map getDefaultFieldMap() {
        return defaultFieldMap;
    }

    public InferenceConfig getInferenceConfig() {
        return inferenceConfig;
    }

    public static Builder builder() {
        return new Builder();
    }

    @Override
    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
        builder.startObject();
        if (modelId != null) {
            builder.field(MODEL_ID.getPreferredName(), modelId);
        }
        if (createdBy != null) {
            builder.field(CREATED_BY.getPreferredName(), createdBy);
        }
        if (version != null) {
            builder.field(VERSION.getPreferredName(), version.toString());
        }
        if (description != null) {
            builder.field(DESCRIPTION.getPreferredName(), description);
        }
        if (createTime != null) {
            builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli());
        }
        if (definition != null) {
            builder.field(DEFINITION.getPreferredName(), definition);
        }
        if (tags != null) {
            builder.field(TAGS.getPreferredName(), tags);
        }
        if (metadata != null) {
            builder.field(METADATA.getPreferredName(), metadata);
        }
        if (input != null) {
            builder.field(INPUT.getPreferredName(), input);
        }
        if (modelSize != null) {
            builder.field(MODEL_SIZE_BYTES.getPreferredName(), modelSize);
        }
        if (estimatedOperations != null) {
            builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
        }
        if (compressedDefinition != null) {
            builder.field(COMPRESSED_DEFINITION.getPreferredName(), compressedDefinition);
        }
        if (licenseLevel != null) {
            builder.field(LICENSE_LEVEL.getPreferredName(), licenseLevel);
        }
        if (defaultFieldMap != null) {
            builder.field(DEFAULT_FIELD_MAP.getPreferredName(), defaultFieldMap);
        }
        if (inferenceConfig != null) {
            writeNamedObject(builder, params, INFERENCE_CONFIG.getPreferredName(), inferenceConfig);
        }
        builder.endObject();
        return builder;
    }

    @Override
    public String toString() {
        return Strings.toString(this);
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) return true;
        if (o == null || getClass() != o.getClass()) return false;
        TrainedModelConfig that = (TrainedModelConfig) o;
        return Objects.equals(modelId, that.modelId)
            && Objects.equals(createdBy, that.createdBy)
            && Objects.equals(version, that.version)
            && Objects.equals(description, that.description)
            && Objects.equals(createTime, that.createTime)
            && Objects.equals(definition, that.definition)
            && Objects.equals(compressedDefinition, that.compressedDefinition)
            && Objects.equals(tags, that.tags)
            && Objects.equals(input, that.input)
            && Objects.equals(modelSize, that.modelSize)
            && Objects.equals(estimatedOperations, that.estimatedOperations)
            && Objects.equals(licenseLevel, that.licenseLevel)
            && Objects.equals(defaultFieldMap, that.defaultFieldMap)
            && Objects.equals(inferenceConfig, that.inferenceConfig)
            && Objects.equals(metadata, that.metadata);
    }

    @Override
    public int hashCode() {
        return Objects.hash(
            modelId,
            createdBy,
            version,
            createTime,
            definition,
            compressedDefinition,
            description,
            tags,
            modelSize,
            estimatedOperations,
            metadata,
            licenseLevel,
            input,
            inferenceConfig,
            defaultFieldMap
        );
    }

    public static class Builder {

        private String modelId;
        private String createdBy;
        private Version version;
        private String description;
        private Instant createTime;
        private Map metadata;
        private List tags;
        private TrainedModelDefinition definition;
        private String compressedDefinition;
        private TrainedModelInput input;
        private Long modelSize;
        private Long estimatedOperations;
        private String licenseLevel;
        private Map defaultFieldMap;
        private InferenceConfig inferenceConfig;

        public Builder setModelId(String modelId) {
            this.modelId = modelId;
            return this;
        }

        private Builder setCreatedBy(String createdBy) {
            this.createdBy = createdBy;
            return this;
        }

        private Builder setVersion(Version version) {
            this.version = version;
            return this;
        }

        private Builder setVersion(String version) {
            return this.setVersion(Version.fromString(version));
        }

        public Builder setDescription(String description) {
            this.description = description;
            return this;
        }

        private Builder setCreateTime(Instant createTime) {
            this.createTime = createTime;
            return this;
        }

        public Builder setTags(List tags) {
            this.tags = tags;
            return this;
        }

        public Builder setTags(String... tags) {
            return setTags(Arrays.asList(tags));
        }

        public Builder setMetadata(Map metadata) {
            this.metadata = metadata;
            return this;
        }

        public Builder setDefinition(TrainedModelDefinition.Builder definition) {
            this.definition = definition == null ? null : definition.build();
            return this;
        }

        public Builder setCompressedDefinition(String compressedDefinition) {
            this.compressedDefinition = compressedDefinition;
            return this;
        }

        public Builder setDefinition(TrainedModelDefinition definition) {
            this.definition = definition;
            return this;
        }

        public Builder setInput(TrainedModelInput input) {
            this.input = input;
            return this;
        }

        private Builder setModelSize(Long modelSize) {
            this.modelSize = modelSize;
            return this;
        }

        private Builder setEstimatedOperations(Long estimatedOperations) {
            this.estimatedOperations = estimatedOperations;
            return this;
        }

        private Builder setLicenseLevel(String licenseLevel) {
            this.licenseLevel = licenseLevel;
            return this;
        }

        public Builder setDefaultFieldMap(Map defaultFieldMap) {
            this.defaultFieldMap = defaultFieldMap;
            return this;
        }

        public Builder setInferenceConfig(InferenceConfig inferenceConfig) {
            this.inferenceConfig = inferenceConfig;
            return this;
        }

        public TrainedModelConfig build() {
            return new TrainedModelConfig(
                modelId,
                createdBy,
                version,
                description,
                createTime,
                definition,
                compressedDefinition,
                tags,
                metadata,
                input,
                modelSize,
                estimatedOperations,
                licenseLevel,
                defaultFieldMap,
                inferenceConfig
            );
        }
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy