All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.o19s.es.ltr.logging.LoggingFetchSubPhase Maven / Gradle / Ivy

There is a newer version: 6.8.0
Show newest version
/*
 * Copyright [2017] Wikimedia Foundation
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.o19s.es.ltr.logging;

import com.o19s.es.ltr.feature.FeatureSet;
import com.o19s.es.ltr.query.RankerQuery;
import com.o19s.es.ltr.ranker.LogLtrRanker;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.BoostQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.document.DocumentField;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.fetch.FetchPhaseExecutionException;
import org.elasticsearch.search.fetch.FetchSubPhase;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.rescore.QueryRescorer;
import org.elasticsearch.search.rescore.RescoreContext;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public class LoggingFetchSubPhase implements FetchSubPhase {
    @Override
    public void hitsExecute(SearchContext context, SearchHit[] hits) throws IOException {
        LoggingSearchExtBuilder ext = (LoggingSearchExtBuilder) context.getSearchExt(LoggingSearchExtBuilder.NAME);
        if (ext == null) {
            return;
        }

        // Use a boolean query with all the models to log
        // This way we reuse existing code to advance through multiple scorers/iterators
        BooleanQuery.Builder builder = new BooleanQuery.Builder();
        List loggers = new ArrayList<>();
        Map namedQueries = context.parsedQuery().namedFilters();
        ext.logSpecsStream().filter((l) -> l.getNamedQuery() != null).forEach((l) -> {
            Tuple query = extractQuery(l, namedQueries);
            builder.add(new BooleanClause(query.v1(), BooleanClause.Occur.MUST));
            loggers.add(query.v2());
        });

        ext.logSpecsStream().filter((l) -> l.getRescoreIndex() != null).forEach((l) -> {
            Tuple query = extractRescore(l, context.rescore());
            builder.add(new BooleanClause(query.v1(), BooleanClause.Occur.MUST));
            loggers.add(query.v2());
        });

        try {
            doLog(builder.build(), loggers, context.searcher(), hits);
        } catch (LtrLoggingException e) {
            throw new FetchPhaseExecutionException(context, e.getMessage(), e);
        }
    }

    void doLog(Query query, List loggers, IndexSearcher searcher, SearchHit[] hits) throws IOException {
        // Reorder hits by id so we can scan all the docs belonging to the same
        // segment by reusing the same scorer.
        SearchHit[] reordered = new SearchHit[hits.length];
        System.arraycopy(hits, 0, reordered, 0, hits.length);
        Arrays.sort(reordered, Comparator.comparingInt(SearchHit::docId));

        int hitUpto = 0;
        int readerUpto = -1;
        int endDoc = 0;
        int docBase = 0;
        Scorer scorer = null;
        Weight weight = searcher.createNormalizedWeight(query, true);
        // Loop logic borrowed from lucene QueryRescorer
        while (hitUpto < reordered.length) {
            SearchHit hit = reordered[hitUpto];
            int docID = hit.docId();
            loggers.forEach((l) -> l.nextDoc(hit));
            LeafReaderContext readerContext = null;
            while (docID >= endDoc) {
                readerUpto++;
                readerContext = searcher.getTopReaderContext().leaves().get(readerUpto);
                endDoc = readerContext.docBase + readerContext.reader().maxDoc();
            }

            if (readerContext != null) {
                // We advanced to another segment:
                docBase = readerContext.docBase;
                scorer = weight.scorer(readerContext);
            }

            if(scorer != null) {
                int targetDoc = docID - docBase;
                int actualDoc = scorer.docID();
                if (actualDoc < targetDoc) {
                    actualDoc = scorer.iterator().advance(targetDoc);
                }
                if (actualDoc == targetDoc) {
                    // Scoring will trigger log collection
                    scorer.score();
                }
            }

            hitUpto++;
        }
    }

    private Tuple extractQuery(LoggingSearchExtBuilder.LogSpec logSpec, Map namedQueries) {
        Query q = namedQueries.get(logSpec.getNamedQuery());
        if (q == null) {
            throw new IllegalArgumentException("No query named [" + logSpec.getNamedQuery() + "] found");
        }
        return toLogger(logSpec, inspectQuery(q)
                .orElseThrow(() -> new IllegalArgumentException("Query named [" + logSpec.getNamedQuery() +
                    "] must be a [sltr] query [" +
                    ((q instanceof BoostQuery) ? ((BoostQuery)q).getQuery().getClass().getSimpleName()  : q.getClass().getSimpleName()) +
                    "] found")));
    }

    private Tuple extractRescore(LoggingSearchExtBuilder.LogSpec logSpec,
                                                              List contexts) {
        if (logSpec.getRescoreIndex() >= contexts.size()) {
            throw new IllegalArgumentException("rescore index [" + logSpec.getRescoreIndex()+"] is out of bounds, only " +
                    "[" + contexts.size() + "] rescore context(s) are available");
        }
        RescoreContext context = contexts.get(logSpec.getRescoreIndex());
        if (!(context instanceof QueryRescorer.QueryRescoreContext)) {
            throw new IllegalArgumentException("Expected a [QueryRescoreContext] but found a " +
                    "[" + context.getClass().getSimpleName() + "] " +
                    "at index [" + logSpec.getRescoreIndex() + "]");
        }
        QueryRescorer.QueryRescoreContext qrescore = (QueryRescorer.QueryRescoreContext) context;
        return toLogger(logSpec, inspectQuery(qrescore.query())
                .orElseThrow(() -> new IllegalArgumentException("Expected a [sltr] query but found a " +
                    "[" + qrescore.query().getClass().getSimpleName() + "] " +
                    "at index [" + logSpec.getRescoreIndex() + "]")));
    }

    private Optional inspectQuery(Query q) {
        if (q instanceof RankerQuery) {
            return Optional.of((RankerQuery) q);
        } else if (q instanceof BoostQuery && ((BoostQuery) q).getQuery() instanceof RankerQuery) {
            return Optional.of((RankerQuery) ((BoostQuery) q).getQuery());
        }
        return Optional.empty();
    }

    private Tuple toLogger(LoggingSearchExtBuilder.LogSpec logSpec, RankerQuery query) {
        HitLogConsumer consumer = new HitLogConsumer(logSpec.getLoggerName(), query.featureSet(), logSpec.isMissingAsZero());
        // Use a null ranker, we don't care about the final score here so don't spend time on it.
        query = query.toLoggerQuery(consumer, true);

        return new Tuple<>(query, consumer);
    }

    static class HitLogConsumer implements LogLtrRanker.LogConsumer {
        private static final String FIELD_NAME = "_ltrlog";
        private final String name;
        private final FeatureSet set;
        private final boolean missingAsZero;

        // [
        //      {
        //          "name": "featureName",
        //          "value": 1.33
        //      },
        //      {
        //          "name": "otherFeatureName",
        //      }
        // ]
        private List> currentLog;
        private SearchHit currentHit;


        HitLogConsumer(String name, FeatureSet set, boolean missingAsZero) {
            this.name = name;
            this.set = set;
            this.missingAsZero = missingAsZero;
        }

        private void rebuild() {
            List> ini = new ArrayList<>(set.size()) ;

            for (int i = 0; i < set.size(); i++) {
                Map defaultKeyVal = new HashMap<>();
                defaultKeyVal.put("name", set.feature(i).name());
                if (missingAsZero) {
                    defaultKeyVal.put("value", 0.0F);
                }
                ini.add(i, defaultKeyVal);
            }
            currentLog = ini;
        }

        @Override
        public void accept(int featureOrdinal, float score) {
            assert currentLog != null;
            assert currentHit != null;
            if (Float.isNaN(score)) {
                // NOTE: should we fail on Float#isInfinite() as well?
                throw new LtrLoggingException("Feature [" + set.feature(featureOrdinal).name() +"] produced a NaN value " +
                        "for doc [" + currentHit.getId() + "]" );
            }
            currentLog.get(featureOrdinal).put("value", score);
        }

        void nextDoc(SearchHit hit) {
            if (hit.fieldsOrNull() == null) {
                hit.fields(new HashMap<>());
            }
            DocumentField logs = hit.getFields()
                    .computeIfAbsent(FIELD_NAME,(k) -> newLogField());
            Map>> entries = logs.getValue();
            rebuild();
            currentHit = hit;
            entries.put(name, currentLog);
        }

        DocumentField newLogField() {
            List logList = Collections.singletonList(new HashMap>>());
            return new DocumentField(FIELD_NAME, logList);
        }
    }
}