All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.languagetool.rules.BERTSuggestionRanking Maven / Gradle / Ivy

Go to download

LanguageTool is an Open Source proofreading software for English, French, German, Polish, Romanian, and more than 20 other languages. It finds many errors that a simple spell checker cannot detect like mixing up there/their and it detects some grammar problems.

There is a newer version: 6.5
Show newest version
/*
 *  LanguageTool, a natural language style checker
 *  * Copyright (C) 2018 Fabian Richter
 *  *
 *  * This library is free software; you can redistribute it and/or
 *  * modify it under the terms of the GNU Lesser General Public
 *  * License as published by the Free Software Foundation; either
 *  * version 2.1 of the License, or (at your option) any later version.
 *  *
 *  * This library is distributed in the hope that it will be useful,
 *  * but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 *  * Lesser General Public License for more details.
 *  *
 *  * You should have received a copy of the GNU Lesser General Public
 *  * License along with this library; if not, write to the Free Software
 *  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301
 *  * USA
 *
 */

package org.languagetool.rules;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.Streams;
import org.apache.commons.lang3.tuple.Pair;
import org.jetbrains.annotations.Nullable;
import org.languagetool.AnalyzedSentence;
import org.languagetool.Language;
import org.languagetool.languagemodel.bert.RemoteLanguageModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.net.ssl.SSLException;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.Callable;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;

/**
 * reorder suggestions from another rule using BERT as a LM
 */
public class BERTSuggestionRanking extends RemoteRule {

  // only for RemoteRuleConfig
  public static final String RULE_ID = "BERT_SUGGESTION_RANKING";

  private static final Logger logger = LoggerFactory.getLogger(BERTSuggestionRanking.class);

  private static final LoadingCache models =
    CacheBuilder.newBuilder().build(CacheLoader.from(serviceConfiguration -> {
      String host = serviceConfiguration.getUrl();
      int port = serviceConfiguration.getPort();
      boolean ssl = Boolean.parseBoolean(serviceConfiguration.getOptions().getOrDefault("secure", "false"));
      String key = serviceConfiguration.getOptions().get("clientKey");
      String cert = serviceConfiguration.getOptions().get("clientCertificate");
      String ca = serviceConfiguration.getOptions().get("rootCertificate");
      try {
        return new RemoteLanguageModel(host, port, ssl, key, cert, ca);
      } catch (SSLException e) {
        throw new RuntimeException(e);
      }
    }));

  static {
    shutdownRoutines.add(() -> models.asMap().values().forEach(RemoteLanguageModel::shutdown));
  }

  // default behavior for prepareSuggestions: limit to top n candidates
  protected int suggestionLimit = 10;
  private final RemoteLanguageModel model;
  private final Rule wrappedRule;

  public BERTSuggestionRanking(Language language, Rule rule, RemoteRuleConfig config, boolean inputLogging) {
    super(language, rule.messages, config, inputLogging, rule.getId());
    this.wrappedRule = rule;
    super.setCategory(wrappedRule.getCategory());
    synchronized (models) {
      RemoteLanguageModel model = null;
      try {
        model = models.get(serviceConfiguration);
      } catch (Exception e) {
        logger.error("Could not connect to BERT service at " + serviceConfiguration + " for suggestion reranking", e);
      }
      this.model = model;
    }
  }

  class MatchesForReordering extends RemoteRequest {
    final List sentences;
    final List matches;
    final List requests;
    MatchesForReordering(List sentences, List matches, List requests) {
      this.sentences = sentences;
      this.matches = matches;
      this.requests = requests;
    }
  }

  /**
   * transform suggestions before resorting, e.g. limit resorting to top-n candidates
   * @return transformed suggestions
   */
  protected List prepareSuggestions(List suggestions) {
    // include more suggestions for resorting if there are translations included as original order isn't that good
    if (suggestions.stream().anyMatch(s -> s.getType() == SuggestedReplacement.SuggestionType.Translation)) {
      suggestionLimit = 25;
    } else {
      suggestionLimit = 10;
    }
    return suggestions.subList(0, Math.min(suggestions.size(), suggestionLimit));
  }

  private static final int MIN_WORDS = 8;
  private static final double MAX_ERROR_RATE = 0.5;

  @Override
  protected RemoteRequest prepareRequest(List sentences, Long textSessionId) {
    List matches = new LinkedList<>();
    int totalWords = 0;
    try {
      for (AnalyzedSentence sentence : sentences) {
        RuleMatch[] sentenceMatches = wrappedRule.match(sentence);
        Collections.addAll(matches, sentenceMatches);

        // computing many suggestions (e.g. for requests with the language incorrectly set, or with garbled input) is very expensive
        // having this as a RemoteRule circumvents the normal ErrorRateTooHighException
        // so build this logic again here
        int words = sentence.getTokensWithoutWhitespace().length;
        totalWords += words;
        if (words > MIN_WORDS && (double) sentenceMatches.length / words > MAX_ERROR_RATE) {
          for (RuleMatch m : sentenceMatches) {
            m.discardLazySuggestedReplacements();
          }
          logger.info("Skipping suggestion generation for sentence, too many matches ({} matches in {} words)",
            sentenceMatches.length, words);
        }
      }
    } catch (IOException e) {
      logger.error("Error while executing rule " + wrappedRule.getId(), e);
      return new MatchesForReordering(sentences, Collections.emptyList(), Collections.emptyList());
    }
    if (totalWords > MIN_WORDS && (double) matches.size() / totalWords > MAX_ERROR_RATE) {
      logger.info("Skipping suggestion generation for request, too many matches ({} matches in {} words)",
        matches.size(), totalWords);
      matches.forEach(RuleMatch::discardLazySuggestedReplacements);
      return new MatchesForReordering(sentences, matches, Collections.emptyList());
    }
    List requests = new LinkedList<>();
    for (RuleMatch match : matches) {
      match.setSuggestedReplacementObjects(prepareSuggestions(match.getSuggestedReplacementObjects()));
      requests.add(buildRequest(match));
    }
    return new MatchesForReordering(sentences, matches, requests);
  }

  @Override
  protected RemoteRuleResult fallbackResults(RemoteRequest request) {
    MatchesForReordering req = (MatchesForReordering) request;
    return new RemoteRuleResult(false, false, req.matches, req.sentences);
  }

  @Override
  protected Callable executeRequest(RemoteRequest request, long timeoutMilliseconds) throws TimeoutException {
    return () -> {
      if (model == null) {
        return fallbackResults(request);
      }
      MatchesForReordering data = (MatchesForReordering) request;
      List matches = data.matches;
      List requests = data.requests;
      Streams.FunctionWithIndex mapIndices = (req, index) -> req != null ? index : null;
      List indices = Streams.mapWithIndex(requests.stream(), mapIndices)
        .filter(Objects::nonNull).collect(Collectors.toList());
      requests = requests.stream().filter(Objects::nonNull).collect(Collectors.toList());

      if (requests.isEmpty()) {
        return new RemoteRuleResult(false, true, matches, data.sentences);
      } else {
        List> results = model.batchScore(requests, timeoutMilliseconds);
        // put curated at the top, then compare probabilities
        for (int i = 0; i < indices.size(); i++) {
          List scores = results.get(i);
          String userWord = requests.get(i).text.substring(requests.get(i).start, requests.get(i).end);
          RuleMatch match = matches.get(indices.get(i).intValue());
          //RemoteLanguageModel.Request req = requests.get(i);
          //String error = req.text.substring(req.start, req.end);
          //logger.info("Scored suggestions for '{}': {} -> {}", error, match.getSuggestedReplacements(), Streams
          //  .zip(match.getSuggestedReplacementObjects().stream(), scores.stream(), Pair::of)
          //  .sorted(new CuratedAndSameCaseComparator(userWord))
          //  .map(scored -> String.format("%s (%e)", scored.getLeft().getReplacement(), scored.getRight()))
          //  .collect(Collectors.toList()));
          List ranked = Streams
            .zip(match.getSuggestedReplacementObjects().stream(), scores.stream(), Pair::of)
            .sorted(new CuratedAndSameCaseComparator(userWord))
            .map(Pair::getLeft)
            .collect(Collectors.toList());
          //logger.info("Reordered correction for '{}' from {} to {}", error, req.candidates, ranked);
          match.setSuggestedReplacementObjects(ranked);
        }
        return new RemoteRuleResult(true, true, matches, data.sentences);
      }
    };
  }

  @Nullable
  private RemoteLanguageModel.Request buildRequest(RuleMatch match) {
    List suggestions = match.getSuggestedReplacements();
    if (suggestions != null && suggestions.size() > 1) {
      return new RemoteLanguageModel.Request(
        match.getSentence().getText(), match.getFromPos(), match.getToPos(), suggestions);
    } else {
      return null;
    }
  }

  @Override
  public String getId() {
    // return values of wrapped rule so that enabling/disabling rules works
    return wrappedRule.getId();
  }

  @Override
  public String getDescription() {
    return wrappedRule.getDescription();
  }

  private static class CuratedAndSameCaseComparator implements Comparator> {
    private final String userWord;
    CuratedAndSameCaseComparator(String userWord) {
      this.userWord = userWord;
    }
    @Override
    public int compare(Pair a, Pair b) {
      //System.out.println(userWord + " -- " + b.getKey().getReplacement());
      if (a.getKey().getReplacement().equalsIgnoreCase(userWord)) {
        return -1;
      } else if (b.getKey().getReplacement().equalsIgnoreCase(userWord)) {
        return 1;
      } else if (a.getKey().getType() != b.getKey().getType()) {
        if (a.getKey().getType() == SuggestedReplacement.SuggestionType.Curated) {
          return -1;
        } else if (b.getKey().getType() == SuggestedReplacement.SuggestionType.Curated) {
          return 1;
        } else {
          return b.getRight().compareTo(a.getRight());
        }
      } else {
        return b.getRight().compareTo(a.getRight());
      }
    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy