All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.opensearch.ml.common.model.ModelGuardrail Maven / Gradle / Ivy

There is a newer version: 2.17.1.0
Show newest version
/*
 * Copyright OpenSearch Contributors
 * SPDX-License-Identifier: Apache-2.0
 */

package org.opensearch.ml.common.model;

import static java.util.concurrent.TimeUnit.SECONDS;
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.LatchedActionListener;
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NonNull;
import lombok.extern.log4j.Log4j2;

@Log4j2
@EqualsAndHashCode
@Getter
public class ModelGuardrail extends Guardrail {
    public static final String MODEL_ID_FIELD = "model_id";
    public static final String RESPONSE_FILTER_FIELD = "response_filter";
    public static final String RESPONSE_VALIDATION_REGEX_FIELD = "response_validation_regex";

    private String modelId;
    private String responseFilter;
    private String responseAccept;
    private NamedXContentRegistry xContentRegistry;
    private Client client;
    private Pattern regexAcceptPattern;

    @Builder(toBuilder = true)
    public ModelGuardrail(String modelId, String responseFilter, String responseAccept) {
        this.modelId = modelId;
        this.responseFilter = responseFilter;
        this.responseAccept = responseAccept;
    }

    public ModelGuardrail(@NonNull Map params) {
        this(
            (String) params.get(MODEL_ID_FIELD),
            (String) params.get(RESPONSE_FILTER_FIELD),
            (String) params.get(RESPONSE_VALIDATION_REGEX_FIELD)
        );
    }

    public ModelGuardrail(StreamInput input) throws IOException {
        modelId = input.readString();
        responseFilter = input.readString();
        responseAccept = input.readString();
    }

    public void writeTo(StreamOutput out) throws IOException {
        out.writeString(modelId);
        out.writeString(responseFilter);
        out.writeString(responseAccept);
    }

    private Boolean validateAcceptRegex(String input) {
        Matcher matcher = regexAcceptPattern.matcher(input);
        return matcher.matches();
    }

    @Override
    public Boolean validate(String in, Map parameters) {
        String input = parameters == null ? null : parameters.get("question");
        if (input == null || input.isEmpty()) {
            log.info("Guardrail request is empty.");
            return true;
        }
        log.info("Guardrail request: {}", input);
        AtomicBoolean isAccepted = new AtomicBoolean(true);
        ActionListener internalListener = ActionListener.wrap(predictionResponse -> {
            ModelTensorOutput output = (ModelTensorOutput) predictionResponse.getOutput();
            ModelTensor tensor = output.getMlModelOutputs().get(0).getMlModelTensors().get(0);
            String guardrailResponse = AccessController
                .doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(tensor.getDataAsMap().get("response")));
            log.info("Guardrail response: {}", guardrailResponse);
            if (!validateAcceptRegex(guardrailResponse)) {
                isAccepted.set(false);
            }
        }, e -> { log.error("[ModelGuardrail] Failed to get prediction response.", e); });
        ActionListener actionListener = wrapActionListener(internalListener, res -> {
            MLTaskResponse predictionResponse = MLTaskResponse.fromActionResponse(res);
            return predictionResponse;
        });
        CountDownLatch latch = new CountDownLatch(1);
        Map guardrailModelParams = new HashMap<>();
        guardrailModelParams.put("question", input);
        if (responseFilter != null && !responseFilter.isEmpty()) {
            guardrailModelParams.put("response_filter", responseFilter);
        }
        log.info("Guardrail resFilter: {}", responseFilter);
        ActionRequest request = new MLPredictionTaskRequest(
            modelId,
            RemoteInferenceMLInput
                .builder()
                .algorithm(FunctionName.REMOTE)
                .inputDataset(RemoteInferenceInputDataSet.builder().parameters(guardrailModelParams).build())
                .build()
        );
        client.execute(MLPredictionTaskAction.INSTANCE, request, new LatchedActionListener(actionListener, latch));
        try {
            latch.await(5, SECONDS);
        } catch (InterruptedException e) {
            log.error("[ModelGuardrail] Validation was timeout.", e);
        }

        return isAccepted.get();
    }

    @Override
    public void init(NamedXContentRegistry xContentRegistry, Client client) {
        this.xContentRegistry = xContentRegistry;
        this.client = client;
        regexAcceptPattern = Pattern.compile(responseAccept);
    }

    @Override
    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
        builder.startObject();
        if (modelId != null) {
            builder.field(MODEL_ID_FIELD, modelId);
        }
        if (responseFilter != null) {
            builder.field(RESPONSE_FILTER_FIELD, responseFilter);
        }
        if (responseAccept != null) {
            builder.field(RESPONSE_VALIDATION_REGEX_FIELD, responseAccept);
        }
        builder.endObject();
        return builder;
    }

    public static ModelGuardrail parse(XContentParser parser) throws IOException {
        String modelId = null;
        String responseFilter = null;
        String responseAccept = null;

        ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
        while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
            String fieldName = parser.currentName();
            parser.nextToken();

            switch (fieldName) {
                case MODEL_ID_FIELD:
                    modelId = parser.text();
                    break;
                case RESPONSE_FILTER_FIELD:
                    responseFilter = parser.text();
                    break;
                case RESPONSE_VALIDATION_REGEX_FIELD:
                    responseAccept = parser.text();
                    break;
                default:
                    parser.skipChildren();
                    break;
            }
        }
        return ModelGuardrail.builder().modelId(modelId).responseFilter(responseFilter).responseAccept(responseAccept).build();
    }

    private  ActionListener wrapActionListener(
        final ActionListener listener,
        final Function recreate
    ) {
        ActionListener actionListener = ActionListener.wrap(r -> {
            listener.onResponse(recreate.apply(r));
            ;
        }, e -> { listener.onFailure(e); });
        return actionListener;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy