All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.opensearch.ml.common.input.MLInput Maven / Gradle / Ivy

The newest version!
/*
 * Copyright OpenSearch Contributors
 * SPDX-License-Identifier: Apache-2.0
 */

package org.opensearch.ml.common.input;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.input.remote.RemoteInferenceMLInput.ACTION_TYPE_FIELD;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;

import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLCommonsClassLoader;
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataframe.DefaultDataFrame;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.QuestionAnsweringInputDataSet;
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.output.model.ModelResultFilter;
import org.opensearch.search.builder.SearchSourceBuilder;

import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

/**
 * ML input data: algorithm name, parameters and input data set.
 */
@Data
@NoArgsConstructor
public class MLInput implements Input {

    public static final String ALGORITHM_FIELD = "algorithm";
    public static final String ML_PARAMETERS_FIELD = "parameters";
    public static final String INPUT_INDEX_FIELD = "input_index";
    public static final String INPUT_QUERY_FIELD = "input_query";
    public static final String INPUT_DATA_FIELD = "input_data";

    // For trained model
    // Return bytes in model output
    public static final String RETURN_BYTES_FIELD = "return_bytes";
    // Return bytes in model output. This can be used together with return_bytes.
    public static final String RETURN_NUMBER_FIELD = "return_number";
    // Filter target response with name in model output
    public static final String TARGET_RESPONSE_FIELD = "target_response";
    // Filter target response with position in model output
    public static final String TARGET_RESPONSE_POSITIONS_FIELD = "target_response_positions";
    // Input text sentences for text embedding model
    public static final String TEXT_DOCS_FIELD = "text_docs";
    // Input query text to compare against for text similarity model
    public static final String QUERY_TEXT_FIELD = "query_text";
    public static final String PARAMETERS_FIELD = "parameters";

    // Input question in question answering model
    public static final String QUESTION_FIELD = "question";

    // Input context in question answering model
    public static final String CONTEXT_FIELD = "context";

    // Algorithm name
    protected FunctionName algorithm;
    // ML algorithm parameters
    protected MLAlgoParams parameters;
    // Input data to train model, run trained model to predict or run ML algorithms(no-model-based) directly.
    protected MLInputDataset inputDataset;

    private int version = 1;

    @Builder(toBuilder = true)
    public MLInput(FunctionName algorithm, MLAlgoParams parameters, MLInputDataset inputDataset) {
        validate(algorithm);
        this.algorithm = algorithm;
        this.parameters = parameters;
        this.inputDataset = inputDataset;
    }

    public MLInput(
        FunctionName algorithm,
        MLAlgoParams parameters,
        SearchSourceBuilder searchSourceBuilder,
        List sourceIndices,
        DataFrame dataFrame,
        MLInputDataset inputDataset
    ) {
        validate(algorithm);
        this.algorithm = algorithm;
        this.parameters = parameters;
        if (inputDataset != null) {
            this.inputDataset = inputDataset;
        } else {
            this.inputDataset = createInputDataSet(searchSourceBuilder, sourceIndices, dataFrame);
        }
    }

    private void validate(FunctionName algorithm) {
        if (algorithm == null) {
            throw new IllegalArgumentException("algorithm can't be null");
        }
    }

    public MLInput(StreamInput in) throws IOException {
        this.algorithm = in.readEnum(FunctionName.class);
        if (in.readBoolean()) {
            this.parameters = MLCommonsClassLoader.initMLInstance(algorithm, in, StreamInput.class);
        }
        if (in.readBoolean()) {
            this.inputDataset = MLInputDataset.fromStream(in);
        }
        this.version = in.readInt();
    }

    @Override
    public void writeTo(StreamOutput out) throws IOException {
        out.writeEnum(algorithm);
        if (parameters != null) {
            out.writeBoolean(true);
            parameters.writeTo(out);
        } else {
            out.writeBoolean(false);
        }
        if (inputDataset != null) {
            out.writeBoolean(true);
            inputDataset.writeTo(out);
        } else {
            out.writeBoolean(false);
        }
        out.writeInt(version);
    }

    @Override
    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
        builder.startObject();
        builder.field(ALGORITHM_FIELD, algorithm.name());
        if (parameters != null) {
            builder.field(ML_PARAMETERS_FIELD, parameters);
        }
        if (inputDataset != null) {
            switch (inputDataset.getInputDataType()) {
                case SEARCH_QUERY:
                    builder.field(INPUT_INDEX_FIELD, ((SearchQueryInputDataset) inputDataset).getIndices().toArray(new String[0]));
                    builder.field(INPUT_QUERY_FIELD, ((SearchQueryInputDataset) inputDataset).getSearchSourceBuilder());
                    break;
                case DATA_FRAME:
                    builder.startObject(INPUT_DATA_FIELD);
                    ((DataFrameInputDataset) inputDataset).getDataFrame().toXContent(builder, EMPTY_PARAMS);
                    builder.endObject();
                    break;
                case TEXT_DOCS:
                    TextDocsInputDataSet textInputDataSet = (TextDocsInputDataSet) this.inputDataset;
                    List docs = textInputDataSet.getDocs();
                    ModelResultFilter resultFilter = textInputDataSet.getResultFilter();
                    if (docs != null && docs.size() > 0) {
                        builder.field(TEXT_DOCS_FIELD, docs.toArray(new String[0]));
                    }
                    if (resultFilter != null) {
                        builder.field(RETURN_BYTES_FIELD, resultFilter.isReturnBytes());
                        builder.field(RETURN_NUMBER_FIELD, resultFilter.isReturnNumber());
                        List targetResponse = resultFilter.getTargetResponse();
                        if (targetResponse != null && targetResponse.size() > 0) {
                            builder.field(TARGET_RESPONSE_FIELD, targetResponse.toArray(new String[0]));
                        }
                        List targetPositions = resultFilter.getTargetResponsePositions();
                        if (targetPositions != null && targetPositions.size() > 0) {
                            builder.field(TARGET_RESPONSE_POSITIONS_FIELD, targetPositions.toArray(new Integer[0]));
                        }
                    }
                    break;
                case TEXT_SIMILARITY:
                    TextSimilarityInputDataSet inputDataSet = (TextSimilarityInputDataSet) this.inputDataset;
                    List documents = inputDataSet.getTextDocs();
                    String queryText = inputDataSet.getQueryText();
                    builder.field(QUERY_TEXT_FIELD, queryText);
                    if (documents != null && !documents.isEmpty()) {
                        builder.startArray(TEXT_DOCS_FIELD);
                        for (String d : documents) {
                            builder.value(d);
                        }
                        builder.endArray();
                    }
                    break;
                case QUESTION_ANSWERING:
                    QuestionAnsweringInputDataSet qaInputDataSet = (QuestionAnsweringInputDataSet) this.inputDataset;
                    String question = qaInputDataSet.getQuestion();
                    String context = qaInputDataSet.getContext();
                    builder.field(QUESTION_FIELD, question);
                    builder.field(CONTEXT_FIELD, context);
                    break;
                case REMOTE:
                    RemoteInferenceInputDataSet remoteInferenceInputDataSet = (RemoteInferenceInputDataSet) this.inputDataset;
                    Map parameters = remoteInferenceInputDataSet.getParameters();
                    builder.field(PARAMETERS_FIELD, parameters);
                    builder.field(ACTION_TYPE_FIELD, remoteInferenceInputDataSet.getActionType());
                    break;
                default:
                    break;
            }

        }
        builder.endObject();
        return builder;
    }

    public static MLInput parse(XContentParser parser, String inputAlgoName, ActionType actionType) throws IOException {
        MLInput mlInput = parse(parser, inputAlgoName);
        if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
            RemoteInferenceInputDataSet remoteInferenceInputDataSet = (RemoteInferenceInputDataSet) mlInput.getInputDataset();
            if (remoteInferenceInputDataSet.getActionType() == null) {
                remoteInferenceInputDataSet.setActionType(actionType);
            }
        }
        return mlInput;
    }

    public static MLInput parse(XContentParser parser, String inputAlgoName) throws IOException {
        String algorithmName = inputAlgoName.toUpperCase(Locale.ROOT);
        FunctionName algorithm = FunctionName.from(algorithmName);

        if (MLCommonsClassLoader.canInitMLInput(algorithm)) {
            MLInput mlInput = MLCommonsClassLoader
                .initMLInput(algorithm, new Object[] { parser, algorithm }, XContentParser.class, FunctionName.class);
            mlInput.setAlgorithm(algorithm);
            return mlInput;
        }

        MLAlgoParams mlParameters = null;
        SearchSourceBuilder searchSourceBuilder = null;
        List sourceIndices = new ArrayList<>();
        DataFrame dataFrame = null;

        boolean returnBytes = false;
        boolean returnNumber = true;
        List targetResponse = new ArrayList<>();
        List targetResponsePositions = new ArrayList<>();
        List textDocs = new ArrayList<>();
        String queryText = null;
        String question = null;
        String context = null;

        ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
        while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
            String fieldName = parser.currentName();
            parser.nextToken();

            switch (fieldName) {
                case ML_PARAMETERS_FIELD:
                    mlParameters = parser.namedObject(MLAlgoParams.class, algorithmName, null);
                    break;
                case INPUT_INDEX_FIELD:
                    ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
                    while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
                        sourceIndices.add(parser.text());
                    }
                    break;
                case INPUT_QUERY_FIELD:
                    ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
                    searchSourceBuilder = SearchSourceBuilder.fromXContent(parser, false);
                    break;
                case INPUT_DATA_FIELD:
                    dataFrame = DefaultDataFrame.parse(parser);
                    break;
                case RETURN_BYTES_FIELD:
                    returnBytes = parser.booleanValue();
                    break;
                case RETURN_NUMBER_FIELD:
                    returnNumber = parser.booleanValue();
                    break;
                case TARGET_RESPONSE_FIELD:
                    ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
                    while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
                        targetResponse.add(parser.text());
                    }
                    break;
                case TARGET_RESPONSE_POSITIONS_FIELD:
                    ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
                    while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
                        targetResponsePositions.add(parser.intValue());
                    }
                    break;
                case TEXT_DOCS_FIELD:
                    ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
                    while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
                        textDocs.add(parser.text());
                    }
                    break;
                case QUERY_TEXT_FIELD:
                    queryText = parser.text();
                    break;
                case QUESTION_FIELD:
                    question = parser.text();
                    break;
                case CONTEXT_FIELD:
                    context = parser.text();
                    break;
                default:
                    parser.skipChildren();
                    break;
            }
        }
        MLInputDataset inputDataSet = null;
        if (algorithm == FunctionName.TEXT_EMBEDDING
            || algorithm == FunctionName.SPARSE_ENCODING
            || algorithm == FunctionName.SPARSE_TOKENIZE) {
            ModelResultFilter filter = new ModelResultFilter(returnBytes, returnNumber, targetResponse, targetResponsePositions);
            inputDataSet = new TextDocsInputDataSet(textDocs, filter);
        } else if (algorithm == FunctionName.TEXT_SIMILARITY) {
            inputDataSet = new TextSimilarityInputDataSet(queryText, textDocs);
        } else if (algorithm == FunctionName.QUESTION_ANSWERING) {
            inputDataSet = new QuestionAnsweringInputDataSet(question, context);
        }
        return new MLInput(algorithm, mlParameters, searchSourceBuilder, sourceIndices, dataFrame, inputDataSet);
    }

    private MLInputDataset createInputDataSet(SearchSourceBuilder searchSourceBuilder, List sourceIndices, DataFrame dataFrame) {
        if (dataFrame != null) {
            return new DataFrameInputDataset(dataFrame);
        }
        if (sourceIndices != null && searchSourceBuilder != null) {
            return new SearchQueryInputDataset(sourceIndices, searchSourceBuilder);
        }
        return null;
    }

    @Override
    public FunctionName getFunctionName() {
        return this.algorithm;
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy