All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.elasticsearch.search.suggest.phrase.DirectCandidateGenerator Maven / Gradle / Ivy

There is a newer version: 8.14.1
Show newest version
/*
 * Licensed to ElasticSearch and Shay Banon under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. ElasticSearch licenses this
 * file to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.elasticsearch.search.suggest.phrase;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.MultiFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.spell.DirectSpellChecker;
import org.apache.lucene.search.spell.SuggestMode;
import org.apache.lucene.search.spell.SuggestWord;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.CharsRef;
import org.elasticsearch.ElasticSearchIllegalArgumentException;
import org.elasticsearch.search.suggest.SuggestUtils;

//TODO public for tests
public final class DirectCandidateGenerator extends CandidateGenerator {

    private final DirectSpellChecker spellchecker;
    private final String field;
    private final SuggestMode suggestMode;
    private final TermsEnum termsEnum;
    private final IndexReader reader;
    private final long dictSize;
    private final double logBase = 5;
    private final long frequencyPlateau;
    private final Analyzer preFilter;
    private final Analyzer postFilter;
    private final double nonErrorLikelihood;
    private final boolean useTotalTermFrequency;
    private final CharsRef spare = new CharsRef();
    private final BytesRef byteSpare = new BytesRef();
    private final int numCandidates;
    
    public DirectCandidateGenerator(DirectSpellChecker spellchecker, String field, SuggestMode suggestMode, IndexReader reader, double nonErrorLikelihood, int numCandidates) throws IOException {
        this(spellchecker, field, suggestMode, reader,  nonErrorLikelihood, numCandidates, null, null);
    }


    public DirectCandidateGenerator(DirectSpellChecker spellchecker, String field, SuggestMode suggestMode, IndexReader reader, double nonErrorLikelihood,  int numCandidates, Analyzer preFilter, Analyzer postFilter) throws IOException {
        this.spellchecker = spellchecker;
        this.field = field;
        this.numCandidates = numCandidates;
        this.suggestMode = suggestMode;
        this.reader = reader;
        Terms terms = MultiFields.getTerms(reader, field);
        if (terms == null) {
            throw new ElasticSearchIllegalArgumentException("generator field [" + field + "] doesn't exist");
        }
        final long dictSize = terms.getSumTotalTermFreq();
        this.useTotalTermFrequency = dictSize != -1;
        this.dictSize =  dictSize == -1 ? reader.maxDoc() : dictSize;
        this.preFilter = preFilter;
        this.postFilter = postFilter;
        this.nonErrorLikelihood = nonErrorLikelihood;
        float thresholdFrequency = spellchecker.getThresholdFrequency();
        this.frequencyPlateau = thresholdFrequency >= 1.0f ? (int) thresholdFrequency: (int)(dictSize * thresholdFrequency);
        termsEnum = terms.iterator(null);
    }

    /* (non-Javadoc)
     * @see org.elasticsearch.search.suggest.phrase.CandidateGenerator#isKnownWord(org.apache.lucene.util.BytesRef)
     */
    @Override
    public boolean isKnownWord(BytesRef term) throws IOException {
        return frequency(term) > 0;
    }

    /* (non-Javadoc)
     * @see org.elasticsearch.search.suggest.phrase.CandidateGenerator#frequency(org.apache.lucene.util.BytesRef)
     */
    @Override
    public long frequency(BytesRef term) throws IOException {
        term = preFilter(term, spare, byteSpare);
        return internalFrequency(term);
    }


    public long internalFrequency(BytesRef term) throws IOException {
        if (termsEnum.seekExact(term, true)) {
            return useTotalTermFrequency ? termsEnum.totalTermFreq() : termsEnum.docFreq(); 
        }
        return 0;
    }
    
    public String getField() {
        return field;
    }
    
    /* (non-Javadoc)
     * @see org.elasticsearch.search.suggest.phrase.CandidateGenerator#drawCandidates(org.elasticsearch.search.suggest.phrase.DirectCandidateGenerator.CandidateSet, int)
     */
    @Override
    public CandidateSet drawCandidates(CandidateSet set) throws IOException {
        Candidate original = set.originalTerm;
        BytesRef term = preFilter(original.term, spare, byteSpare);
        final long frequency = original.frequency;
        spellchecker.setThresholdFrequency(this.suggestMode == SuggestMode.SUGGEST_ALWAYS ? 0 : thresholdFrequency(frequency, dictSize));
        SuggestWord[] suggestSimilar = spellchecker.suggestSimilar(new Term(field, term), numCandidates, reader, this.suggestMode);
        List candidates = new ArrayList(suggestSimilar.length);
        for (int i = 0; i < suggestSimilar.length; i++) {
            SuggestWord suggestWord = suggestSimilar[i];
            BytesRef candidate = new BytesRef(suggestWord.string);
            postFilter(new Candidate(candidate, internalFrequency(candidate), suggestWord.score, score(suggestWord.freq, suggestWord.score, dictSize)), spare, byteSpare, candidates);
        }
        set.addCandidates(candidates);
        return set;
    }
    
    protected BytesRef preFilter(final BytesRef term, final CharsRef spare, final BytesRef byteSpare) throws IOException {
        if (preFilter == null) {
            return term;
        }
        final BytesRef result = byteSpare;
        SuggestUtils.analyze(preFilter, term, field, new SuggestUtils.TokenConsumer() {
            
            @Override
            public void nextToken() throws IOException {
                this.fillBytesRef(result);
            }
        }, spare);
        return result;
    }
    
    protected void postFilter(final Candidate candidate, final CharsRef spare, BytesRef byteSpare, final List candidates) throws IOException {
        if (postFilter == null) {
            candidates.add(candidate);
        } else {
            final BytesRef result = byteSpare;
            SuggestUtils.analyze(postFilter, candidate.term, field, new SuggestUtils.TokenConsumer() {
                @Override
                public void nextToken() throws IOException {
                    this.fillBytesRef(result);
                    
                    if (posIncAttr.getPositionIncrement() > 0 && result.bytesEquals(candidate.term))  {
                        BytesRef term = BytesRef.deepCopyOf(result);    
                        long freq = frequency(term);
                        candidates.add(new Candidate(BytesRef.deepCopyOf(term), freq, candidate.stringDistance, score(candidate.frequency, candidate.stringDistance, dictSize)));
                    } else {
                        candidates.add(new Candidate(BytesRef.deepCopyOf(result), candidate.frequency, nonErrorLikelihood, score(candidate.frequency, candidate.stringDistance, dictSize)));
                    }
                }
            }, spare);
        }
    }
    
    private double score(long frequency, double errorScore, long dictionarySize) {
        return errorScore * (((double)frequency + 1) / ((double)dictionarySize +1));
    }
    
    protected long thresholdFrequency(long termFrequency, long dictionarySize) {
        if (termFrequency > 0) {
            return (long) Math.max(0, Math.round(termFrequency * (Math.log10(termFrequency - frequencyPlateau) * (1.0 / Math.log10(logBase))) + 1));
        }
        return 0;
        
    }
    
    public static class CandidateSet {
        public Candidate[] candidates;
        public final Candidate originalTerm;

        public CandidateSet(Candidate[] candidates, Candidate originalTerm) {
            this.candidates = candidates;
            this.originalTerm = originalTerm;
        }
        
        public void addCandidates(List candidates) {
            final Set set = new HashSet(candidates);
            for (int i = 0; i < this.candidates.length; i++) {
                set.add(this.candidates[i]);
            }
            this.candidates = set.toArray(new Candidate[set.size()]);
        }

        public void addOneCandidate(Candidate candidate) {
            Candidate[] candidates = new Candidate[this.candidates.length + 1];
            System.arraycopy(this.candidates, 0, candidates, 0, this.candidates.length);
            candidates[candidates.length-1] = candidate;
            this.candidates = candidates;
        }

    }

    public static class Candidate {
        public static final Candidate[] EMPTY = new Candidate[0];
        public final BytesRef term;
        public final double stringDistance;
        public final long frequency;
        public final double score;

        public Candidate(BytesRef term, long frequency, double stringDistance, double score) {
            this.frequency = frequency;
            this.term = term;
            this.stringDistance = stringDistance;
            this.score = score;
        }

        @Override
        public String toString() {
            return "Candidate [term=" + term.utf8ToString() + ", stringDistance=" + stringDistance + ", frequency=" + frequency + "]";
        }

        @Override
        public int hashCode() {
            final int prime = 31;
            int result = 1;
            result = prime * result + ((term == null) ? 0 : term.hashCode());
            return result;
        }

        @Override
        public boolean equals(Object obj) {
            if (this == obj)
                return true;
            if (obj == null)
                return false;
            if (getClass() != obj.getClass())
                return false;
            Candidate other = (Candidate) obj;
            if (term == null) {
                if (other.term != null)
                    return false;
            } else if (!term.equals(other.term))
                return false;
            return true;
        }
    }

    @Override
    public Candidate createCandidate(BytesRef term, long frequency, double channelScore) throws IOException {
        return new Candidate(term, frequency, channelScore, score(frequency, channelScore, dictSize));
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy