All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.o19s.es.ltr.query.RankerQuery Maven / Gradle / Ivy

There is a newer version: 6.8.0
Show newest version
/*
 * Copyright [2017] Doug Turnbull, Wikimedia Foundation
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.o19s.es.ltr.query;

import com.o19s.es.ltr.LtrQueryContext;
import com.o19s.es.ltr.feature.Feature;
import com.o19s.es.ltr.feature.FeatureSet;
import com.o19s.es.ltr.feature.LtrModel;
import com.o19s.es.ltr.feature.PrebuiltLtrModel;
import com.o19s.es.ltr.ranker.LogLtrRanker;
import com.o19s.es.ltr.ranker.LtrRanker;
import com.o19s.es.ltr.ranker.NullRanker;
import com.o19s.es.ltr.utils.Suppliers;
import com.o19s.es.ltr.utils.Suppliers.MutableSupplier;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.ConstantScoreScorer;
import org.apache.lucene.search.ConstantScoreWeight;
import org.apache.lucene.search.DisiPriorityQueue;
import org.apache.lucene.search.DisiWrapper;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.RandomAccess;
import java.util.Set;
import java.util.stream.Stream;

/**
 * Lucene query designed to apply a ranking model provided by {@link LtrRanker}
 * This query is not designed for retrieval, in other words it will score
 * all the docs in the index and thus must be used either in a rescore phase
 * or within a BooleanQuery and an appropriate filter clause.
 */
public class RankerQuery extends Query {
    private final List queries;
    private final FeatureSet features;
    private final LtrRanker ranker;

    private RankerQuery(List queries, FeatureSet features, LtrRanker ranker) {
        this.queries = Objects.requireNonNull(queries);
        this.features = Objects.requireNonNull(features);
        this.ranker = Objects.requireNonNull(ranker);
    }

    /**
     * Build a RankerQuery based on a prebuilt model.
     * Prebuilt models are not parametrized as they contain only {@link com.o19s.es.ltr.feature.PrebuiltFeature}
     *
     * @param model a prebuilt model
     * @return the lucene query
     */
    public static RankerQuery build(PrebuiltLtrModel model) {
        return build(model.ranker(), model.featureSet(), new LtrQueryContext(null, Collections.emptySet()), Collections.emptyMap());
    }

    /**
     * Build a RankerQuery.
     *
     * @param model   The model
     * @param context the context used to parse features into lucene queries
     * @param params  the query params
     * @return the lucene query
     */
    public static RankerQuery build(LtrModel model, LtrQueryContext context, Map params) {
        return build(model.ranker(), model.featureSet(), context, params);
    }

    private static RankerQuery build(LtrRanker ranker, FeatureSet features, LtrQueryContext context, Map params) {
        List queries = features.toQueries(context, params);
        return new RankerQuery(queries, features, ranker);
    }

    public static RankerQuery buildLogQuery(LogLtrRanker.LogConsumer consumer, FeatureSet features,
                                            LtrQueryContext context, Map params) {
        List queries = features.toQueries(context, params);
        return new RankerQuery(queries, features, new LogLtrRanker(consumer, features.size()));
    }

    public RankerQuery toLoggerQuery(LogLtrRanker.LogConsumer consumer, boolean replaceWithNullRanker) {
        LtrRanker newRanker = ranker;
        if (replaceWithNullRanker && !(ranker instanceof NullRanker)) {
            newRanker = new NullRanker(features.size());
        }
        return new RankerQuery(queries, features, new LogLtrRanker(newRanker, consumer));
    }

    @Override
    public Query rewrite(IndexReader reader) throws IOException {
        List rewrittenQueries = new ArrayList<>(queries.size());
        boolean rewritten = false;
        for (Query query : queries) {
            Query rewrittenQuery = query.rewrite(reader);
            rewritten |= rewrittenQuery != query;
            rewrittenQueries.add(rewrittenQuery);
        }
        return rewritten ? new RankerQuery(rewrittenQueries, features, ranker) : this;
    }

    @SuppressWarnings("EqualsWhichDoesntCheckParameterClass")
    @Override
    public boolean equals(Object obj) {
        // This query should never be cached
        if (this == obj) {
            return true;
        }
        if (!sameClassAs(obj)) {
            return false;
        }
        RankerQuery that = (RankerQuery) obj;
        return Objects.deepEquals(queries, that.queries) &&
                Objects.deepEquals(features, that.features) &&
                Objects.equals(ranker, that.ranker);
    }

    Stream stream() {
        return queries.stream();
    }

    @Override
    public int hashCode() {
        return 31 * classHash() + Objects.hash(features, queries, ranker);
    }

    @Override
    public String toString(String field) {
        return "rankerquery:" + field;
    }

    /**
     * Return feature at ordinal
     */
    Feature getFeature(int ordinal) {
        return features.feature(ordinal);
    }

    /**
     * The ranker used by this query
     */
    LtrRanker ranker() {
        return ranker;
    }

    public FeatureSet featureSet() {
        return features;
    }

    @Override
    public Weight createWeight(IndexSearcher searcher, boolean needsScores, float boost) throws IOException {
        if (!needsScores) {
            // If scores are not needed simply return a constant score on all docs
            return new ConstantScoreWeight(this, boost) {
                @Override
                public Scorer scorer(LeafReaderContext context) throws IOException {
                    return new ConstantScoreScorer(this, score(), DocIdSetIterator.all(context.reader().maxDoc()));
                }

                @Override
                public boolean isCacheable(LeafReaderContext ctx) {
                    return false;
                }
            };
        }

        List weights = new ArrayList<>(queries.size());
        // XXX: this is not thread safe and may run into extremely weird issues
        // if the searcher uses the parallel collector
        // Hopefully elastic never runs
        MutableSupplier vectorSupplier = new Suppliers.FeatureVectorSupplier();
        FVLtrRankerWrapper ltrRankerWrapper = new FVLtrRankerWrapper(ranker, vectorSupplier);
        for (Query q : queries) {
            if (q instanceof LtrRewritableQuery) {
                q = ((LtrRewritableQuery)q).ltrRewrite(vectorSupplier);
            }
            weights.add(searcher.createWeight(q, true, boost));
        }
        return new RankerWeight(this, weights, ltrRankerWrapper, features);
    }

    public static class RankerWeight extends Weight {
        private final List weights;
        private final FVLtrRankerWrapper ranker;
        private final FeatureSet features;

        RankerWeight(RankerQuery query, List weights, FVLtrRankerWrapper ranker, FeatureSet features) {
            super(query);
            assert weights instanceof RandomAccess;
            this.weights = weights;
            this.ranker = Objects.requireNonNull(ranker);
            this.features = Objects.requireNonNull(features);
        }

        @Override
        public boolean isCacheable(LeafReaderContext ctx) {
            return false;
        }

        @Override
        public void extractTerms(Set terms) {
            for (Weight w : weights) {
                w.extractTerms(terms);
            }
        }

        @Override
        public Explanation explain(LeafReaderContext context, int doc) throws IOException {
            List subs = new ArrayList<>(weights.size());

            LtrRanker.FeatureVector d = ranker.newFeatureVector(null);
            int ordinal = -1;
            for (Weight weight : weights) {
                ordinal++;
                final Explanation explain;
                explain = weight.explain(context, doc);
                String featureString = "Feature " + Integer.toString(ordinal);
                if (features.feature(ordinal).name() != null) {
                    featureString += "(" + features.feature(ordinal).name() + ")";
                }
                featureString += ":";
                if (!explain.isMatch()) {
                    subs.add(Explanation.noMatch(featureString + " [no match, default value 0.0 used]"));
                } else {
                    subs.add(Explanation.match(explain.getValue(), featureString, explain));
                    d.setFeatureScore(ordinal, explain.getValue());
                }
            }
            float modelScore = ranker.score(d);
            return Explanation.match(modelScore, " LtrModel: " + ranker.name() + " using features:", subs);
        }

        @Override
        public RankerScorer scorer(LeafReaderContext context) throws IOException {
            List scorers = new ArrayList<>(weights.size());
            DisiPriorityQueue disiPriorityQueue = new DisiPriorityQueue(weights.size());
            for (Weight weight : weights) {
                Scorer scorer = weight.scorer(context);
                if (scorer == null) {
                    scorer = new NoopScorer(this, DocIdSetIterator.empty());
                }
                scorers.add(scorer);
                disiPriorityQueue.add(new DisiWrapper(scorer));
            }

            DisjunctionDISI rankerIterator = new DisjunctionDISI(
                    DocIdSetIterator.all(context.reader().maxDoc()), disiPriorityQueue);
            return new RankerScorer(scorers, rankerIterator, ranker);
        }

        class RankerScorer extends Scorer {
            /**
             * NOTE: Switch to ChildScorer and {@link #getChildren()} if it appears
             * to be useful for logging
             */
            private final List scorers;
            private final DisjunctionDISI iterator;
            private final FVLtrRankerWrapper ranker;
            private LtrRanker.FeatureVector fv;

            RankerScorer(List scorers, DisjunctionDISI iterator, FVLtrRankerWrapper ranker) {
                super(RankerWeight.this);
                this.scorers = scorers;
                this.iterator = iterator;
                this.ranker = ranker;
            }

            @Override
            public int docID() {
                return iterator.docID();
            }

            @Override
            public float score() throws IOException {
                fv = ranker.newFeatureVector(fv);
                int ordinal = -1;
                // a DisiPriorityQueue could help to avoid
                // looping on all scorers
                for (Scorer scorer : scorers) {
                    ordinal++;
                    // FIXME: Probably inefficient, again we loop over all scorers..
                    if (scorer.docID() == docID()) {
                        float score = scorer.score();
                        // XXX: bold assumption that all models are dense
                        // do we need a some indirection to infer the featureId?
                        fv.setFeatureScore(ordinal, score);
                    }
                }
                return ranker.score(fv);
            }

//            @Override
//            public int freq() throws IOException {
//                return scorers.size();
//            }

            @Override
            public DocIdSetIterator iterator() {
                return iterator;
            }
        }
    }

    /**
     * Driven by a main iterator and tries to maintain a list of sub iterators
     * Mostly needed to avoid calling {@link Scorer#iterator()} to directly advance
     * from {@link RankerWeight.RankerScorer#score()} as some Scorer implementations
     * will instantiate new objects every time iterator() is called.
     */
    static class DisjunctionDISI extends DocIdSetIterator {
        private final DocIdSetIterator main;
        private final DisiPriorityQueue subIteratorsPriorityQueue;

        DisjunctionDISI(DocIdSetIterator main, DisiPriorityQueue subIteratorsPriorityQueue) {
            this.main = main;
            this.subIteratorsPriorityQueue = subIteratorsPriorityQueue;
        }

        @Override
        public int docID() {
            return main.docID();
        }

        @Override
        public int nextDoc() throws IOException {
            int doc = main.nextDoc();
            advanceSubIterators(doc);
            return doc;
        }

        @Override
        public int advance(int target) throws IOException {
            int docId = main.advance(target);
            advanceSubIterators(docId);
            return docId;
        }

        private void advanceSubIterators(int target) throws IOException {
            if (target == NO_MORE_DOCS) {
                return;
            }
            DisiWrapper top = subIteratorsPriorityQueue.top();
            while (top.doc < target) {
                top.doc = top.iterator.advance(target);
                top = subIteratorsPriorityQueue.updateTop();
            }
        }

        @Override
        public long cost() {
            return main.cost();
        }
    }

    static class FVLtrRankerWrapper implements LtrRanker {
        private final LtrRanker wrapped;
        private final MutableSupplier vectorSupplier;

        FVLtrRankerWrapper(LtrRanker wrapped, MutableSupplier vectorSupplier) {
            this.wrapped = Objects.requireNonNull(wrapped);
            this.vectorSupplier = Objects.requireNonNull(vectorSupplier);
        }

        @Override
        public String name() {
            return wrapped.name();
        }

        @Override
        public FeatureVector newFeatureVector(FeatureVector reuse) {
            FeatureVector fv = wrapped.newFeatureVector(reuse);
            vectorSupplier.set(fv);
            return fv;
        }

        @Override
        public float score(FeatureVector point) {
            return wrapped.score(point);
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (o == null || getClass() != o.getClass()) return false;
            FVLtrRankerWrapper that = (FVLtrRankerWrapper) o;
            return Objects.equals(wrapped, that.wrapped) &&
                    Objects.equals(vectorSupplier, that.vectorSupplier);
        }

        @Override
        public int hashCode() {
            return Objects.hash(wrapped, vectorSupplier);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy