All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.opensearch.ml.common.model.TextEmbeddingModelConfig Maven / Gradle / Ivy

The newest version!
/*
 * Copyright OpenSearch Contributors
 * SPDX-License-Identifier: Apache-2.0
 */

package org.opensearch.ml.common.model;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

import java.io.IOException;
import java.util.Locale;

import org.opensearch.core.ParseField;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;

import lombok.Builder;
import lombok.Getter;
import lombok.Setter;

@Setter
@Getter
public class TextEmbeddingModelConfig extends MLModelConfig {
    public static final String PARSE_FIELD_NAME = FunctionName.TEXT_EMBEDDING.name();
    public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry(
        TextEmbeddingModelConfig.class,
        new ParseField(PARSE_FIELD_NAME),
        it -> parse(it)
    );

    public static final String EMBEDDING_DIMENSION_FIELD = "embedding_dimension";
    public static final String FRAMEWORK_TYPE_FIELD = "framework_type";
    public static final String POOLING_MODE_FIELD = "pooling_mode";
    public static final String NORMALIZE_RESULT_FIELD = "normalize_result";
    public static final String MODEL_MAX_LENGTH_FIELD = "model_max_length";
    public static final String QUERY_PREFIX = "query_prefix";
    public static final String PASSAGE_PREFIX = "passage_prefix";

    private final Integer embeddingDimension;
    private final FrameworkType frameworkType;
    private final PoolingMode poolingMode;
    private final boolean normalizeResult;
    private final Integer modelMaxLength;
    private final String queryPrefix;
    private final String passagePrefix;

    public TextEmbeddingModelConfig(
        String modelType,
        Integer embeddingDimension,
        FrameworkType frameworkType,
        String allConfig,
        PoolingMode poolingMode,
        boolean normalizeResult,
        Integer modelMaxLength
    ) {
        this(modelType, embeddingDimension, frameworkType, allConfig, poolingMode, normalizeResult, modelMaxLength, null, null);
    }

    @Builder(toBuilder = true)
    public TextEmbeddingModelConfig(
        String modelType,
        Integer embeddingDimension,
        FrameworkType frameworkType,
        String allConfig,
        PoolingMode poolingMode,
        boolean normalizeResult,
        Integer modelMaxLength,
        String queryPrefix,
        String passagePrefix
    ) {
        super(modelType, allConfig);
        if (embeddingDimension == null) {
            throw new IllegalArgumentException("embedding dimension is null");
        }
        if (frameworkType == null) {
            throw new IllegalArgumentException("framework type is null");
        }
        this.embeddingDimension = embeddingDimension;
        this.frameworkType = frameworkType;
        this.poolingMode = poolingMode;
        this.normalizeResult = normalizeResult;
        this.modelMaxLength = modelMaxLength;
        this.queryPrefix = queryPrefix;
        this.passagePrefix = passagePrefix;
    }

    public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOException {
        String modelType = null;
        Integer embeddingDimension = null;
        FrameworkType frameworkType = null;
        String allConfig = null;
        PoolingMode poolingMode = null;
        boolean normalizeResult = false;
        Integer modelMaxLength = null;
        String queryPrefix = null;
        String passagePrefix = null;

        ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
        while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
            String fieldName = parser.currentName();
            parser.nextToken();

            switch (fieldName) {
                case MODEL_TYPE_FIELD:
                    modelType = parser.text();
                    break;
                case EMBEDDING_DIMENSION_FIELD:
                    embeddingDimension = parser.intValue();
                    break;
                case FRAMEWORK_TYPE_FIELD:
                    frameworkType = FrameworkType.from(parser.text().toUpperCase(Locale.ROOT));
                    break;
                case ALL_CONFIG_FIELD:
                    allConfig = parser.text();
                    break;
                case POOLING_MODE_FIELD:
                    poolingMode = PoolingMode.from(parser.text().toUpperCase(Locale.ROOT));
                    break;
                case NORMALIZE_RESULT_FIELD:
                    normalizeResult = parser.booleanValue();
                    break;
                case MODEL_MAX_LENGTH_FIELD:
                    modelMaxLength = parser.intValue();
                    break;
                case QUERY_PREFIX:
                    queryPrefix = parser.text();
                    break;
                case PASSAGE_PREFIX:
                    passagePrefix = parser.text();
                    break;
                default:
                    parser.skipChildren();
                    break;
            }
        }
        return new TextEmbeddingModelConfig(
            modelType,
            embeddingDimension,
            frameworkType,
            allConfig,
            poolingMode,
            normalizeResult,
            modelMaxLength,
            queryPrefix,
            passagePrefix
        );
    }

    @Override
    public String getWriteableName() {
        return PARSE_FIELD_NAME;
    }

    public TextEmbeddingModelConfig(StreamInput in) throws IOException {
        super(in);
        embeddingDimension = in.readInt();
        frameworkType = in.readEnum(FrameworkType.class);
        if (in.readBoolean()) {
            poolingMode = in.readEnum(PoolingMode.class);
        } else {
            poolingMode = null;
        }
        normalizeResult = in.readBoolean();
        modelMaxLength = in.readOptionalInt();
        queryPrefix = in.readOptionalString();
        passagePrefix = in.readOptionalString();
    }

    @Override
    public void writeTo(StreamOutput out) throws IOException {
        super.writeTo(out);
        out.writeInt(embeddingDimension);
        out.writeEnum(frameworkType);
        if (poolingMode != null) {
            out.writeBoolean(true);
            out.writeEnum(poolingMode);
        } else {
            out.writeBoolean(false);
        }
        out.writeBoolean(normalizeResult);
        out.writeOptionalInt(modelMaxLength);
        out.writeOptionalString(queryPrefix);
        out.writeOptionalString(passagePrefix);
    }

    @Override
    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        if (modelType != null) {
            builder.field(MODEL_TYPE_FIELD, modelType);
        }
        if (embeddingDimension != null) {
            builder.field(EMBEDDING_DIMENSION_FIELD, embeddingDimension);
        }
        if (frameworkType != null) {
            builder.field(FRAMEWORK_TYPE_FIELD, frameworkType);
        }
        if (allConfig != null) {
            builder.field(ALL_CONFIG_FIELD, allConfig);
        }
        if (modelMaxLength != null) {
            builder.field(MODEL_MAX_LENGTH_FIELD, modelMaxLength);
        }
        if (poolingMode != null) {
            builder.field(POOLING_MODE_FIELD, poolingMode);
        }
        if (normalizeResult) {
            builder.field(NORMALIZE_RESULT_FIELD, normalizeResult);
        }
        if (queryPrefix != null) {
            builder.field(QUERY_PREFIX, queryPrefix);
        }
        if (passagePrefix != null) {
            builder.field(PASSAGE_PREFIX, passagePrefix);
        }
        builder.endObject();
        return builder;
    }

    public enum PoolingMode {
        MEAN("mean"),
        MEAN_SQRT_LEN("mean_sqrt_len"),
        MAX("max"),
        WEIGHTED_MEAN("weightedmean"),
        CLS("cls"),
        LAST_TOKEN("lasttoken");

        private String name;

        public String getName() {
            return name;
        }

        PoolingMode(String name) {
            this.name = name;
        }

        public static PoolingMode from(String value) {
            try {
                return PoolingMode.valueOf(value.toUpperCase(Locale.ROOT));
            } catch (Exception e) {
                throw new IllegalArgumentException("Wrong pooling method");
            }
        }
    }

    public enum FrameworkType {
        HUGGINGFACE_TRANSFORMERS,
        SENTENCE_TRANSFORMERS,
        HUGGINGFACE_TRANSFORMERS_NEURON;

        public static FrameworkType from(String value) {
            try {
                return FrameworkType.valueOf(value.toUpperCase(Locale.ROOT));
            } catch (Exception e) {
                throw new IllegalArgumentException("Wrong framework type");
            }
        }
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy