All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.lucene.queries.XIntervals Maven / Gradle / Ivy

There is a newer version: 8.13.2
Show newest version
/*
 * Licensed to Elasticsearch under one or more contributor
 * license agreements. See the NOTICE file distributed with
 * this work for additional information regarding copyright
 * ownership. Elasticsearch licenses this file to you under
 * the Apache License, Version 2.0 (the "License"); you may
 * not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.lucene.queries;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.queries.intervals.IntervalIterator;
import org.apache.lucene.queries.intervals.IntervalQuery;
import org.apache.lucene.queries.intervals.Intervals;
import org.apache.lucene.queries.intervals.IntervalsSource;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.MatchesIterator;
import org.apache.lucene.search.MatchesUtils;
import org.apache.lucene.search.PrefixQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.PriorityQueue;
import org.apache.lucene.util.automaton.CompiledAutomaton;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;

/**
 * Replacement for {@link Intervals#wildcard(BytesRef)} and {@link Intervals#prefix(BytesRef)}
 * until LUCENE-9050 is merged
 */
public final class XIntervals {

    private XIntervals() {}

    public static IntervalsSource wildcard(BytesRef wildcard) {
        CompiledAutomaton ca = new CompiledAutomaton(WildcardQuery.toAutomaton(new Term("", wildcard)));
        return new MultiTermIntervalsSource(ca, 128, wildcard.utf8ToString());
    }

    public static IntervalsSource prefix(BytesRef prefix) {
        CompiledAutomaton ca = new CompiledAutomaton(PrefixQuery.toAutomaton(prefix));
        return new MultiTermIntervalsSource(ca, 128, prefix.utf8ToString());
    }

    public static IntervalsSource multiterm(CompiledAutomaton ca, String label) {
        return new MultiTermIntervalsSource(ca, 128, label);
    }

    static class MultiTermIntervalsSource extends IntervalsSource {

        private final CompiledAutomaton automaton;
        private final int maxExpansions;
        private final String pattern;

        MultiTermIntervalsSource(CompiledAutomaton automaton, int maxExpansions, String pattern) {
            this.automaton = automaton;
            if (maxExpansions > BooleanQuery.getMaxClauseCount()) {
                throw new IllegalArgumentException("maxExpansions [" + maxExpansions
                    + "] cannot be greater than BooleanQuery.getMaxClauseCount [" + BooleanQuery.getMaxClauseCount() + "]");
            }
            this.maxExpansions = maxExpansions;
            this.pattern = pattern;
        }

        @Override
        public IntervalIterator intervals(String field, LeafReaderContext ctx) throws IOException {
            Terms terms = ctx.reader().terms(field);
            if (terms == null) {
                return null;
            }
            List subSources = new ArrayList<>();
            TermsEnum te = automaton.getTermsEnum(terms);
            BytesRef term;
            int count = 0;
            while ((term = te.next()) != null) {
                subSources.add(TermIntervalsSource.intervals(term, te));
                if (++count > maxExpansions) {
                    throw new IllegalStateException("Automaton [" + this.pattern + "] expanded to too many terms (limit "
                        + maxExpansions + ")");
                }
            }
            if (subSources.size() == 0) {
                return null;
            }
            return new DisjunctionIntervalIterator(subSources);
        }

        @Override
        public MatchesIterator matches(String field, LeafReaderContext ctx, int doc) throws IOException {
            Terms terms = ctx.reader().terms(field);
            if (terms == null) {
                return null;
            }
            List subMatches = new ArrayList<>();
            TermsEnum te = automaton.getTermsEnum(terms);
            BytesRef term;
            int count = 0;
            while ((term = te.next()) != null) {
                MatchesIterator mi = XIntervals.TermIntervalsSource.matches(te, doc);
                if (mi != null) {
                    subMatches.add(mi);
                    if (count++ > maxExpansions) {
                        throw new IllegalStateException("Automaton " + term + " expanded to too many terms (limit " + maxExpansions + ")");
                    }
                }
            }
            return MatchesUtils.disjunction(subMatches);
        }

        @Override
        public void visit(String field, QueryVisitor visitor) {
            visitor.visitLeaf(new IntervalQuery(field, this));
        }

        @Override
        public int minExtent() {
            return 1;
        }

        @Override
        public Collection pullUpDisjunctions() {
            return Collections.singleton(this);
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (o == null || getClass() != o.getClass()) return false;
            MultiTermIntervalsSource that = (MultiTermIntervalsSource) o;
            return maxExpansions == that.maxExpansions &&
                Objects.equals(automaton, that.automaton) &&
                Objects.equals(pattern, that.pattern);
        }

        @Override
        public int hashCode() {
            return Objects.hash(automaton, maxExpansions, pattern);
        }

        @Override
        public String toString() {
            return "MultiTerm(" + pattern + ")";
        }
    }

    static class DisiWrapper {

        public final DocIdSetIterator iterator;
        public final IntervalIterator intervals;
        public final long cost;
        public final float matchCost; // the match cost for two-phase iterators, 0 otherwise
        public int doc; // the current doc, used for comparison
        public DisiWrapper next; // reference to a next element, see #topList

        // An approximation of the iterator, or the iterator itself if it does not
        // support two-phase iteration
        public final DocIdSetIterator approximation;

        DisiWrapper(IntervalIterator iterator) {
            this.intervals = iterator;
            this.iterator = iterator;
            this.cost = iterator.cost();
            this.doc = -1;
            this.approximation = iterator;
            this.matchCost = iterator.matchCost();
        }

    }

    static final class DisiPriorityQueue implements Iterable {

        static int leftNode(int node) {
            return ((node + 1) << 1) - 1;
        }

        static int rightNode(int leftNode) {
            return leftNode + 1;
        }

        static int parentNode(int node) {
            return ((node + 1) >>> 1) - 1;
        }

        private final DisiWrapper[] heap;
        private int size;

        DisiPriorityQueue(int maxSize) {
            heap = new DisiWrapper[maxSize];
            size = 0;
        }

        public int size() {
            return size;
        }

        public DisiWrapper top() {
            return heap[0];
        }

        /** Get the list of scorers which are on the current doc. */
        DisiWrapper topList() {
            final DisiWrapper[] heap = this.heap;
            final int size = this.size;
            DisiWrapper list = heap[0];
            list.next = null;
            if (size >= 3) {
                list = topList(list, heap, size, 1);
                list = topList(list, heap, size, 2);
            } else if (size == 2 && heap[1].doc == list.doc) {
                list = prepend(heap[1], list);
            }
            return list;
        }

        // prepend w1 (iterator) to w2 (list)
        private DisiWrapper prepend(DisiWrapper w1, DisiWrapper w2) {
            w1.next = w2;
            return w1;
        }

        private DisiWrapper topList(DisiWrapper list, DisiWrapper[] heap,
                                    int size, int i) {
            final DisiWrapper w = heap[i];
            if (w.doc == list.doc) {
                list = prepend(w, list);
                final int left = leftNode(i);
                final int right = left + 1;
                if (right < size) {
                    list = topList(list, heap, size, left);
                    list = topList(list, heap, size, right);
                } else if (left < size && heap[left].doc == list.doc) {
                    list = prepend(heap[left], list);
                }
            }
            return list;
        }

        public DisiWrapper add(DisiWrapper entry) {
            final DisiWrapper[] heap = this.heap;
            final int size = this.size;
            heap[size] = entry;
            upHeap(size);
            this.size = size + 1;
            return heap[0];
        }

        public DisiWrapper pop() {
            final DisiWrapper[] heap = this.heap;
            final DisiWrapper result = heap[0];
            final int i = --size;
            heap[0] = heap[i];
            heap[i] = null;
            downHeap(i);
            return result;
        }

        DisiWrapper updateTop() {
            downHeap(size);
            return heap[0];
        }

        void upHeap(int i) {
            final DisiWrapper node = heap[i];
            final int nodeDoc = node.doc;
            int j = parentNode(i);
            while (j >= 0 && nodeDoc < heap[j].doc) {
                heap[i] = heap[j];
                i = j;
                j = parentNode(j);
            }
            heap[i] = node;
        }

        void downHeap(int size) {
            int i = 0;
            final DisiWrapper node = heap[0];
            int j = leftNode(i);
            if (j < size) {
                int k = rightNode(j);
                if (k < size && heap[k].doc < heap[j].doc) {
                    j = k;
                }
                if (heap[j].doc < node.doc) {
                    do {
                        heap[i] = heap[j];
                        i = j;
                        j = leftNode(i);
                        k = rightNode(j);
                        if (k < size && heap[k].doc < heap[j].doc) {
                            j = k;
                        }
                    } while (j < size && heap[j].doc < node.doc);
                    heap[i] = node;
                }
            }
        }

        @Override
        public Iterator iterator() {
            return Arrays.asList(heap).subList(0, size).iterator();
        }

    }

    static class DisjunctionDISIApproximation extends DocIdSetIterator {

        final DisiPriorityQueue subIterators;
        final long cost;

        DisjunctionDISIApproximation(DisiPriorityQueue subIterators) {
            this.subIterators = subIterators;
            long cost = 0;
            for (DisiWrapper w : subIterators) {
                cost += w.cost;
            }
            this.cost = cost;
        }

        @Override
        public long cost() {
            return cost;
        }

        @Override
        public int docID() {
            return subIterators.top().doc;
        }

        @Override
        public int nextDoc() throws IOException {
            DisiWrapper top = subIterators.top();
            final int doc = top.doc;
            do {
                top.doc = top.approximation.nextDoc();
                top = subIterators.updateTop();
            } while (top.doc == doc);

            return top.doc;
        }

        @Override
        public int advance(int target) throws IOException {
            DisiWrapper top = subIterators.top();
            do {
                top.doc = top.approximation.advance(target);
                top = subIterators.updateTop();
            } while (top.doc < target);

            return top.doc;
        }
    }

    static class DisjunctionIntervalIterator extends IntervalIterator {

        final DocIdSetIterator approximation;
        final PriorityQueue intervalQueue;
        final DisiPriorityQueue disiQueue;
        final List iterators;
        final float matchCost;

        IntervalIterator current = EMPTY;

        DisjunctionIntervalIterator(List iterators) {
            this.disiQueue = new DisiPriorityQueue(iterators.size());
            for (IntervalIterator it : iterators) {
                disiQueue.add(new DisiWrapper(it));
            }
            this.approximation = new DisjunctionDISIApproximation(disiQueue);
            this.iterators = iterators;
            this.intervalQueue = new PriorityQueue(iterators.size()) {
                @Override
                protected boolean lessThan(IntervalIterator a, IntervalIterator b) {
                    return a.end() < b.end() || (a.end() == b.end() && a.start() >= b.start());
                }
            };
            float costsum = 0;
            for (IntervalIterator it : iterators) {
                costsum += it.cost();
            }
            this.matchCost = costsum;
        }

        @Override
        public float matchCost() {
            return matchCost;
        }

        @Override
        public int start() {
            return current.start();
        }

        @Override
        public int end() {
            return current.end();
        }

        @Override
        public int gaps() {
            return current.gaps();
        }

        private void reset() throws IOException {
            intervalQueue.clear();
            for (DisiWrapper dw = disiQueue.topList(); dw != null; dw = dw.next) {
                dw.intervals.nextInterval();
                intervalQueue.add(dw.intervals);
            }
            current = EMPTY;
        }

        @Override
        public int nextInterval() throws IOException {
            if (current == EMPTY || current == EXHAUSTED) {
                if (intervalQueue.size() > 0) {
                    current = intervalQueue.top();
                }
                return current.start();
            }
            int start = current.start(), end = current.end();
            while (intervalQueue.size() > 0 && contains(intervalQueue.top(), start, end)) {
                IntervalIterator it = intervalQueue.pop();
                if (it != null && it.nextInterval() != NO_MORE_INTERVALS) {
                    intervalQueue.add(it);
                }
            }
            if (intervalQueue.size() == 0) {
                current = EXHAUSTED;
                return NO_MORE_INTERVALS;
            }
            current = intervalQueue.top();
            return current.start();
        }

        private boolean contains(IntervalIterator it, int start, int end) {
            return start >= it.start() && start <= it.end() && end >= it.start() && end <= it.end();
        }

        @Override
        public int docID() {
            return approximation.docID();
        }

        @Override
        public int nextDoc() throws IOException {
            int doc = approximation.nextDoc();
            reset();
            return doc;
        }

        @Override
        public int advance(int target) throws IOException {
            int doc = approximation.advance(target);
            reset();
            return doc;
        }

        @Override
        public long cost() {
            return approximation.cost();
        }
    }

    private static final IntervalIterator EMPTY = new IntervalIterator() {

        @Override
        public int docID() {
            throw new UnsupportedOperationException();
        }

        @Override
        public int nextDoc() {
            throw new UnsupportedOperationException();
        }

        @Override
        public int advance(int target) {
            throw new UnsupportedOperationException();
        }

        @Override
        public long cost() {
            throw new UnsupportedOperationException();
        }

        @Override
        public int start() {
            return -1;
        }

        @Override
        public int end() {
            return -1;
        }

        @Override
        public int gaps() {
            throw new UnsupportedOperationException();
        }

        @Override
        public int nextInterval() {
            return NO_MORE_INTERVALS;
        }

        @Override
        public float matchCost() {
            return 0;
        }
    };

    private static final IntervalIterator EXHAUSTED = new IntervalIterator() {

        @Override
        public int docID() {
            throw new UnsupportedOperationException();
        }

        @Override
        public int nextDoc() {
            throw new UnsupportedOperationException();
        }

        @Override
        public int advance(int target) {
            throw new UnsupportedOperationException();
        }

        @Override
        public long cost() {
            throw new UnsupportedOperationException();
        }

        @Override
        public int start() {
            return NO_MORE_INTERVALS;
        }

        @Override
        public int end() {
            return NO_MORE_INTERVALS;
        }

        @Override
        public int gaps() {
            throw new UnsupportedOperationException();
        }

        @Override
        public int nextInterval() {
            return NO_MORE_INTERVALS;
        }

        @Override
        public float matchCost() {
            return 0;
        }
    };

    static class TermIntervalsSource extends IntervalsSource {

        final BytesRef term;

        TermIntervalsSource(BytesRef term) {
            this.term = term;
        }

        @Override
        public IntervalIterator intervals(String field, LeafReaderContext ctx) throws IOException {
            Terms terms = ctx.reader().terms(field);
            if (terms == null)
                return null;
            if (terms.hasPositions() == false) {
                throw new IllegalArgumentException("Cannot create an IntervalIterator over field " + field
                    + " because it has no indexed positions");
            }
            TermsEnum te = terms.iterator();
            if (te.seekExact(term) == false) {
                return null;
            }
            return intervals(term, te);
        }

        static IntervalIterator intervals(BytesRef term, TermsEnum te) throws IOException {
            PostingsEnum pe = te.postings(null, PostingsEnum.POSITIONS);
            float cost = termPositionsCost(te);
            return new IntervalIterator() {

                @Override
                public int docID() {
                    return pe.docID();
                }

                @Override
                public int nextDoc() throws IOException {
                    int doc = pe.nextDoc();
                    reset();
                    return doc;
                }

                @Override
                public int advance(int target) throws IOException {
                    int doc = pe.advance(target);
                    reset();
                    return doc;
                }

                @Override
                public long cost() {
                    return pe.cost();
                }

                int pos = -1, upto;

                @Override
                public int start() {
                    return pos;
                }

                @Override
                public int end() {
                    return pos;
                }

                @Override
                public int gaps() {
                    return 0;
                }

                @Override
                public int nextInterval() throws IOException {
                    if (upto <= 0)
                        return pos = NO_MORE_INTERVALS;
                    upto--;
                    return pos = pe.nextPosition();
                }

                @Override
                public float matchCost() {
                    return cost;
                }

                private void reset() throws IOException {
                    if (pe.docID() == NO_MORE_DOCS) {
                        upto = -1;
                        pos = NO_MORE_INTERVALS;
                    }
                    else {
                        upto = pe.freq();
                        pos = -1;
                    }
                }

                @Override
                public String toString() {
                    return term.utf8ToString() + ":" + super.toString();
                }
            };
        }

        @Override
        public MatchesIterator matches(String field, LeafReaderContext ctx, int doc) throws IOException {
            Terms terms = ctx.reader().terms(field);
            if (terms == null)
                return null;
            if (terms.hasPositions() == false) {
                throw new IllegalArgumentException("Cannot create an IntervalIterator over field " + field
                    + " because it has no indexed positions");
            }
            TermsEnum te = terms.iterator();
            if (te.seekExact(term) == false) {
                return null;
            }
            return matches(te, doc);
        }

        static MatchesIterator matches(TermsEnum te, int doc) throws IOException {
            PostingsEnum pe = te.postings(null, PostingsEnum.OFFSETS);
            if (pe.advance(doc) != doc) {
                return null;
            }
            return new MatchesIterator() {

                int upto = pe.freq();
                int pos = -1;

                @Override
                public boolean next() throws IOException {
                    if (upto <= 0) {
                        pos = IntervalIterator.NO_MORE_INTERVALS;
                        return false;
                    }
                    upto--;
                    pos = pe.nextPosition();
                    return true;
                }

                @Override
                public int startPosition() {
                    return pos;
                }

                @Override
                public int endPosition() {
                    return pos;
                }

                @Override
                public int startOffset() throws IOException {
                    return pe.startOffset();
                }

                @Override
                public int endOffset() throws IOException {
                    return pe.endOffset();
                }

                @Override
                public MatchesIterator getSubMatches() {
                    return null;
                }

                @Override
                public Query getQuery() {
                    throw new UnsupportedOperationException();
                }
            };
        }

        @Override
        public int minExtent() {
            return 1;
        }

        @Override
        public Collection pullUpDisjunctions() {
            return Collections.singleton(this);
        }

        @Override
        public int hashCode() {
            return Objects.hash(term);
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (o == null || getClass() != o.getClass()) return false;
            TermIntervalsSource that = (TermIntervalsSource) o;
            return Objects.equals(term, that.term);
        }

        @Override
        public String toString() {
            return term.utf8ToString();
        }

        @Override
        public void visit(String field, QueryVisitor visitor) {
            visitor.consumeTerms(new IntervalQuery(field, this), new Term(field, term));
        }

        private static final int TERM_POSNS_SEEK_OPS_PER_DOC = 128;

        private static final int TERM_OPS_PER_POS = 7;

        static float termPositionsCost(TermsEnum termsEnum) throws IOException {
            int docFreq = termsEnum.docFreq();
            assert docFreq > 0;
            long totalTermFreq = termsEnum.totalTermFreq();
            float expOccurrencesInMatchingDoc = totalTermFreq / (float) docFreq;
            return TERM_POSNS_SEEK_OPS_PER_DOC + expOccurrencesInMatchingDoc * TERM_OPS_PER_POS;
        }
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy