io.github.mightguy.symspell.solr.component.SpellcheckComponent Maven / Gradle / Ivy
The newest version!
package io.github.mightguy.symspell.solr.component;
import io.github.mightguy.spellcheck.symspell.api.CharDistance;
import io.github.mightguy.spellcheck.symspell.api.DataHolder;
import io.github.mightguy.spellcheck.symspell.api.SpellChecker;
import io.github.mightguy.spellcheck.symspell.api.StringDistance;
import io.github.mightguy.spellcheck.symspell.common.DictionaryItem;
import io.github.mightguy.spellcheck.symspell.common.Murmur3HashFunction;
import io.github.mightguy.spellcheck.symspell.common.SpellCheckSettings;
import io.github.mightguy.spellcheck.symspell.common.SuggestionItem;
import io.github.mightguy.spellcheck.symspell.common.Verbosity;
import io.github.mightguy.spellcheck.symspell.common.WeightedDamerauLevenshteinDistance;
import io.github.mightguy.spellcheck.symspell.exception.SpellCheckException;
import io.github.mightguy.spellcheck.symspell.impl.InMemoryDataHolder;
import io.github.mightguy.spellcheck.symspell.impl.SymSpellCheck;
import io.github.mightguy.symspell.solr.eventlistner.CustomSpellCheckListner;
import io.github.mightguy.symspell.solr.utils.Constants;
import io.github.mightguy.symspell.solr.utils.SearchRequestUtil;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.solr.common.StringUtils;
import org.apache.solr.common.params.CommonParams;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.core.SolrCore;
import org.apache.solr.handler.component.ResponseBuilder;
import org.apache.solr.handler.component.SearchComponent;
import org.apache.solr.util.plugin.SolrCoreAware;
@Slf4j
public class SpellcheckComponent extends SearchComponent implements SolrCoreAware {
private NamedList initParams;
@Getter
private SpellChecker spellChecker;
private CustomSpellCheckListner customSpellCheckListner;
private int threshold = 0;
public static final String COMPONENT_NAME = "custom_spellcheck";
@Override
public void init(NamedList args) {
super.init(args);
this.initParams = args;
}
@Override
public void prepare(ResponseBuilder rb) throws IOException {
SolrParams params = rb.req.getParams();
threshold = params.getInt(Constants.SPELLCHECK_THRESHOLD, 15);
if (!params.getBool(COMPONENT_NAME, false)) {
return;
}
try {
if (params.getBool(Constants.SPELLCHECK_BUILD, false)) {
customSpellCheckListner.reload(rb.req.getSearcher(), spellChecker);
rb.rsp.add("command", "build");
}
} catch (SpellCheckException ex) {
log.error("Unable to build spellcheck indexes");
throw new IOException(ex);
}
}
@Override
public void process(ResponseBuilder rb) throws IOException {
if (!rb.req.getParams().getBool(Constants.SPELLCHECK_ENABLE, true) || SearchRequestUtil
.resultGreaterThanThreshold(rb.rsp, threshold)) {
log.debug("Spellcheck is disbaled either by query or result greater than threshold [{}]",
threshold);
return;
}
SolrParams params = rb.req.getParams();
String q = params.get(Constants.SPELLCHECK_Q, params.get(CommonParams.Q));
boolean sow = params.getBool(Constants.SPELLCHECK_SOW, true);
List suggestions;
try {
if (sow) {
suggestions = spellChecker.lookupCompound(q);
} else {
suggestions = spellChecker.lookupCompound(q, 2, false);
}
if (!CollectionUtils.isEmpty(suggestions)) {
addToResponse(rb, suggestions);
}
} catch (SpellCheckException ex) {
log.error("exception occured while looking for spelling suggestions");
throw new IOException(ex);
}
}
private void addToResponse(ResponseBuilder rb, List suggestions) {
rb.rsp.add("spell_suggestions", toNamedList(suggestions));
}
private NamedList toNamedList(List suggestionItems) {
NamedList result = new NamedList();
if (CollectionUtils.isEmpty(suggestionItems)) {
return result;
}
Map suggestions = suggestionItems.parallelStream().collect(
Collectors.toMap(SuggestionItem::getTerm,
s -> s.getCount() + "," + s.getDistance() + "," + s.getScore()));
result.add("spellcheck", suggestions);
result.add("correctlySpelled", suggestionItems.get(0).getDistance() == 0);
return result;
}
@Override
public String getDescription() {
return "SymSpell based Spellchecker Component";
}
@Override
public void inform(SolrCore core) {
if (initParams == null) {
return;
}
log.info("Initializing spell checkers");
if (initParams.getName(0).equals("spellcheckers")) {
Object cfg = initParams.getVal(0);
if (cfg instanceof NamedList) {
addSpellChecker(core, (NamedList) cfg);
} else if (cfg instanceof Map) {
addSpellChecker(core, new NamedList((Map) cfg));
} else if (cfg instanceof List) {
for (Object o : (List) cfg) {
if (o instanceof Map) {
addSpellChecker(core, new NamedList((Map) o));
}
}
}
}
log.info("Spell checker Initialization completed");
}
@Override
public Category getCategory() {
return Category.SPELLCHECKER;
}
private StringDistance getStringDistance(NamedList spellchecker,
SpellCheckSettings spellCheckSettings, SolrCore core) {
String chardistanceClassname = SearchRequestUtil
.getFromNamedList(spellchecker, "chardistance_classname", null);
CharDistance charDistance = null;
if (chardistanceClassname != null) {
charDistance = SearchRequestUtil
.getClassFromLoader(chardistanceClassname, core.getResourceLoader(), CharDistance.class,
new String[0], toObjectArr());
}
return new WeightedDamerauLevenshteinDistance(spellCheckSettings.getDeletionWeight(),
spellCheckSettings.getInsertionWeight(), spellCheckSettings.getReplaceWeight(),
spellCheckSettings.getTranspositionWeight(), charDistance);
}
private void addSpellChecker(SolrCore core, NamedList spellcheckerNL) {
SpellCheckSettings spellCheckSettings = SpellCheckSettings.builder()
.deletionWeight(SearchRequestUtil.getFromNamedList(spellcheckerNL, "deleteionWeight", 1.0f))
.insertionWeight(
SearchRequestUtil.getFromNamedList(spellcheckerNL, "insertionWeight", 1.0f))
.replaceWeight(SearchRequestUtil.getFromNamedList(spellcheckerNL, "replaceWeight", 1.0f))
.transpositionWeight(
SearchRequestUtil.getFromNamedList(spellcheckerNL, "transpositionWeight", 1.0f))
.maxEditDistance(
SearchRequestUtil.getFromNamedList(spellcheckerNL, "maxEditDistance", 2.0d))
.prefixLength(SearchRequestUtil.getFromNamedList(spellcheckerNL, "prefixLength", 7))
.verbosity(Verbosity.valueOf(
SearchRequestUtil
.getFromNamedList(spellcheckerNL, "verbosity", Verbosity.ALL.name())))
.countThreshold(SearchRequestUtil.getFromNamedList(spellcheckerNL, "countThreshold", 10))
.doKeySplit(
SearchRequestUtil.getFromNamedList(spellcheckerNL, "createBigram", true))
.keySplitRegex(
SearchRequestUtil.getFromNamedList(spellcheckerNL, "bigramSplitRegex", "\\s+"))
.build();
StringDistance stringDistance = getStringDistance(spellcheckerNL, spellCheckSettings, core);
DataHolder dataHolder = new InMemoryDataHolder(spellCheckSettings, new Murmur3HashFunction());
spellChecker = new SymSpellCheck(dataHolder, stringDistance, spellCheckSettings);
String[] fieldList = SearchRequestUtil.getFromNamedList(spellcheckerNL, "field_names", "")
.split("\\s+");
// Register event listeners for this SpellChecker
customSpellCheckListner = new CustomSpellCheckListner(core, spellChecker, fieldList);
core.registerFirstSearcherListener(customSpellCheckListner);
String unigramsFile = SearchRequestUtil.getFromNamedList(spellcheckerNL, "unigrams_file", null);
String bigramsFile = SearchRequestUtil.getFromNamedList(spellcheckerNL, "bigrams_file", null);
String exclusionsFile = SearchRequestUtil
.getFromNamedList(spellcheckerNL, "exclusions_file", null);
String exclustionnFileSeperator = SearchRequestUtil
.getFromNamedList(spellcheckerNL, "exclusions_file_sp", "\\s+");
loadDefault(unigramsFile, bigramsFile, exclusionsFile, spellChecker, core,
exclustionnFileSeperator);
boolean buildOnCommit = Boolean.parseBoolean((String) spellcheckerNL.get("buildOnCommit"));
boolean buildOnOptimize = Boolean.parseBoolean((String) spellcheckerNL.get("buildOnOptimize"));
if (buildOnCommit || buildOnOptimize) {
log.info("Registering newSearcher listener for spellChecker");
core.registerNewSearcherListener(
new CustomSpellCheckListner(core, spellChecker, fieldList));
}
}
private void loadDefault(String unigramsFile, String bigramsFile, String exclusionsFile,
SpellChecker spellChecker,
SolrCore core, String exclusionListSperatorRegex) {
try {
if (!StringUtils.isEmpty(unigramsFile)) {
loadUniGramFile(core.getResourceLoader().openResource(unigramsFile),
spellChecker.getDataHolder());
}
if (!StringUtils.isEmpty(bigramsFile)) {
loadBiGramFile(core.getResourceLoader().openResource(bigramsFile),
spellChecker.getDataHolder());
}
if (!StringUtils.isEmpty(exclusionsFile)) {
loadExclusions(core.getResourceLoader().openResource(exclusionsFile),
spellChecker.getDataHolder(), exclusionListSperatorRegex);
}
} catch (SpellCheckException | IOException ex) {
log.error("Error occured while loading default Configs for Spellcheck");
}
}
private void loadUniGramFile(InputStream inputStream, DataHolder dataHolder)
throws IOException, SpellCheckException {
try (BufferedReader br = new BufferedReader(
new InputStreamReader(inputStream))) {
String line;
while ((line = br.readLine()) != null) {
String[] arr = line.split("\\s+");
dataHolder.addItem(new DictionaryItem(arr[0], Double.parseDouble(arr[1]), -1.0));
}
}
}
private void loadBiGramFile(InputStream inputStream, DataHolder dataHolder)
throws IOException, SpellCheckException {
try (BufferedReader br = new BufferedReader(
new InputStreamReader(inputStream))) {
String line;
while ((line = br.readLine()) != null) {
String[] arr = line.split("\\s+");
dataHolder
.addItem(new DictionaryItem(arr[0] + " " + arr[1], Double.parseDouble(arr[2]), -1.0));
}
}
}
private void loadExclusions(InputStream inputStream, DataHolder dataHolder, String seperatorRegex)
throws IOException, SpellCheckException {
try (BufferedReader br = new BufferedReader(
new InputStreamReader(inputStream))) {
String line;
while ((line = br.readLine()) != null) {
String[] arr = line.split(seperatorRegex);
if (arr.length == 2) {
dataHolder.addExclusionItem(arr[0], arr[1]);
}
}
}
}
private Object[] toObjectArr(Object... args) {
return args;
}
}