All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.opensearch.ml.common.input.nlp.QuestionAnsweringMLInput Maven / Gradle / Ivy

The newest version!
/*
 * Copyright OpenSearch Contributors
 * SPDX-License-Identifier: Apache-2.0
 */
package org.opensearch.ml.common.input.nlp;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

import java.io.IOException;

import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.QuestionAnsweringInputDataSet;
import org.opensearch.ml.common.input.MLInput;

/**
 * MLInput which supports a question answering algorithm
 * Inputs are question and context. Output is the answer
 */
@org.opensearch.ml.common.annotation.MLInput(functionNames = { FunctionName.QUESTION_ANSWERING })
public class QuestionAnsweringMLInput extends MLInput {

    public QuestionAnsweringMLInput(FunctionName algorithm, MLInputDataset dataset) {
        super(algorithm, null, dataset);
    }

    public QuestionAnsweringMLInput(StreamInput in) throws IOException {
        super(in);
    }

    @Override
    public void writeTo(StreamOutput out) throws IOException {
        super.writeTo(out);
    }

    @Override
    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
        builder.startObject();
        builder.field(ALGORITHM_FIELD, algorithm.name());
        if (parameters != null) {
            builder.field(ML_PARAMETERS_FIELD, parameters);
        }
        if (inputDataset != null) {
            QuestionAnsweringInputDataSet ds = (QuestionAnsweringInputDataSet) this.inputDataset;
            String question = ds.getQuestion();
            String context = ds.getContext();
            builder.field(QUESTION_FIELD, question);
            builder.field(CONTEXT_FIELD, context);
        }
        builder.endObject();
        return builder;
    }

    public QuestionAnsweringMLInput(XContentParser parser, FunctionName functionName) throws IOException {
        super();
        this.algorithm = functionName;
        String question = null;
        String context = null;

        ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
        while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
            String fieldName = parser.currentName();
            parser.nextToken();

            switch (fieldName) {
                case QUESTION_FIELD:
                    question = parser.text();
                    break;
                case CONTEXT_FIELD:
                    context = parser.text();
                    break;
                default:
                    parser.skipChildren();
                    break;
            }
        }
        if (question == null) {
            throw new IllegalArgumentException("Question is not provided");
        }
        if (context == null) {
            throw new IllegalArgumentException("Context is not provided");
        }
        inputDataset = new QuestionAnsweringInputDataSet(question, context);
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy