![JAR search and dependency download from the Maven repository](/logo.png)
org.opensearch.search.suggest.phrase.DirectCandidateGenerator Maven / Gradle / Ivy
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/
package org.opensearch.search.suggest.phrase;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.analysis.tokenattributes.OffsetAttribute;
import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
import org.apache.lucene.codecs.TermStats;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.MultiTerms;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.spell.DirectSpellChecker;
import org.apache.lucene.search.spell.SuggestMode;
import org.apache.lucene.search.spell.SuggestWord;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.CharsRef;
import org.apache.lucene.util.CharsRefBuilder;
import org.opensearch.core.internal.io.IOUtils;
import java.io.CharArrayReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import static java.lang.Math.log10;
import static java.lang.Math.max;
import static java.lang.Math.min;
import static java.lang.Math.round;
public final class DirectCandidateGenerator extends CandidateGenerator {
private final DirectSpellChecker spellchecker;
private final String field;
private final SuggestMode suggestMode;
private final TermsEnum termsEnum;
private final IndexReader reader;
private final long sumTotalTermFreq;
private static final double LOG_BASE = 5;
private final long frequencyPlateau;
private final Analyzer preFilter;
private final Analyzer postFilter;
private final double nonErrorLikelihood;
private final CharsRefBuilder spare = new CharsRefBuilder();
private final BytesRefBuilder byteSpare = new BytesRefBuilder();
private final int numCandidates;
public DirectCandidateGenerator(
DirectSpellChecker spellchecker,
String field,
SuggestMode suggestMode,
IndexReader reader,
double nonErrorLikelihood,
int numCandidates
) throws IOException {
this(spellchecker, field, suggestMode, reader, nonErrorLikelihood, numCandidates, null, null, MultiTerms.getTerms(reader, field));
}
public DirectCandidateGenerator(
DirectSpellChecker spellchecker,
String field,
SuggestMode suggestMode,
IndexReader reader,
double nonErrorLikelihood,
int numCandidates,
Analyzer preFilter,
Analyzer postFilter,
Terms terms
) throws IOException {
if (terms == null) {
throw new IllegalArgumentException("generator field [" + field + "] doesn't exist");
}
this.spellchecker = spellchecker;
this.field = field;
this.numCandidates = numCandidates;
this.suggestMode = suggestMode;
this.reader = reader;
this.sumTotalTermFreq = terms.getSumTotalTermFreq() == -1 ? reader.maxDoc() : terms.getSumTotalTermFreq();
this.preFilter = preFilter;
this.postFilter = postFilter;
this.nonErrorLikelihood = nonErrorLikelihood;
float thresholdFrequency = spellchecker.getThresholdFrequency();
this.frequencyPlateau = thresholdFrequency >= 1.0f ? (int) thresholdFrequency : (int) (reader.maxDoc() * thresholdFrequency);
termsEnum = terms.iterator();
}
/* (non-Javadoc)
* @see org.opensearch.search.suggest.phrase.CandidateGenerator#isKnownWord(org.apache.lucene.util.BytesRef)
*/
@Override
public boolean isKnownWord(BytesRef term) throws IOException {
return termStats(term).docFreq > 0;
}
/* (non-Javadoc)
* @see org.opensearch.search.suggest.phrase.CandidateGenerator#frequency(org.apache.lucene.util.BytesRef)
*/
@Override
public TermStats termStats(BytesRef term) throws IOException {
term = preFilter(term, spare, byteSpare);
return internalTermStats(term);
}
public TermStats internalTermStats(BytesRef term) throws IOException {
if (termsEnum.seekExact(term)) {
return new TermStats(
termsEnum.docFreq(),
/**
* We use the {@link TermsEnum#docFreq()} for fields that don't
* record the {@link TermsEnum#totalTermFreq()}.
*/
termsEnum.totalTermFreq() == -1 ? termsEnum.docFreq() : termsEnum.totalTermFreq()
);
}
return new TermStats(0, 0);
}
public String getField() {
return field;
}
@Override
public CandidateSet drawCandidates(CandidateSet set) throws IOException {
Candidate original = set.originalTerm;
BytesRef term = preFilter(original.term, spare, byteSpare);
float origThreshold = spellchecker.getThresholdFrequency();
try {
if (suggestMode != SuggestMode.SUGGEST_ALWAYS) {
/**
* We use the {@link TermStats#docFreq} to compute the frequency threshold
* because that's what {@link DirectSpellChecker#suggestSimilar} expects
* when filtering terms.
*/
int threshold = thresholdTermFrequency(original.termStats.docFreq);
if (threshold == Integer.MAX_VALUE) {
// the threshold is the max possible frequency so we can skip the search
return set;
}
// don't override the threshold if the provided min_doc_freq is greater
// than the original term frequency.
if (spellchecker.getThresholdFrequency() < threshold) {
spellchecker.setThresholdFrequency(threshold);
}
}
SuggestWord[] suggestSimilar = spellchecker.suggestSimilar(new Term(field, term), numCandidates, reader, this.suggestMode);
List candidates = new ArrayList<>(suggestSimilar.length);
for (int i = 0; i < suggestSimilar.length; i++) {
SuggestWord suggestWord = suggestSimilar[i];
BytesRef candidate = new BytesRef(suggestWord.string);
TermStats termStats = internalTermStats(candidate);
postFilter(
new Candidate(candidate, termStats, suggestWord.score, score(termStats, suggestWord.score, sumTotalTermFreq), false),
spare,
byteSpare,
candidates
);
}
set.addCandidates(candidates);
return set;
} finally {
// restore the original value back
spellchecker.setThresholdFrequency(origThreshold);
}
}
protected BytesRef preFilter(final BytesRef term, final CharsRefBuilder spare, final BytesRefBuilder byteSpare) throws IOException {
if (preFilter == null) {
return term;
}
final BytesRefBuilder result = byteSpare;
analyze(preFilter, term, field, new TokenConsumer() {
@Override
public void nextToken() throws IOException {
this.fillBytesRef(result);
}
}, spare);
return result.get();
}
protected void postFilter(
final Candidate candidate,
final CharsRefBuilder spare,
BytesRefBuilder byteSpare,
final List candidates
) throws IOException {
if (postFilter == null) {
candidates.add(candidate);
} else {
final BytesRefBuilder result = byteSpare;
analyze(postFilter, candidate.term, field, new TokenConsumer() {
@Override
public void nextToken() throws IOException {
this.fillBytesRef(result);
if (posIncAttr.getPositionIncrement() > 0 && result.get().bytesEquals(candidate.term)) {
BytesRef term = result.toBytesRef();
// We should not use frequency(term) here because it will analyze the term again
// If preFilter and postFilter are the same analyzer it would fail.
TermStats termStats = internalTermStats(term);
candidates.add(
new Candidate(
result.toBytesRef(),
termStats,
candidate.stringDistance,
score(candidate.termStats, candidate.stringDistance, sumTotalTermFreq),
false
)
);
} else {
candidates.add(
new Candidate(
result.toBytesRef(),
candidate.termStats,
nonErrorLikelihood,
score(candidate.termStats, candidate.stringDistance, sumTotalTermFreq),
false
)
);
}
}
}, spare);
}
}
private double score(TermStats termStats, double errorScore, long dictionarySize) {
return errorScore * (((double) termStats.totalTermFreq + 1) / ((double) dictionarySize + 1));
}
// package protected for test
int thresholdTermFrequency(int docFreq) {
if (docFreq > 0) {
return (int) min(max(0, round(docFreq * (log10(docFreq - frequencyPlateau) * (1.0 / log10(LOG_BASE))) + 1)), Integer.MAX_VALUE);
}
return 0;
}
public abstract static class TokenConsumer {
protected CharTermAttribute charTermAttr;
protected PositionIncrementAttribute posIncAttr;
protected OffsetAttribute offsetAttr;
public void reset(TokenStream stream) {
charTermAttr = stream.addAttribute(CharTermAttribute.class);
posIncAttr = stream.addAttribute(PositionIncrementAttribute.class);
offsetAttr = stream.addAttribute(OffsetAttribute.class);
}
protected BytesRef fillBytesRef(BytesRefBuilder spare) {
spare.copyChars(charTermAttr);
return spare.get();
}
public abstract void nextToken() throws IOException;
public void end() {}
}
public static class CandidateSet {
public Candidate[] candidates;
public final Candidate originalTerm;
public CandidateSet(Candidate[] candidates, Candidate originalTerm) {
this.candidates = candidates;
this.originalTerm = originalTerm;
}
public void addCandidates(List candidates) {
// Merge new candidates into existing ones,
// deduping:
final Set set = new HashSet<>(candidates);
Collections.addAll(set, this.candidates);
this.candidates = set.toArray(new Candidate[set.size()]);
// Sort strongest to weakest:
Arrays.sort(this.candidates, Collections.reverseOrder());
}
public void addOneCandidate(Candidate candidate) {
Candidate[] candidates = new Candidate[this.candidates.length + 1];
System.arraycopy(this.candidates, 0, candidates, 0, this.candidates.length);
candidates[candidates.length - 1] = candidate;
this.candidates = candidates;
}
}
public static class Candidate implements Comparable {
public static final Candidate[] EMPTY = new Candidate[0];
public final BytesRef term;
public final double stringDistance;
public final TermStats termStats;
public final double score;
public final boolean userInput;
public Candidate(BytesRef term, TermStats termStats, double stringDistance, double score, boolean userInput) {
this.termStats = termStats;
this.term = term;
this.stringDistance = stringDistance;
this.score = score;
this.userInput = userInput;
}
@Override
public String toString() {
return "Candidate [term="
+ term.utf8ToString()
+ ", stringDistance="
+ stringDistance
+ ", score="
+ score
+ ", termStats="
+ termStats
+ (userInput ? ", userInput" : "")
+ "]";
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + ((term == null) ? 0 : term.hashCode());
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj) return true;
if (obj == null) return false;
if (getClass() != obj.getClass()) return false;
Candidate other = (Candidate) obj;
if (term == null) {
if (other.term != null) return false;
} else {
if (!term.equals(other.term)) return false;
}
return true;
}
/** Lower scores sort first; if scores are equal, then later (zzz) terms sort first */
@Override
public int compareTo(Candidate other) {
if (score == other.score) {
// Later (zzz) terms sort before earlier (aaa) terms:
return other.term.compareTo(term);
} else {
return Double.compare(score, other.score);
}
}
}
@Override
public Candidate createCandidate(BytesRef term, TermStats termStats, double channelScore, boolean userInput) throws IOException {
return new Candidate(term, termStats, channelScore, score(termStats, channelScore, sumTotalTermFreq), userInput);
}
public static int analyze(Analyzer analyzer, BytesRef toAnalyze, String field, TokenConsumer consumer, CharsRefBuilder spare)
throws IOException {
spare.copyUTF8Bytes(toAnalyze);
CharsRef charsRef = spare.get();
try (TokenStream ts = analyzer.tokenStream(field, new CharArrayReader(charsRef.chars, charsRef.offset, charsRef.length))) {
return analyze(ts, consumer);
}
}
/** NOTE: this method closes the TokenStream, even on exception, which is awkward
* because really the caller who called {@link Analyzer#tokenStream} should close it,
* but when trying that there are recursion issues when we try to use the same
* TokenStream twice in the same recursion... */
public static int analyze(TokenStream stream, TokenConsumer consumer) throws IOException {
int numTokens = 0;
boolean success = false;
try {
stream.reset();
consumer.reset(stream);
while (stream.incrementToken()) {
consumer.nextToken();
numTokens++;
}
consumer.end();
success = true;
} finally {
if (success) {
stream.close();
} else {
IOUtils.closeWhileHandlingException(stream);
}
}
return numTokens;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy