All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.languagetool.rules.BERTSuggestionRanking Maven / Gradle / Ivy

Go to download

LanguageTool is an Open Source proofreading software for English, French, German, Polish, Romanian, and more than 20 other languages. It finds many errors that a simple spell checker cannot detect like mixing up there/their and it detects some grammar problems.

There is a newer version: 6.5
Show newest version
/*
 *  LanguageTool, a natural language style checker
 *  * Copyright (C) 2018 Fabian Richter
 *  *
 *  * This library is free software; you can redistribute it and/or
 *  * modify it under the terms of the GNU Lesser General Public
 *  * License as published by the Free Software Foundation; either
 *  * version 2.1 of the License, or (at your option) any later version.
 *  *
 *  * This library is distributed in the hope that it will be useful,
 *  * but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 *  * Lesser General Public License for more details.
 *  *
 *  * You should have received a copy of the GNU Lesser General Public
 *  * License along with this library; if not, write to the Free Software
 *  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301
 *  * USA
 *
 */

package org.languagetool.rules;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.Streams;
import org.apache.commons.lang3.tuple.Pair;
import org.jetbrains.annotations.Nullable;
import org.languagetool.AnalyzedSentence;
import org.languagetool.UserConfig;
import org.languagetool.languagemodel.bert.RemoteLanguageModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.net.ssl.SSLException;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;

/**
 * reorder suggestions from another rule using BERT as a LM
 */
public class BERTSuggestionRanking extends RemoteRule {
  private static final Logger logger = LoggerFactory.getLogger(BERTSuggestionRanking.class);
  public static final String RULE_ID = "BERT_SUGGESTION_RANKING";

  private static final LoadingCache models =
    CacheBuilder.newBuilder().build(CacheLoader.from(serviceConfiguration -> {
      String host = serviceConfiguration.getUrl();
      int port = serviceConfiguration.getPort();
      boolean ssl = Boolean.parseBoolean(serviceConfiguration.getOptions().getOrDefault("secure", "false"));
      String key = serviceConfiguration.getOptions().get("clientKey");
      String cert = serviceConfiguration.getOptions().get("clientCertificate");
      String ca = serviceConfiguration.getOptions().get("rootCertificate");
      try {
        return new RemoteLanguageModel(host, port, ssl, key, cert, ca);
      } catch (SSLException e) {
        throw new RuntimeException(e);
      }
    }));

  static {
    shutdownRoutines.add(() -> models.asMap().values().forEach(RemoteLanguageModel::shutdown));
  }

  // default behavior for prepareSuggestions: limit to top n candidates
  protected int suggestionLimit = 10;
  private final RemoteLanguageModel model;
  private final Rule wrappedRule;

  public BERTSuggestionRanking(Rule rule, RemoteRuleConfig config, UserConfig userConfig) {
    super(rule.messages, config);
    this.wrappedRule = rule;

    synchronized (models) {
      RemoteLanguageModel model = null;
      if (getId().equals(userConfig.getAbTest())) {
        try {
          model = models.get(serviceConfiguration);
        } catch (Exception e) {
          logger.error("Could not connect to BERT service at " + serviceConfiguration + " for suggestion reranking", e);
        }
      }
      this.model = model;
    }
  }

  class MatchesForReordering extends RemoteRequest {
    final List matches;
    final List requests;
    MatchesForReordering(List matches, List requests) {
      this.matches = matches;
      this.requests = requests;
    }
  }

  /**
   * transform suggestions before resorting, e.g. limit resorting to top-n candidates
   * @return transformed suggestions
   */
  protected List prepareSuggestions(List suggestions) {
    // include more suggestions for resorting if there are translations included as original order isn't that good
    if (suggestions.stream().anyMatch(s -> s.getType() == SuggestedReplacement.SuggestionType.Translation)) {
      suggestionLimit = 25;
    } else {
      suggestionLimit = 10;
    }
    return suggestions.subList(0, Math.min(suggestions.size(), suggestionLimit));
  }

  @Override
  protected RemoteRequest prepareRequest(List sentences) {
    List matches = new LinkedList<>();
    List requests = new LinkedList<>();
    try {
      int offset = 0;
      for (AnalyzedSentence sentence : sentences) {
        RuleMatch[] sentenceMatches = wrappedRule.match(sentence);
        for (RuleMatch match : sentenceMatches) {
          match.setSuggestedReplacementObjects(prepareSuggestions(match.getSuggestedReplacementObjects()));
          // build request before correcting offset, as we send only sentence as text
          requests.add(buildRequest(match));
          match.setOffsetPosition(match.getFromPos() + offset, match.getToPos() + offset);
        }
        Collections.addAll(matches, sentenceMatches);
        offset += sentence.getText().length();
      }
      return new MatchesForReordering(matches, requests);
    } catch (IOException e) {
      logger.error("Error while executing rule " + wrappedRule.getId(), e);
      return new MatchesForReordering(Collections.emptyList(), Collections.emptyList());
    }
  }

  @Override
  protected RemoteRuleResult fallbackResults(RemoteRequest request) {
    return new RemoteRuleResult(false, ((MatchesForReordering) request).matches);
  }

  @Override
  protected Callable executeRequest(RemoteRequest request) {
    return () -> {
      if (model == null) {
        return fallbackResults(request);
      }
      MatchesForReordering data = (MatchesForReordering) request;
      List matches = data.matches;
      List requests = data.requests;
      Streams.FunctionWithIndex mapIndices = (req, index) -> req != null ? index : null;
      List indices = Streams.mapWithIndex(requests.stream(), mapIndices)
        .filter(Objects::nonNull).collect(Collectors.toList());
      requests = requests.stream().filter(Objects::nonNull).collect(Collectors.toList());

      if (requests.isEmpty()) {
        return new RemoteRuleResult(false, matches);
      } else {
        List> results = model.batchScore(requests);
        Comparator> suggestionOrdering = Comparator.comparing(Pair::getRight);
        suggestionOrdering = suggestionOrdering.reversed();

        for (int i = 0; i < indices.size(); i++) {
          List scores = results.get(i);
          RemoteLanguageModel.Request req = requests.get(i);
          RuleMatch match = matches.get(indices.get(i).intValue());
          String error = req.text.substring(req.start, req.end);
          logger.info("Scored suggestions for '{}': {} -> {}", error, match.getSuggestedReplacements(), Streams
            .zip(match.getSuggestedReplacementObjects().stream(), scores.stream(), Pair::of)
            .sorted(suggestionOrdering)
            .map(scored -> String.format("%s (%e)", scored.getLeft().getReplacement(), scored.getRight()))
            .collect(Collectors.toList()));
          List ranked = Streams
            .zip(match.getSuggestedReplacementObjects().stream(), scores.stream(), Pair::of)
            .sorted(suggestionOrdering)
            .map(Pair::getLeft)
            .collect(Collectors.toList());
          //logger.info("Reordered correction for '{}' from {} to {}", error, req.candidates, ranked);
          match.setSuggestedReplacementObjects(ranked);
        }
        return new RemoteRuleResult(true, matches);
      }
    };
  }

  @Nullable
  private RemoteLanguageModel.Request buildRequest(RuleMatch match) {
    List suggestions = match.getSuggestedReplacements();
    if (suggestions != null && suggestions.size() > 1) {
      return new RemoteLanguageModel.Request(
        match.getSentence().getText(), match.getFromPos(), match.getToPos(), suggestions);
    } else {
      return null;
    }
  }

  @Override
  public String getId() {
    return RULE_ID;
  }

  @Override
  public String getDescription() {
    return "Suggestion reordering based on the BERT model";
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy