All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.github.mightguy.symspell.solr.component.SpellcheckComponent Maven / Gradle / Ivy

There is a newer version: 6.6.154
Show newest version

package io.github.mightguy.symspell.solr.component;

import io.github.mightguy.spellcheck.symspell.api.CharDistance;
import io.github.mightguy.spellcheck.symspell.api.DataHolder;
import io.github.mightguy.spellcheck.symspell.api.SpellChecker;
import io.github.mightguy.spellcheck.symspell.api.StringDistance;
import io.github.mightguy.spellcheck.symspell.common.DictionaryItem;
import io.github.mightguy.spellcheck.symspell.common.Murmur3HashFunction;
import io.github.mightguy.spellcheck.symspell.common.SpellCheckSettings;
import io.github.mightguy.spellcheck.symspell.common.SuggestionItem;
import io.github.mightguy.spellcheck.symspell.common.Verbosity;
import io.github.mightguy.spellcheck.symspell.common.WeightedDamerauLevenshteinDistance;
import io.github.mightguy.spellcheck.symspell.exception.SpellCheckException;
import io.github.mightguy.spellcheck.symspell.impl.InMemoryDataHolder;
import io.github.mightguy.spellcheck.symspell.impl.SymSpellCheck;
import io.github.mightguy.symspell.solr.eventlistner.CustomSpellCheckListner;
import io.github.mightguy.symspell.solr.utils.Constants;
import io.github.mightguy.symspell.solr.utils.SearchRequestUtil;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.solr.common.StringUtils;
import org.apache.solr.common.params.CommonParams;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.core.SolrCore;
import org.apache.solr.handler.component.ResponseBuilder;
import org.apache.solr.handler.component.SearchComponent;
import org.apache.solr.util.plugin.SolrCoreAware;

@Slf4j
public class SpellcheckComponent extends SearchComponent implements SolrCoreAware {

  private NamedList initParams;
  @Getter
  private SpellChecker spellChecker;
  private CustomSpellCheckListner customSpellCheckListner;
  private int threshold = 0;

  public static final String COMPONENT_NAME = "custom_spellcheck";

  @Override
  public void init(NamedList args) {
    super.init(args);
    this.initParams = args;
  }

  @Override
  public void prepare(ResponseBuilder rb) throws IOException {
    SolrParams params = rb.req.getParams();
    threshold = params.getInt(Constants.SPELLCHECK_THRESHOLD, 15);
    if (!params.getBool(COMPONENT_NAME, false)) {
      return;
    }
    try {
      if (params.getBool(Constants.SPELLCHECK_BUILD, false)) {
        customSpellCheckListner.reload(rb.req.getSearcher(), spellChecker);
        rb.rsp.add("command", "build");
      }
    } catch (SpellCheckException ex) {
      log.error("Unable to build spellcheck indexes");
      throw new IOException(ex);
    }

  }

  @Override
  public void process(ResponseBuilder rb) throws IOException {
    if (!rb.req.getParams().getBool(Constants.SPELLCHECK_ENABLE, true) || SearchRequestUtil
        .resultGreaterThanThreshold(rb.rsp, threshold)) {
      log.debug("Spellcheck is disbaled either by query or result  greater than threshold [{}]",
          threshold);
      return;
    }
    SolrParams params = rb.req.getParams();
    String q = params.get(Constants.SPELLCHECK_Q, params.get(CommonParams.Q));
    boolean sow = params.getBool(Constants.SPELLCHECK_SOW, true);
    List suggestions;
    try {
      if (sow) {
        suggestions = spellChecker.lookupCompound(q);
      } else {
        suggestions = spellChecker.lookupCompound(q, 2, false);
      }
      if (!CollectionUtils.isEmpty(suggestions)) {
        addToResponse(rb, suggestions);
      }
    } catch (SpellCheckException ex) {
      log.error("exception occured while looking for spelling suggestions");
      throw new IOException(ex);
    }

  }


  private void addToResponse(ResponseBuilder rb, List suggestions) {
    rb.rsp.add("spell_suggestions", toNamedList(suggestions));
  }

  private NamedList toNamedList(List suggestionItems) {
    NamedList result = new NamedList();
    if (CollectionUtils.isEmpty(suggestionItems)) {
      return result;
    }
    Map suggestions = suggestionItems.parallelStream().collect(
        Collectors.toMap(SuggestionItem::getTerm,
            s -> s.getCount() + "," + s.getDistance() + "," + s.getScore()));

    result.add("spellcheck", suggestions);
    result.add("correctlySpelled", suggestionItems.get(0).getDistance() == 0);
    return result;
  }

  @Override
  public String getDescription() {
    return "SymSpell based Spellchecker Component";
  }

  @Override
  public void inform(SolrCore core) {
    if (initParams == null) {
      return;
    }

    log.info("Initializing spell checkers");

    if (initParams.getName(0).equals("spellcheckers")) {
      Object cfg = initParams.getVal(0);
      if (cfg instanceof NamedList) {
        addSpellChecker(core, (NamedList) cfg);
      } else if (cfg instanceof Map) {
        addSpellChecker(core, new NamedList((Map) cfg));
      } else if (cfg instanceof List) {
        for (Object o : (List) cfg) {
          if (o instanceof Map) {
            addSpellChecker(core, new NamedList((Map) o));
          }
        }
      }
    }

    log.info("Spell checker  Initialization completed");
  }

  @Override
  public Category getCategory() {
    return Category.SPELLCHECKER;
  }

  private StringDistance getStringDistance(NamedList spellchecker,
      SpellCheckSettings spellCheckSettings, SolrCore core) {

    String chardistanceClassname = SearchRequestUtil
        .getFromNamedList(spellchecker, "chardistance_classname", null);

    CharDistance charDistance = null;
    if (chardistanceClassname != null) {
      charDistance = SearchRequestUtil
          .getClassFromLoader(chardistanceClassname, core.getResourceLoader(), CharDistance.class,
              new String[0], toObjectArr());
    }

    return new WeightedDamerauLevenshteinDistance(spellCheckSettings.getDeletionWeight(),
        spellCheckSettings.getInsertionWeight(), spellCheckSettings.getReplaceWeight(),
        spellCheckSettings.getTranspositionWeight(), charDistance);
  }

  private void addSpellChecker(SolrCore core, NamedList spellcheckerNL) {

    SpellCheckSettings spellCheckSettings = SpellCheckSettings.builder()
        .deletionWeight(SearchRequestUtil.getFromNamedList(spellcheckerNL, "deleteionWeight", 1.0f))
        .insertionWeight(
            SearchRequestUtil.getFromNamedList(spellcheckerNL, "insertionWeight", 1.0f))
        .replaceWeight(SearchRequestUtil.getFromNamedList(spellcheckerNL, "replaceWeight", 1.0f))
        .transpositionWeight(
            SearchRequestUtil.getFromNamedList(spellcheckerNL, "transpositionWeight", 1.0f))
        .maxEditDistance(
            SearchRequestUtil.getFromNamedList(spellcheckerNL, "maxEditDistance", 2.0d))
        .prefixLength(SearchRequestUtil.getFromNamedList(spellcheckerNL, "prefixLength", 7))
        .verbosity(Verbosity.valueOf(
            SearchRequestUtil
                .getFromNamedList(spellcheckerNL, "verbosity", Verbosity.ALL.name())))
        .countThreshold(SearchRequestUtil.getFromNamedList(spellcheckerNL, "countThreshold", 10))
        .doKeySplit(
            SearchRequestUtil.getFromNamedList(spellcheckerNL, "createBigram", true))
        .keySplitRegex(
            SearchRequestUtil.getFromNamedList(spellcheckerNL, "bigramSplitRegex", "\\s+"))
        .build();

    StringDistance stringDistance = getStringDistance(spellcheckerNL, spellCheckSettings, core);

    DataHolder dataHolder = new InMemoryDataHolder(spellCheckSettings, new Murmur3HashFunction());

    spellChecker = new SymSpellCheck(dataHolder, stringDistance, spellCheckSettings);

    String[] fieldList = SearchRequestUtil.getFromNamedList(spellcheckerNL, "field_names", "")
        .split("\\s+");

    // Register event listeners for this SpellChecker
    customSpellCheckListner = new CustomSpellCheckListner(core, spellChecker, fieldList);
    core.registerFirstSearcherListener(customSpellCheckListner);

    String unigramsFile = SearchRequestUtil.getFromNamedList(spellcheckerNL, "unigrams_file", null);
    String bigramsFile = SearchRequestUtil.getFromNamedList(spellcheckerNL, "bigrams_file", null);
    String exclusionsFile = SearchRequestUtil
        .getFromNamedList(spellcheckerNL, "exclusions_file", null);
    String exclustionnFileSeperator = SearchRequestUtil
        .getFromNamedList(spellcheckerNL, "exclusions_file_sp", "\\s+");
    loadDefault(unigramsFile, bigramsFile, exclusionsFile, spellChecker, core,
        exclustionnFileSeperator);
    boolean buildOnCommit = Boolean.parseBoolean((String) spellcheckerNL.get("buildOnCommit"));
    boolean buildOnOptimize = Boolean.parseBoolean((String) spellcheckerNL.get("buildOnOptimize"));
    if (buildOnCommit || buildOnOptimize) {
      log.info("Registering newSearcher listener for spellChecker");
      core.registerNewSearcherListener(
          new CustomSpellCheckListner(core, spellChecker, fieldList));
    }

  }

  private void loadDefault(String unigramsFile, String bigramsFile, String exclusionsFile,
      SpellChecker spellChecker,
      SolrCore core, String exclusionListSperatorRegex) {
    try {
      if (!StringUtils.isEmpty(unigramsFile)) {
        loadUniGramFile(core.getResourceLoader().openResource(unigramsFile),
            spellChecker.getDataHolder());
      }

      if (!StringUtils.isEmpty(bigramsFile)) {
        loadBiGramFile(core.getResourceLoader().openResource(bigramsFile),
            spellChecker.getDataHolder());
      }

      if (!StringUtils.isEmpty(exclusionsFile)) {
        loadExclusions(core.getResourceLoader().openResource(exclusionsFile),
            spellChecker.getDataHolder(), exclusionListSperatorRegex);
      }

    } catch (SpellCheckException | IOException ex) {
      log.error("Error occured while loading default Configs for Spellcheck");
    }
  }

  private void loadUniGramFile(InputStream inputStream, DataHolder dataHolder)
      throws IOException, SpellCheckException {
    try (BufferedReader br = new BufferedReader(
        new InputStreamReader(inputStream))) {
      String line;
      while ((line = br.readLine()) != null) {
        String[] arr = line.split("\\s+");
        dataHolder.addItem(new DictionaryItem(arr[0], Double.parseDouble(arr[1]), -1.0));
      }
    }
  }

  private void loadBiGramFile(InputStream inputStream, DataHolder dataHolder)
      throws IOException, SpellCheckException {
    try (BufferedReader br = new BufferedReader(
        new InputStreamReader(inputStream))) {
      String line;
      while ((line = br.readLine()) != null) {
        String[] arr = line.split("\\s+");
        dataHolder
            .addItem(new DictionaryItem(arr[0] + " " + arr[1], Double.parseDouble(arr[2]), -1.0));
      }
    }
  }

  private void loadExclusions(InputStream inputStream, DataHolder dataHolder, String seperatorRegex)
      throws IOException, SpellCheckException {
    try (BufferedReader br = new BufferedReader(
        new InputStreamReader(inputStream))) {
      String line;
      while ((line = br.readLine()) != null) {
        String[] arr = line.split(seperatorRegex);
        if (arr.length == 2) {
          dataHolder.addExclusionItem(arr[0], arr[1]);
        }
      }
    }
  }

  private Object[] toObjectArr(Object... args) {
    return args;
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy