org.languagetool.rules.BERTSuggestionRanking Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of languagetool-core Show documentation
Show all versions of languagetool-core Show documentation
LanguageTool is an Open Source proofreading software for English, French, German, Polish, Romanian, and more than 20 other languages. It finds many errors that a simple spell checker cannot detect like mixing up there/their and it detects some grammar problems.
/*
* LanguageTool, a natural language style checker
* * Copyright (C) 2018 Fabian Richter
* *
* * This library is free software; you can redistribute it and/or
* * modify it under the terms of the GNU Lesser General Public
* * License as published by the Free Software Foundation; either
* * version 2.1 of the License, or (at your option) any later version.
* *
* * This library is distributed in the hope that it will be useful,
* * but WITHOUT ANY WARRANTY; without even the implied warranty of
* * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* * Lesser General Public License for more details.
* *
* * You should have received a copy of the GNU Lesser General Public
* * License along with this library; if not, write to the Free Software
* * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
* * USA
*
*/
package org.languagetool.rules;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.Streams;
import org.apache.commons.lang3.tuple.Pair;
import org.jetbrains.annotations.Nullable;
import org.languagetool.AnalyzedSentence;
import org.languagetool.UserConfig;
import org.languagetool.languagemodel.bert.RemoteLanguageModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.net.ssl.SSLException;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;
/**
* reorder suggestions from another rule using BERT as a LM
*/
public class BERTSuggestionRanking extends RemoteRule {
private static final Logger logger = LoggerFactory.getLogger(BERTSuggestionRanking.class);
public static final String RULE_ID = "BERT_SUGGESTION_RANKING";
private static final LoadingCache models =
CacheBuilder.newBuilder().build(CacheLoader.from(serviceConfiguration -> {
String host = serviceConfiguration.getUrl();
int port = serviceConfiguration.getPort();
boolean ssl = Boolean.parseBoolean(serviceConfiguration.getOptions().getOrDefault("secure", "false"));
String key = serviceConfiguration.getOptions().get("clientKey");
String cert = serviceConfiguration.getOptions().get("clientCertificate");
String ca = serviceConfiguration.getOptions().get("rootCertificate");
try {
return new RemoteLanguageModel(host, port, ssl, key, cert, ca);
} catch (SSLException e) {
throw new RuntimeException(e);
}
}));
static {
shutdownRoutines.add(() -> models.asMap().values().forEach(RemoteLanguageModel::shutdown));
}
// default behavior for prepareSuggestions: limit to top n candidates
protected int suggestionLimit = 10;
private final RemoteLanguageModel model;
private final Rule wrappedRule;
public BERTSuggestionRanking(Rule rule, RemoteRuleConfig config, UserConfig userConfig) {
super(rule.messages, config);
this.wrappedRule = rule;
synchronized (models) {
RemoteLanguageModel model = null;
if (getId().equals(userConfig.getAbTest())) {
try {
model = models.get(serviceConfiguration);
} catch (Exception e) {
logger.error("Could not connect to BERT service at " + serviceConfiguration + " for suggestion reranking", e);
}
}
this.model = model;
}
}
class MatchesForReordering extends RemoteRequest {
final List matches;
final List requests;
MatchesForReordering(List matches, List requests) {
this.matches = matches;
this.requests = requests;
}
}
/**
* transform suggestions before resorting, e.g. limit resorting to top-n candidates
* @return transformed suggestions
*/
protected List prepareSuggestions(List suggestions) {
// include more suggestions for resorting if there are translations included as original order isn't that good
if (suggestions.stream().anyMatch(s -> s.getType() == SuggestedReplacement.SuggestionType.Translation)) {
suggestionLimit = 25;
} else {
suggestionLimit = 10;
}
return suggestions.subList(0, Math.min(suggestions.size(), suggestionLimit));
}
@Override
protected RemoteRequest prepareRequest(List sentences) {
List matches = new LinkedList<>();
List requests = new LinkedList<>();
try {
int offset = 0;
for (AnalyzedSentence sentence : sentences) {
RuleMatch[] sentenceMatches = wrappedRule.match(sentence);
for (RuleMatch match : sentenceMatches) {
match.setSuggestedReplacementObjects(prepareSuggestions(match.getSuggestedReplacementObjects()));
// build request before correcting offset, as we send only sentence as text
requests.add(buildRequest(match));
match.setOffsetPosition(match.getFromPos() + offset, match.getToPos() + offset);
}
Collections.addAll(matches, sentenceMatches);
offset += sentence.getText().length();
}
return new MatchesForReordering(matches, requests);
} catch (IOException e) {
logger.error("Error while executing rule " + wrappedRule.getId(), e);
return new MatchesForReordering(Collections.emptyList(), Collections.emptyList());
}
}
@Override
protected RemoteRuleResult fallbackResults(RemoteRequest request) {
return new RemoteRuleResult(false, ((MatchesForReordering) request).matches);
}
@Override
protected Callable executeRequest(RemoteRequest request) {
return () -> {
if (model == null) {
return fallbackResults(request);
}
MatchesForReordering data = (MatchesForReordering) request;
List matches = data.matches;
List requests = data.requests;
Streams.FunctionWithIndex mapIndices = (req, index) -> req != null ? index : null;
List indices = Streams.mapWithIndex(requests.stream(), mapIndices)
.filter(Objects::nonNull).collect(Collectors.toList());
requests = requests.stream().filter(Objects::nonNull).collect(Collectors.toList());
if (requests.isEmpty()) {
return new RemoteRuleResult(false, matches);
} else {
List> results = model.batchScore(requests);
Comparator> suggestionOrdering = Comparator.comparing(Pair::getRight);
suggestionOrdering = suggestionOrdering.reversed();
for (int i = 0; i < indices.size(); i++) {
List scores = results.get(i);
RemoteLanguageModel.Request req = requests.get(i);
RuleMatch match = matches.get(indices.get(i).intValue());
String error = req.text.substring(req.start, req.end);
logger.info("Scored suggestions for '{}': {} -> {}", error, match.getSuggestedReplacements(), Streams
.zip(match.getSuggestedReplacementObjects().stream(), scores.stream(), Pair::of)
.sorted(suggestionOrdering)
.map(scored -> String.format("%s (%e)", scored.getLeft().getReplacement(), scored.getRight()))
.collect(Collectors.toList()));
List ranked = Streams
.zip(match.getSuggestedReplacementObjects().stream(), scores.stream(), Pair::of)
.sorted(suggestionOrdering)
.map(Pair::getLeft)
.collect(Collectors.toList());
//logger.info("Reordered correction for '{}' from {} to {}", error, req.candidates, ranked);
match.setSuggestedReplacementObjects(ranked);
}
return new RemoteRuleResult(true, matches);
}
};
}
@Nullable
private RemoteLanguageModel.Request buildRequest(RuleMatch match) {
List suggestions = match.getSuggestedReplacements();
if (suggestions != null && suggestions.size() > 1) {
return new RemoteLanguageModel.Request(
match.getSentence().getText(), match.getFromPos(), match.getToPos(), suggestions);
} else {
return null;
}
}
@Override
public String getId() {
return RULE_ID;
}
@Override
public String getDescription() {
return "Suggestion reordering based on the BERT model";
}
}