org.languagetool.rules.OpenNMTRule Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of languagetool-core Show documentation
Show all versions of languagetool-core Show documentation
LanguageTool is an Open Source proofreading software for English, French, German, Polish, Romanian, and more than 20 other languages. It finds many errors that a simple spell checker cannot detect like mixing up there/their and it detects some grammar problems.
/* LanguageTool, a natural language style checker
* Copyright (C) 2017 Daniel Naber (http://www.danielnaber.de)
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
* USA
*/
package org.languagetool.rules;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.common.base.Charsets;
import com.google.common.io.CharStreams;
import org.jetbrains.annotations.NotNull;
import org.languagetool.AnalyzedSentence;
import org.languagetool.Experimental;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.HttpURLConnection;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;
/**
* Queries an OpenNMT server started like this:
* {@code th tools/rest_translation_server.lua -replace_unk -model ...}
*/
@Experimental
public class OpenNMTRule extends Rule {
private static final String defaultServerUrl = "http://127.0.0.1:7784/translator/translate";
private final String serverUrl;
private final ObjectMapper mapper = new ObjectMapper();
/**
* @param serverUrl URL of the OpenNMT server, like {@code http://127.0.0.1:7784/translator/translate}
*/
public OpenNMTRule(String serverUrl) {
this.serverUrl = serverUrl;
setDefaultOff();
}
/**
* Expects an OpenNMT server running at http://127.0.0.1:7784/translator/translate
*/
public OpenNMTRule() {
this(defaultServerUrl);
}
@Override
public String getId() {
return "OPENNMT_RULE";
}
@Override
public String getDescription() {
return "Get corrections from an OpenNMT server (beta)";
}
@Override
public RuleMatch[] match(AnalyzedSentence sentence) throws IOException {
// TODO: send all sentences at once for less overhead?
String json = createJson(sentence);
URL url = new URL(serverUrl);
HttpURLConnection conn = postToServer(json, url);
int responseCode = conn.getResponseCode();
if (responseCode == 200) {
InputStream inputStr = conn.getInputStream();
JsonNode response = mapper.readTree(inputStr);
String translation = response.get(0).get(0).get("tgt").textValue();
if (translation.contains("")) {
throw new RuntimeException("'' found in translation - please start the OpenNMT server with the -replace_unk option");
}
// TODO: whitespace is introduced and needs to be properly removed - we should use the 'src'
// key for comparison, but we'll still need to clean up when using the suggestion...
translation = detokenize(translation);
String cleanTranslation = translation.replaceAll(" ([.,;:])", "$1");
String sentenceText = sentence.getText();
if (!cleanTranslation.trim().equals(sentenceText.trim())) {
List ruleMatches = new ArrayList<>();
int from = getLeftWordBoundary(sentenceText, getFirstDiffPosition(sentenceText, cleanTranslation));
int to = getRightWordBoundary(sentenceText, getLastDiffPosition(sentenceText, cleanTranslation));
int replacementTo = getRightWordBoundary(cleanTranslation, getLastDiffPosition(cleanTranslation, sentenceText));
String message = "OpenNMT suggests that this might(!) be better phrased differently, please check.";
RuleMatch ruleMatch = new RuleMatch(this, sentence, from, to, message);
ruleMatch.setSuggestedReplacement(cleanTranslation.substring(from, replacementTo));
ruleMatches.add(ruleMatch);
return toRuleMatchArray(ruleMatches);
}
return new RuleMatch[0];
} else {
InputStream inputStr = conn.getErrorStream();
String error = CharStreams.toString(new InputStreamReader(inputStr, Charsets.UTF_8));
throw new RuntimeException("Got error " + responseCode + " from " + url + ": " + error);
}
}
private String detokenize(String translation) {
return translation
.replaceAll(" 's", "'s")
.replaceAll(" ' s", "'s")
.replaceAll(" ' t", "'t")
.replace(" .", ".")
.replace(" ,", ",");
}
int getFirstDiffPosition(String text1, String text2) {
for (int i = 0; i < text1.length(); i++) {
if (i < text2.length()) {
if (text1.charAt(i) != text2.charAt(i)) {
return i;
}
} else {
return i;
}
}
if (text1.length() != text2.length()) {
return text1.length();
}
return -1;
}
int getLastDiffPosition(String text1, String text2) {
StringBuilder reverse1 = new StringBuilder(text1.trim()).reverse();
StringBuilder reverse2 = new StringBuilder(text2.trim()).reverse();
int revDiffPos = getFirstDiffPosition(reverse1.toString(), reverse2.toString());
if (revDiffPos != -1) {
return text1.length() - revDiffPos;
} else {
return -1;
}
}
int getLeftWordBoundary(String text, int pos) {
while (pos >= 1) {
if (Character.isAlphabetic(text.charAt(pos - 1))) {
pos--;
} else {
break;
}
}
return pos;
}
int getRightWordBoundary(String text, int pos) {
while (pos >= 0 && pos < text.length()) {
if (Character.isAlphabetic(text.charAt(pos))) {
pos++;
} else {
break;
}
}
return pos;
}
private String createJson(AnalyzedSentence sentence) throws JsonProcessingException {
ArrayNode list = mapper.createArrayNode();
ObjectNode node = list.addObject();
node.put("src", sentence.getText());
return mapper.writerWithDefaultPrettyPrinter().writeValueAsString(list);
}
@NotNull
private HttpURLConnection postToServer(String json, URL url) throws IOException {
HttpURLConnection conn = (HttpURLConnection)url.openConnection();
conn.setRequestMethod("POST");
conn.setUseCaches(false);
conn.setRequestProperty("Content-Type", "application/json");
conn.setDoOutput(true);
try {
try (DataOutputStream dos = new DataOutputStream(conn.getOutputStream())) {
//System.out.println("POSTING: " + json);
dos.write(json.getBytes("utf-8"));
}
} catch (IOException e) {
throw new RuntimeException("Could not connect OpenNMT server at " + url);
}
return conn;
}
}