Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.opensearch.ml.common.model.LocalRegexGuardrail Maven / Gradle / Ivy
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.ml.common.model;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.stopWordsIndices;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.opensearch.action.LatchedActionListener;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.search.builder.SearchSourceBuilder;
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NonNull;
import lombok.extern.log4j.Log4j2;
@Log4j2
@EqualsAndHashCode
@Getter
public class LocalRegexGuardrail extends Guardrail {
public static final String STOP_WORDS_FIELD = "stop_words";
public static final String REGEX_FIELD = "regex";
private List stopWords;
private String[] regex;
private List regexPattern;
private Map> stopWordsIndicesInput;
private NamedXContentRegistry xContentRegistry;
private Client client;
@Builder(toBuilder = true)
public LocalRegexGuardrail(List stopWords, String[] regex) {
this.stopWords = stopWords;
this.regex = regex;
}
public LocalRegexGuardrail(@NonNull Map params) {
List words = (List) params.get(STOP_WORDS_FIELD);
stopWords = new ArrayList<>();
if (words != null && !words.isEmpty()) {
for (Map e : words) {
stopWords.add(new StopWords(e));
}
}
List regexes = (List) params.get(REGEX_FIELD);
if (regexes != null && !regexes.isEmpty()) {
this.regex = regexes.toArray(new String[0]);
}
}
public LocalRegexGuardrail(StreamInput input) throws IOException {
if (input.readBoolean()) {
stopWords = new ArrayList<>();
int size = input.readInt();
for (int i = 0; i < size; i++) {
stopWords.add(new StopWords(input));
}
}
regex = input.readOptionalStringArray();
}
public void writeTo(StreamOutput out) throws IOException {
if (stopWords != null && stopWords.size() > 0) {
out.writeBoolean(true);
out.writeInt(stopWords.size());
for (StopWords e : stopWords) {
e.writeTo(out);
}
} else {
out.writeBoolean(false);
}
out.writeOptionalStringArray(regex);
}
@Override
public Boolean validate(String input, Map parameters) {
return validateRegexList(input, regexPattern) && validateStopWords(input, stopWordsIndicesInput);
}
@Override
public void init(NamedXContentRegistry xContentRegistry, Client client) {
this.xContentRegistry = xContentRegistry;
this.client = client;
init();
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (stopWords != null && stopWords.size() > 0) {
builder.field(STOP_WORDS_FIELD, stopWords);
}
if (regex != null) {
builder.field(REGEX_FIELD, regex);
}
builder.endObject();
return builder;
}
public static LocalRegexGuardrail parse(XContentParser parser) throws IOException {
List stopWords = null;
String[] regex = null;
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();
switch (fieldName) {
case STOP_WORDS_FIELD:
stopWords = new ArrayList<>();
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
stopWords.add(StopWords.parse(parser));
}
break;
case REGEX_FIELD:
regex = parser.list().toArray(new String[0]);
break;
default:
parser.skipChildren();
break;
}
}
return LocalRegexGuardrail.builder().stopWords(stopWords).regex(regex).build();
}
private void init() {
stopWordsIndicesInput = stopWordsToMap();
List regexList = regex == null ? new ArrayList<>() : Arrays.asList(regex);
regexPattern = regexList.stream().map(reg -> Pattern.compile(reg)).collect(Collectors.toList());
}
private Map> stopWordsToMap() {
Map> map = new HashMap<>();
if (stopWords != null && !stopWords.isEmpty()) {
for (StopWords e : stopWords) {
if (e.getIndex() != null && e.getSourceFields() != null) {
map.put(e.getIndex(), Arrays.asList(e.getSourceFields()));
}
}
}
return map;
}
public Boolean validateRegexList(String input, List regexPatterns) {
if (regexPatterns == null || regexPatterns.isEmpty()) {
return true;
}
for (Pattern pattern : regexPatterns) {
if (!validateRegex(input, pattern)) {
return false;
}
}
return true;
}
public Boolean validateRegex(String input, Pattern pattern) {
Matcher matcher = pattern.matcher(input);
return !matcher.matches();
}
public Boolean validateStopWords(String input, Map> stopWordsIndices) {
if (stopWordsIndices == null || stopWordsIndices.isEmpty()) {
return true;
}
for (Map.Entry entry : stopWordsIndices.entrySet()) {
if (!validateStopWordsSingleIndex(input, (String) entry.getKey(), (List) entry.getValue())) {
return false;
}
}
return true;
}
/**
* Validate the input string against stop words
* @param input the string to validate against stop words
* @param indexName the index containing stop words
* @param fieldNames a list of field names containing stop words
* @return true if no stop words matching, otherwise false.
*/
public Boolean validateStopWordsSingleIndex(String input, String indexName, List fieldNames) {
SearchRequest searchRequest;
AtomicBoolean hitStopWords = new AtomicBoolean(false);
String queryBody;
Map documentMap = new HashMap<>();
for (String field : fieldNames) {
documentMap.put(field, input);
}
Map queryBodyMap = Map.of("query", Map.of("percolate", Map.of("field", "query", "document", documentMap)));
CountDownLatch latch = new CountDownLatch(1);
ThreadContext.StoredContext context = null;
try {
queryBody = AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(queryBodyMap));
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
XContentParser queryParser = XContentType.JSON
.xContent()
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, queryBody);
searchSourceBuilder.parseXContent(queryParser);
searchSourceBuilder.size(1); // Only need 1 doc returned, if hit.
searchRequest = new SearchRequest().source(searchSourceBuilder).indices(indexName);
if (isStopWordsSystemIndex(indexName)) {
context = client.threadPool().getThreadContext().stashContext();
ThreadContext.StoredContext finalContext = context;
client.search(searchRequest, ActionListener.runBefore(new LatchedActionListener(ActionListener.wrap(r -> {
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) {
hitStopWords.set(true);
}
}, e -> {
log.error("Failed to search stop words index {}", indexName, e);
hitStopWords.set(true);
}), latch), () -> finalContext.restore()));
} else {
client.search(searchRequest, new LatchedActionListener(ActionListener.wrap(r -> {
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) {
hitStopWords.set(true);
}
}, e -> {
log.error("Failed to search stop words index {}", indexName, e);
hitStopWords.set(true);
}), latch));
}
} catch (Exception e) {
log.error("[validateStopWords] Searching stop words index failed.", e);
latch.countDown();
hitStopWords.set(true);
} finally {
if (context != null) {
context.close();
}
}
try {
latch.await(5, SECONDS);
} catch (InterruptedException e) {
log.error("[validateStopWords] Searching stop words index was timeout.", e);
throw new IllegalStateException(e);
}
return hitStopWords.get();
}
private boolean isStopWordsSystemIndex(String index) {
return stopWordsIndices.contains(index);
}
}