All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.opensearch.ml.common.model.Guardrails Maven / Gradle / Ivy

There is a newer version: 2.17.1.0
Show newest version
/*
 * Copyright OpenSearch Contributors
 * SPDX-License-Identifier: Apache-2.0
 */

package org.opensearch.ml.common.model;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

import java.io.IOException;
import java.util.Map;
import java.util.Set;

import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;

@EqualsAndHashCode
@Getter
public class Guardrails implements ToXContentObject {
    public static final String TYPE_FIELD = "type";
    public static final String INPUT_GUARDRAIL_FIELD = "input_guardrail";
    public static final String OUTPUT_GUARDRAIL_FIELD = "output_guardrail";
    public static final Set types = Set.of("local_regex", "model");

    private String type;
    private Guardrail inputGuardrail;
    private Guardrail outputGuardrail;

    @Builder(toBuilder = true)
    public Guardrails(String type, Guardrail inputGuardrail, Guardrail outputGuardrail) {
        this.type = type;
        this.inputGuardrail = inputGuardrail;
        this.outputGuardrail = outputGuardrail;
    }

    public Guardrails(StreamInput input) throws IOException {
        type = input.readString();
        if (input.readBoolean()) {
            switch (type) {
                case "local_regex":
                    inputGuardrail = new LocalRegexGuardrail(input);
                    break;
                case "model":
                    break;
                default:
                    throw new IllegalArgumentException(String.format("Unsupported guardrails type: %s", type));
            }
        }
        if (input.readBoolean()) {
            switch (type) {
                case "local_regex":
                    outputGuardrail = new LocalRegexGuardrail(input);
                    break;
                case "model":
                    break;
                default:
                    throw new IllegalArgumentException(String.format("Unsupported guardrails type: %s", type));
            }
        }
    }

    public void writeTo(StreamOutput out) throws IOException {
        out.writeString(type);
        if (inputGuardrail != null) {
            out.writeBoolean(true);
            inputGuardrail.writeTo(out);
        } else {
            out.writeBoolean(false);
        }
        if (outputGuardrail != null) {
            out.writeBoolean(true);
            outputGuardrail.writeTo(out);
        } else {
            out.writeBoolean(false);
        }
    }

    @Override
    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
        builder.startObject();
        if (type != null) {
            builder.field(TYPE_FIELD, type);
        }
        if (inputGuardrail != null) {
            builder.field(INPUT_GUARDRAIL_FIELD, inputGuardrail);
        }
        if (outputGuardrail != null) {
            builder.field(OUTPUT_GUARDRAIL_FIELD, outputGuardrail);
        }
        builder.endObject();
        return builder;
    }

    public static Guardrails parse(XContentParser parser) throws IOException {
        String type = null;
        Map inputGuardrailMap = null;
        Map outputGuardrailMap = null;

        ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
        while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
            String fieldName = parser.currentName();
            parser.nextToken();

            switch (fieldName) {
                case TYPE_FIELD:
                    type = parser.text();
                    break;
                case INPUT_GUARDRAIL_FIELD:
                    inputGuardrailMap = parser.map();
                    break;
                case OUTPUT_GUARDRAIL_FIELD:
                    outputGuardrailMap = parser.map();
                    break;
                default:
                    parser.skipChildren();
                    break;
            }
        }
        if (type == null) {
            type = "local_regex";
        }
        if (!validateType(type)) {
            throw new IllegalArgumentException("The type of guardrails is required, can not be null.");
        }

        return Guardrails
            .builder()
            .type(type)
            .inputGuardrail(createGuardrail(type, inputGuardrailMap))
            .outputGuardrail(createGuardrail(type, outputGuardrailMap))
            .build();
    }

    private static Boolean validateType(String type) {
        if (types.contains(type)) {
            return true;
        }
        return false;
    }

    private static Guardrail createGuardrail(String type, Map params) {
        if (params == null || params.isEmpty()) {
            return null;
        }

        switch (type) {
            case "local_regex":
                return new LocalRegexGuardrail(params);
            case "model":
                return new ModelGuardrail(params);
            default:
                throw new IllegalArgumentException(String.format("Unsupported guardrails type: %s", type));
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy