All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.opensearch.ml.common.model.LocalRegexGuardrail Maven / Gradle / Ivy

The newest version!
/*
 * Copyright OpenSearch Contributors
 * SPDX-License-Identifier: Apache-2.0
 */

package org.opensearch.ml.common.model;

import static java.util.concurrent.TimeUnit.SECONDS;
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.stopWordsIndices;
import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import org.opensearch.action.LatchedActionListener;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.search.builder.SearchSourceBuilder;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NonNull;
import lombok.extern.log4j.Log4j2;

@Log4j2
@EqualsAndHashCode
@Getter
public class LocalRegexGuardrail extends Guardrail {
    public static final String STOP_WORDS_FIELD = "stop_words";
    public static final String REGEX_FIELD = "regex";

    private List stopWords;
    private String[] regex;
    private List regexPattern;
    private Map> stopWordsIndicesInput;
    private NamedXContentRegistry xContentRegistry;
    private Client client;

    @Builder(toBuilder = true)
    public LocalRegexGuardrail(List stopWords, String[] regex) {
        this.stopWords = stopWords;
        this.regex = regex;
    }

    public LocalRegexGuardrail(@NonNull Map params) {
        List words = (List) params.get(STOP_WORDS_FIELD);
        stopWords = new ArrayList<>();
        if (words != null && !words.isEmpty()) {
            for (Map e : words) {
                stopWords.add(new StopWords(e));
            }
        }
        List regexes = (List) params.get(REGEX_FIELD);
        if (regexes != null && !regexes.isEmpty()) {
            this.regex = regexes.toArray(new String[0]);
        }
    }

    public LocalRegexGuardrail(StreamInput input) throws IOException {
        if (input.readBoolean()) {
            stopWords = new ArrayList<>();
            int size = input.readInt();
            for (int i = 0; i < size; i++) {
                stopWords.add(new StopWords(input));
            }
        }
        regex = input.readOptionalStringArray();
    }

    public void writeTo(StreamOutput out) throws IOException {
        if (stopWords != null && stopWords.size() > 0) {
            out.writeBoolean(true);
            out.writeInt(stopWords.size());
            for (StopWords e : stopWords) {
                e.writeTo(out);
            }
        } else {
            out.writeBoolean(false);
        }
        out.writeOptionalStringArray(regex);
    }

    @Override
    public Boolean validate(String input, Map parameters) {
        return validateRegexList(input, regexPattern) && validateStopWords(input, stopWordsIndicesInput);
    }

    @Override
    public void init(NamedXContentRegistry xContentRegistry, Client client) {
        this.xContentRegistry = xContentRegistry;
        this.client = client;
        init();
    }

    @Override
    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
        builder.startObject();
        if (stopWords != null && stopWords.size() > 0) {
            builder.field(STOP_WORDS_FIELD, stopWords);
        }
        if (regex != null) {
            builder.field(REGEX_FIELD, regex);
        }
        builder.endObject();
        return builder;
    }

    public static LocalRegexGuardrail parse(XContentParser parser) throws IOException {
        List stopWords = null;
        String[] regex = null;

        ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
        while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
            String fieldName = parser.currentName();
            parser.nextToken();

            switch (fieldName) {
                case STOP_WORDS_FIELD:
                    stopWords = new ArrayList<>();
                    ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
                    while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
                        stopWords.add(StopWords.parse(parser));
                    }
                    break;
                case REGEX_FIELD:
                    regex = parser.list().toArray(new String[0]);
                    break;
                default:
                    parser.skipChildren();
                    break;
            }
        }
        return LocalRegexGuardrail.builder().stopWords(stopWords).regex(regex).build();
    }

    private void init() {
        stopWordsIndicesInput = stopWordsToMap();
        List regexList = regex == null ? new ArrayList<>() : Arrays.asList(regex);
        regexPattern = regexList.stream().map(reg -> Pattern.compile(reg)).collect(Collectors.toList());
    }

    private Map> stopWordsToMap() {
        Map> map = new HashMap<>();
        if (stopWords != null && !stopWords.isEmpty()) {
            for (StopWords e : stopWords) {
                if (e.getIndex() != null && e.getSourceFields() != null) {
                    map.put(e.getIndex(), Arrays.asList(e.getSourceFields()));
                }
            }
        }
        return map;
    }

    public Boolean validateRegexList(String input, List regexPatterns) {
        if (regexPatterns == null || regexPatterns.isEmpty()) {
            return true;
        }
        for (Pattern pattern : regexPatterns) {
            if (!validateRegex(input, pattern)) {
                return false;
            }
        }
        return true;
    }

    public Boolean validateRegex(String input, Pattern pattern) {
        Matcher matcher = pattern.matcher(input);
        return !matcher.matches();
    }

    public Boolean validateStopWords(String input, Map> stopWordsIndices) {
        if (stopWordsIndices == null || stopWordsIndices.isEmpty()) {
            return true;
        }
        for (Map.Entry entry : stopWordsIndices.entrySet()) {
            if (!validateStopWordsSingleIndex(input, (String) entry.getKey(), (List) entry.getValue())) {
                return false;
            }
        }
        return true;
    }

    /**
     * Validate the input string against stop words
     * @param input the string to validate against stop words
     * @param indexName the index containing stop words
     * @param fieldNames a list of field names containing stop words
     * @return true if no stop words matching, otherwise false.
     */
    public Boolean validateStopWordsSingleIndex(String input, String indexName, List fieldNames) {
        SearchRequest searchRequest;
        AtomicBoolean hitStopWords = new AtomicBoolean(false);
        String queryBody;
        Map documentMap = new HashMap<>();
        for (String field : fieldNames) {
            documentMap.put(field, input);
        }
        Map queryBodyMap = Map.of("query", Map.of("percolate", Map.of("field", "query", "document", documentMap)));
        CountDownLatch latch = new CountDownLatch(1);
        ThreadContext.StoredContext context = null;

        try {
            queryBody = AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(queryBodyMap));
            SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
            XContentParser queryParser = XContentType.JSON
                .xContent()
                .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, queryBody);
            searchSourceBuilder.parseXContent(queryParser);
            searchSourceBuilder.size(1); // Only need 1 doc returned, if hit.
            searchRequest = new SearchRequest().source(searchSourceBuilder).indices(indexName);
            if (isStopWordsSystemIndex(indexName)) {
                context = client.threadPool().getThreadContext().stashContext();
                ThreadContext.StoredContext finalContext = context;
                client.search(searchRequest, ActionListener.runBefore(new LatchedActionListener(ActionListener.wrap(r -> {
                    if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) {
                        hitStopWords.set(true);
                    }
                }, e -> {
                    log.error("Failed to search stop words index {}", indexName, e);
                    hitStopWords.set(true);
                }), latch), () -> finalContext.restore()));
            } else {
                client.search(searchRequest, new LatchedActionListener(ActionListener.wrap(r -> {
                    if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) {
                        hitStopWords.set(true);
                    }
                }, e -> {
                    log.error("Failed to search stop words index {}", indexName, e);
                    hitStopWords.set(true);
                }), latch));
            }
        } catch (Exception e) {
            log.error("[validateStopWords] Searching stop words index failed.", e);
            latch.countDown();
            hitStopWords.set(true);
        } finally {
            if (context != null) {
                context.close();
            }
        }

        try {
            latch.await(5, SECONDS);
        } catch (InterruptedException e) {
            log.error("[validateStopWords] Searching stop words index was timeout.", e);
            throw new IllegalStateException(e);
        }
        return hitStopWords.get();
    }

    private boolean isStopWordsSystemIndex(String index) {
        return stopWordsIndices.contains(index);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy