org.apache.lucene.queries.intervals.DisjunctionIntervalsSource Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.queries.intervals;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.MatchesIterator;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.util.PriorityQueue;
class DisjunctionIntervalsSource extends IntervalsSource {
final Collection subSources;
final boolean pullUpDisjunctions;
static IntervalsSource create(
Collection subSources, boolean pullUpDisjunctions) {
subSources = simplify(subSources);
if (subSources.size() == 1) {
return subSources.iterator().next();
}
return new DisjunctionIntervalsSource(subSources, pullUpDisjunctions);
}
private DisjunctionIntervalsSource(
Collection subSources, boolean pullUpDisjunctions) {
this.subSources = simplify(subSources);
this.pullUpDisjunctions = pullUpDisjunctions;
}
private static Collection simplify(Collection sources) {
Set simplified = new HashSet<>();
for (IntervalsSource source : sources) {
if (source instanceof DisjunctionIntervalsSource) {
simplified.addAll(source.pullUpDisjunctions());
} else {
simplified.add(source);
}
}
return simplified;
}
@Override
public IntervalIterator intervals(String field, LeafReaderContext ctx) throws IOException {
List subIterators = new ArrayList<>();
for (IntervalsSource subSource : subSources) {
IntervalIterator it = subSource.intervals(field, ctx);
if (it != null) {
subIterators.add(it);
}
}
if (subIterators.isEmpty()) {
return null;
}
return new DisjunctionIntervalIterator(subIterators);
}
@Override
public IntervalMatchesIterator matches(String field, LeafReaderContext ctx, int doc)
throws IOException {
List subMatches = new ArrayList<>();
for (IntervalsSource subSource : subSources) {
IntervalMatchesIterator mi = subSource.matches(field, ctx, doc);
if (mi != null) {
subMatches.add(mi);
}
}
if (subMatches.isEmpty()) {
return null;
}
DisjunctionIntervalIterator it =
new DisjunctionIntervalIterator(
subMatches.stream().map(m -> IntervalMatches.wrapMatches(m, doc)).toList());
if (it.advance(doc) != doc) {
return null;
}
return new DisjunctionMatchesIterator(it, subMatches);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
DisjunctionIntervalsSource that = (DisjunctionIntervalsSource) o;
return Objects.equals(subSources, that.subSources);
}
@Override
public int hashCode() {
return Objects.hash(subSources);
}
@Override
public String toString() {
return subSources.stream()
.map(Object::toString)
.sorted()
.collect(Collectors.joining(",", "or(", ")"));
}
@Override
public void visit(String field, QueryVisitor visitor) {
Query parent = new IntervalQuery(field, this);
QueryVisitor v = visitor.getSubVisitor(BooleanClause.Occur.SHOULD, parent);
for (IntervalsSource source : subSources) {
source.visit(field, v);
}
}
@Override
public int minExtent() {
int minExtent = Integer.MAX_VALUE;
for (IntervalsSource subSource : subSources) {
minExtent = Math.min(minExtent, subSource.minExtent());
}
return minExtent;
}
@Override
public Collection pullUpDisjunctions() {
if (pullUpDisjunctions) {
return subSources;
}
return Collections.singletonList(this);
}
static class DisjunctionIntervalIterator extends IntervalIterator {
final DocIdSetIterator approximation;
final PriorityQueue intervalQueue;
final DisiPriorityQueue disiQueue;
final List iterators;
final float matchCost;
IntervalIterator current = EMPTY;
DisjunctionIntervalIterator(List iterators) {
this.disiQueue = new DisiPriorityQueue(iterators.size());
for (IntervalIterator it : iterators) {
disiQueue.add(new DisiWrapper(it));
}
this.approximation = new DisjunctionDISIApproximation(disiQueue);
this.iterators = iterators;
this.intervalQueue =
new PriorityQueue<>(iterators.size()) {
@Override
protected boolean lessThan(IntervalIterator a, IntervalIterator b) {
return a.end() < b.end() || (a.end() == b.end() && a.start() >= b.start());
}
};
float costsum = 0;
for (IntervalIterator it : iterators) {
costsum += it.cost();
}
this.matchCost = costsum;
}
@Override
public float matchCost() {
return matchCost;
}
@Override
public int start() {
return current.start();
}
@Override
public int end() {
return current.end();
}
@Override
public int gaps() {
return current.gaps();
}
private void reset() throws IOException {
intervalQueue.clear();
for (DisiWrapper dw = disiQueue.topList(); dw != null; dw = dw.next) {
dw.intervals.nextInterval();
intervalQueue.add(dw.intervals);
}
current = EMPTY;
}
int currentOrd() {
assert current != EMPTY && current != EXHAUSTED;
for (int i = 0; i < iterators.size(); i++) {
if (iterators.get(i) == current) {
return i;
}
}
throw new IllegalStateException();
}
@Override
public int nextInterval() throws IOException {
if (current == EMPTY || current == EXHAUSTED) {
if (intervalQueue.size() > 0) {
current = intervalQueue.top();
}
return current.start();
}
int start = current.start(), end = current.end();
while (intervalQueue.size() > 0 && contains(intervalQueue.top(), start, end)) {
IntervalIterator it = intervalQueue.pop();
if (it != null && it.nextInterval() != NO_MORE_INTERVALS) {
intervalQueue.add(it);
}
}
if (intervalQueue.size() == 0) {
current = EXHAUSTED;
return NO_MORE_INTERVALS;
}
current = intervalQueue.top();
return current.start();
}
private boolean contains(IntervalIterator it, int start, int end) {
return start >= it.start() && start <= it.end() && end >= it.start() && end <= it.end();
}
@Override
public int docID() {
return approximation.docID();
}
@Override
public int nextDoc() throws IOException {
int doc = approximation.nextDoc();
reset();
return doc;
}
@Override
public int advance(int target) throws IOException {
int doc = approximation.advance(target);
reset();
return doc;
}
@Override
public long cost() {
return approximation.cost();
}
}
private static final IntervalIterator EMPTY =
new IntervalIterator() {
@Override
public int docID() {
throw new UnsupportedOperationException();
}
@Override
public int nextDoc() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public int advance(int target) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public long cost() {
throw new UnsupportedOperationException();
}
@Override
public int start() {
return -1;
}
@Override
public int end() {
return -1;
}
@Override
public int gaps() {
throw new UnsupportedOperationException();
}
@Override
public int nextInterval() {
return NO_MORE_INTERVALS;
}
@Override
public float matchCost() {
return 0;
}
};
private static final IntervalIterator EXHAUSTED =
new IntervalIterator() {
@Override
public int docID() {
throw new UnsupportedOperationException();
}
@Override
public int nextDoc() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public int advance(int target) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public long cost() {
throw new UnsupportedOperationException();
}
@Override
public int start() {
return NO_MORE_INTERVALS;
}
@Override
public int end() {
return NO_MORE_INTERVALS;
}
@Override
public int gaps() {
throw new UnsupportedOperationException();
}
@Override
public int nextInterval() {
return NO_MORE_INTERVALS;
}
@Override
public float matchCost() {
return 0;
}
};
private record DisjunctionMatchesIterator(
DisjunctionIntervalIterator it, List subs)
implements IntervalMatchesIterator {
@Override
public boolean next() throws IOException {
return it.nextInterval() != IntervalIterator.NO_MORE_INTERVALS;
}
@Override
public int startPosition() {
return it.start();
}
@Override
public int endPosition() {
return it.end();
}
@Override
public int startOffset() throws IOException {
int ord = it.currentOrd();
return subs.get(ord).startOffset();
}
@Override
public int endOffset() throws IOException {
int ord = it.currentOrd();
return subs.get(ord).endOffset();
}
@Override
public MatchesIterator getSubMatches() throws IOException {
int ord = it.currentOrd();
return subs.get(ord).getSubMatches();
}
@Override
public Query getQuery() {
int ord = it.currentOrd();
return subs.get(ord).getQuery();
}
@Override
public int gaps() {
return it.gaps();
}
@Override
public int width() {
return it.width();
}
}
}