All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.elasticsearch.search.suggest.phrase.DirectCandidateGenerator Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to Elasticsearch under one or more contributor
 * license agreements. See the NOTICE file distributed with
 * this work for additional information regarding copyright
 * ownership. Elasticsearch licenses this file to you under
 * the Apache License, Version 2.0 (the "License"); you may
 * not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.elasticsearch.search.suggest.phrase;

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.index.*;
import org.apache.lucene.search.spell.DirectSpellChecker;
import org.apache.lucene.search.spell.SuggestMode;
import org.apache.lucene.search.spell.SuggestWord;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.CharsRefBuilder;
import org.elasticsearch.search.suggest.SuggestUtils;

import java.io.IOException;
import java.util.*;

//TODO public for tests
public final class DirectCandidateGenerator extends CandidateGenerator {

    private final DirectSpellChecker spellchecker;
    private final String field;
    private final SuggestMode suggestMode;
    private final TermsEnum termsEnum;
    private final IndexReader reader;
    private final long dictSize;
    private final double logBase = 5;
    private final long frequencyPlateau;
    private final Analyzer preFilter;
    private final Analyzer postFilter;
    private final double nonErrorLikelihood;
    private final boolean useTotalTermFrequency;
    private final CharsRefBuilder spare = new CharsRefBuilder();
    private final BytesRefBuilder byteSpare = new BytesRefBuilder();
    private final int numCandidates;
    
    public DirectCandidateGenerator(DirectSpellChecker spellchecker, String field, SuggestMode suggestMode, IndexReader reader, double nonErrorLikelihood, int numCandidates) throws IOException {
        this(spellchecker, field, suggestMode, reader,  nonErrorLikelihood, numCandidates, null, null, MultiFields.getTerms(reader, field));
    }


    public DirectCandidateGenerator(DirectSpellChecker spellchecker, String field, SuggestMode suggestMode, IndexReader reader, double nonErrorLikelihood,  int numCandidates, Analyzer preFilter, Analyzer postFilter, Terms terms) throws IOException {
        if (terms == null) {
            throw new IllegalArgumentException("generator field [" + field + "] doesn't exist");
        }
        this.spellchecker = spellchecker;
        this.field = field;
        this.numCandidates = numCandidates;
        this.suggestMode = suggestMode;
        this.reader = reader;
        final long dictSize = terms.getSumTotalTermFreq();
        this.useTotalTermFrequency = dictSize != -1;
        this.dictSize =  dictSize == -1 ? reader.maxDoc() : dictSize;
        this.preFilter = preFilter;
        this.postFilter = postFilter;
        this.nonErrorLikelihood = nonErrorLikelihood;
        float thresholdFrequency = spellchecker.getThresholdFrequency();
        this.frequencyPlateau = thresholdFrequency >= 1.0f ? (int) thresholdFrequency: (int)(dictSize * thresholdFrequency);
        termsEnum = terms.iterator();
    }

    /* (non-Javadoc)
     * @see org.elasticsearch.search.suggest.phrase.CandidateGenerator#isKnownWord(org.apache.lucene.util.BytesRef)
     */
    @Override
    public boolean isKnownWord(BytesRef term) throws IOException {
        return frequency(term) > 0;
    }

    /* (non-Javadoc)
     * @see org.elasticsearch.search.suggest.phrase.CandidateGenerator#frequency(org.apache.lucene.util.BytesRef)
     */
    @Override
    public long frequency(BytesRef term) throws IOException {
        term = preFilter(term, spare, byteSpare);
        return internalFrequency(term);
    }


    public long internalFrequency(BytesRef term) throws IOException {
        if (termsEnum.seekExact(term)) {
            return useTotalTermFrequency ? termsEnum.totalTermFreq() : termsEnum.docFreq(); 
        }
        return 0;
    }
    
    public String getField() {
        return field;
    }
    
    /* (non-Javadoc)
     * @see org.elasticsearch.search.suggest.phrase.CandidateGenerator#drawCandidates(org.elasticsearch.search.suggest.phrase.DirectCandidateGenerator.CandidateSet, int)
     */
    @Override
    public CandidateSet drawCandidates(CandidateSet set) throws IOException {
        Candidate original = set.originalTerm;
        BytesRef term = preFilter(original.term, spare, byteSpare);
        final long frequency = original.frequency;
        spellchecker.setThresholdFrequency(this.suggestMode == SuggestMode.SUGGEST_ALWAYS ? 0 : thresholdFrequency(frequency, dictSize));
        SuggestWord[] suggestSimilar = spellchecker.suggestSimilar(new Term(field, term), numCandidates, reader, this.suggestMode);
        List candidates = new ArrayList<>(suggestSimilar.length);
        for (int i = 0; i < suggestSimilar.length; i++) {
            SuggestWord suggestWord = suggestSimilar[i];
            BytesRef candidate = new BytesRef(suggestWord.string);
            postFilter(new Candidate(candidate, internalFrequency(candidate), suggestWord.score, score(suggestWord.freq, suggestWord.score, dictSize), false), spare, byteSpare, candidates);
        }
        set.addCandidates(candidates);
        return set;
    }
    
    protected BytesRef preFilter(final BytesRef term, final CharsRefBuilder spare, final BytesRefBuilder byteSpare) throws IOException {
        if (preFilter == null) {
            return term;
        }
        final BytesRefBuilder result = byteSpare;
        SuggestUtils.analyze(preFilter, term, field, new SuggestUtils.TokenConsumer() {
            
            @Override
            public void nextToken() throws IOException {
                this.fillBytesRef(result);
            }
        }, spare);
        return result.get();
    }
    
    protected void postFilter(final Candidate candidate, final CharsRefBuilder spare, BytesRefBuilder byteSpare, final List candidates) throws IOException {
        if (postFilter == null) {
            candidates.add(candidate);
        } else {
            final BytesRefBuilder result = byteSpare;
            SuggestUtils.analyze(postFilter, candidate.term, field, new SuggestUtils.TokenConsumer() {
                @Override
                public void nextToken() throws IOException {
                    this.fillBytesRef(result);
                    
                    if (posIncAttr.getPositionIncrement() > 0 && result.get().bytesEquals(candidate.term))  {
                        BytesRef term = result.toBytesRef();
                        // We should not use frequency(term) here because it will analyze the term again
                        // If preFilter and postFilter are the same analyzer it would fail. 
                        long freq = internalFrequency(term);
                        candidates.add(new Candidate(result.toBytesRef(), freq, candidate.stringDistance, score(candidate.frequency, candidate.stringDistance, dictSize), false));
                    } else {
                        candidates.add(new Candidate(result.toBytesRef(), candidate.frequency, nonErrorLikelihood, score(candidate.frequency, candidate.stringDistance, dictSize), false));
                    }
                }
            }, spare);
        }
    }
    
    private double score(long frequency, double errorScore, long dictionarySize) {
        return errorScore * (((double)frequency + 1) / ((double)dictionarySize +1));
    }
    
    protected long thresholdFrequency(long termFrequency, long dictionarySize) {
        if (termFrequency > 0) {
            return (long) Math.max(0, Math.round(termFrequency * (Math.log10(termFrequency - frequencyPlateau) * (1.0 / Math.log10(logBase))) + 1));
        }
        return 0;
        
    }
    
    public static class CandidateSet {
        public Candidate[] candidates;
        public final Candidate originalTerm;

        public CandidateSet(Candidate[] candidates, Candidate originalTerm) {
            this.candidates = candidates;
            this.originalTerm = originalTerm;
        }
        
        public void addCandidates(List candidates) {
            // Merge new candidates into existing ones,
            // deduping:
            final Set set = new HashSet<>(candidates);
            for (int i = 0; i < this.candidates.length; i++) {
                set.add(this.candidates[i]);
            }
            this.candidates = set.toArray(new Candidate[set.size()]);
            // Sort strongest to weakest:
            Arrays.sort(this.candidates, Collections.reverseOrder());
        }

        public void addOneCandidate(Candidate candidate) {
            Candidate[] candidates = new Candidate[this.candidates.length + 1];
            System.arraycopy(this.candidates, 0, candidates, 0, this.candidates.length);
            candidates[candidates.length-1] = candidate;
            this.candidates = candidates;
        }

    }

    public static class Candidate implements Comparable {
        public static final Candidate[] EMPTY = new Candidate[0];
        public final BytesRef term;
        public final double stringDistance;
        public final long frequency;
        public final double score;
        public final boolean userInput;

        public Candidate(BytesRef term, long frequency, double stringDistance, double score, boolean userInput) {
            this.frequency = frequency;
            this.term = term;
            this.stringDistance = stringDistance;
            this.score = score;
            this.userInput = userInput;
        }

        @Override
        public String toString() {
            return "Candidate [term=" + term.utf8ToString() + ", stringDistance=" + stringDistance + ", score=" + score + ", frequency=" + frequency + 
                    (userInput ? ", userInput" : "" ) + "]";
        }

        @Override
        public int hashCode() {
            final int prime = 31;
            int result = 1;
            result = prime * result + ((term == null) ? 0 : term.hashCode());
            return result;
        }

        @Override
        public boolean equals(Object obj) {
            if (this == obj)
                return true;
            if (obj == null)
                return false;
            if (getClass() != obj.getClass())
                return false;
            Candidate other = (Candidate) obj;
            if (term == null) {
                if (other.term != null)
                    return false;
            } else if (!term.equals(other.term))
                return false;
            return true;
        }

        /** Lower scores sort first; if scores are equal, then later (zzz) terms sort first */
        @Override
        public int compareTo(Candidate other) {
            if (score == other.score) {
                // Later (zzz) terms sort before earlier (aaa) terms:
                return other.term.compareTo(term);
            } else {
                return Double.compare(score, other.score);
            }
        }
    }

    @Override
    public Candidate createCandidate(BytesRef term, long frequency, double channelScore, boolean userInput) throws IOException {
        return new Candidate(term, frequency, channelScore, score(frequency, channelScore, dictSize), userInput);
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy