All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.wikimedia.highlighter.cirrus.opensearch.CirrusHighlighter Maven / Gradle / Ivy

The newest version!
package org.wikimedia.highlighter.cirrus.opensearch;


import static java.lang.Boolean.FALSE;
import static java.lang.Boolean.TRUE;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.regex.Pattern;

import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.Query;
import org.opensearch.common.logging.Loggers;
import org.opensearch.common.text.Text;
import org.opensearch.common.util.BigArrays;
import org.opensearch.common.util.LocaleUtils;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.search.fetch.FetchPhaseExecutionException;
import org.opensearch.search.fetch.subphase.highlight.HighlightField;
import org.opensearch.search.fetch.subphase.highlight.Highlighter;
import org.opensearch.search.fetch.subphase.highlight.FieldHighlightContext;
import org.opensearch.search.fetch.subphase.highlight.SearchHighlightContext.FieldOptions;
import org.wikimedia.highlighter.cirrus.lucene.hit.AutomatonHitEnum;
import org.wikimedia.highlighter.cirrus.lucene.hit.weight.BasicQueryWeigher;
import org.wikimedia.search.highlighter.cirrus.HitEnum;
import org.wikimedia.search.highlighter.cirrus.Snippet;
import org.wikimedia.search.highlighter.cirrus.Snippet.HitBuilder;
import org.wikimedia.search.highlighter.cirrus.SnippetChooser;
import org.wikimedia.search.highlighter.cirrus.SnippetFormatter;
import org.wikimedia.search.highlighter.cirrus.SnippetWeigher;
import org.wikimedia.search.highlighter.cirrus.hit.ConcatHitEnum;
import org.wikimedia.search.highlighter.cirrus.hit.EmptyHitEnum;
import org.wikimedia.search.highlighter.cirrus.hit.MergingHitEnum;
import org.wikimedia.search.highlighter.cirrus.hit.OverlapMergingHitEnumWrapper;
import org.wikimedia.search.highlighter.cirrus.hit.RegexHitEnum;
import org.wikimedia.search.highlighter.cirrus.hit.ReplayingHitEnum.HitEnumAndLength;
import org.wikimedia.search.highlighter.cirrus.snippet.BasicScoreBasedSnippetChooser;
import org.wikimedia.search.highlighter.cirrus.snippet.BasicSourceOrderSnippetChooser;
import org.wikimedia.search.highlighter.cirrus.snippet.ExponentialSnippetWeigher;
import org.wikimedia.search.highlighter.cirrus.snippet.SumSnippetWeigher;
import org.wikimedia.search.highlighter.cirrus.tools.GraphvizHit;
import org.wikimedia.search.highlighter.cirrus.tools.GraphvizHitEnum;
import org.wikimedia.search.highlighter.cirrus.tools.GraphvizSnippetFormatter;

@SuppressWarnings("checkstyle:classfanoutcomplexity") // to improve if we ever touch that code again
public class CirrusHighlighter implements Highlighter {
    public static final String BC_NAME = "experimental";
    public static final String NAME = "cirrus";
    private static final String CACHE_KEY = "highlight-cirrus";
    private static final Text EMPTY_STRING = new Text("");

    @Override
    public boolean canHighlight(MappedFieldType field) {
        return true;
    }

    @Override
    @SuppressWarnings("checkstyle:IllegalCatch")
    public HighlightField highlight(FieldHighlightContext context) {
        try {
            CacheEntry entry = (CacheEntry) context.cache.get(CACHE_KEY);
            if (entry == null) {
                entry = new CacheEntry();
                context.cache.put(CACHE_KEY, entry);
            }
            HighlightExecutionContext executionContext = new HighlightExecutionContext(context, entry);
            try {
                return executionContext.highlight();
            } finally {
                executionContext.cleanup();
            }
        } catch (Exception e) {
            getLogger(context).error("Failed to highlight field [{}]", context.fieldName, e);
            throw new FetchPhaseExecutionException(context.hitContext.hit().getShard(), "Failed to highlight field [" + context.fieldName + "]", e);
        }
    }

    private Logger getLogger(FieldHighlightContext context) {
        return Loggers.getLogger(CirrusHighlighter.class, context.context.getIndexName());
    }

    static class CacheEntry {
        private final Map queryWeighers = new HashMap<>();
        private Map automatonHitEnumFactories;
        private boolean lastMatched;
        private int lastDocId = -1;
    }

    static class QueryCacheKey {
        private final Query query;
        private final int maxExpandedTerms;
        private final boolean phraseAsTerms;
        private final boolean removeHighFrequencyTermsFromCommonTerms;

        QueryCacheKey(Query query, int maxExpandedTerms, boolean phraseAsTerms, boolean removeHighFrequencyTermsFromCommonTerms) {
            this.query = query;
            this.maxExpandedTerms = maxExpandedTerms;
            this.phraseAsTerms = phraseAsTerms;
            this.removeHighFrequencyTermsFromCommonTerms = removeHighFrequencyTermsFromCommonTerms;
        }

        @Override
        public int hashCode() {
            return Objects.hash(maxExpandedTerms, phraseAsTerms, removeHighFrequencyTermsFromCommonTerms, query);
        }

        @Override
        public boolean equals(Object obj) {
            if (this == obj)
                return true;
            if (obj == null)
                return false;
            if (getClass() != obj.getClass())
                return false;
            QueryCacheKey other = (QueryCacheKey) obj;
            return Objects.equals(maxExpandedTerms, other.maxExpandedTerms)
                    && Objects.equals(phraseAsTerms, other.phraseAsTerms)
                    && Objects.equals(removeHighFrequencyTermsFromCommonTerms, other.removeHighFrequencyTermsFromCommonTerms)
                    && Objects.equals(query, other.query);
        }
    }

    static class HighlightExecutionContext {
        private static final String OPTION_RETURN_DEBUG_GRAPH = "return_debug_graph";
        private static final String OPTION_RETURN_SNIPPETS_WITH_OFFSET = "return_snippets_and_offsets";
        private static final int DEFAULT_MAX_DETERMINIZED_STATES = 20000;
        private final FieldHighlightContext context;
        private final CacheEntry cache;
        private BasicQueryWeigher weigher;
        private FieldWrapper defaultField;
        private List extraFields;
        private SegmenterFactory segmenterFactory;
        private DelayedSegmenter segmenter;
        private boolean scoreMatters;
        private Locale locale;
        private int maxDeterminizedStates;

        HighlightExecutionContext(FieldHighlightContext context, CacheEntry cache) {
            this.context = context;
            this.cache = cache;
        }

        HighlightField highlight() throws IOException {
            if (shouldSkip()) {
                return null;
            }

            // TODO it might be possible to not build the weigher at all if just
            // using regex highlighting
            ensureWeigher();
            scoreMatters = context.field.fieldOptions().scoreOrdered();
            if (!scoreMatters) {
                Boolean topScoring = (Boolean) getOption("top_scoring");
                scoreMatters = topScoring != null && topScoring;
            }
            defaultField = new FieldWrapper(this, context, weigher);
            int numberOfSnippets = context.field.fieldOptions().numberOfFragments();
            if (numberOfSnippets == 0) {
                numberOfSnippets = 1;
            }
            segmenter = new DelayedSegmenter(defaultField);
            List snippets = buildChooser().choose(segmenter, buildHitEnum(), numberOfSnippets);
            if (!snippets.isEmpty()) {
                cache.lastMatched = true;
                return new HighlightField(context.fieldName, formatSnippets(snippets));
            }
            cache.lastMatched = false;
            int noMatchSize = context.field.fieldOptions().noMatchSize();
            if (noMatchSize <= 0) {
                return null;
            }
            List fieldValues = defaultField.getFieldValues();
            if (fieldValues.isEmpty()) {
                return null;
            }
            Text fragment = new Text(getSegmenterFactory().extractNoMatchFragment(fieldValues.get(0), noMatchSize));
            return new HighlightField(context.fieldName, new Text[] {fragment});
        }

        private boolean shouldSkip() {
            // Maintain lastMatched - it should be false if we shift to a new
            // doc.
            if (cache.lastDocId != context.hitContext.docId()) {
                cache.lastMatched = false;
                cache.lastDocId = context.hitContext.docId();
            }

            Boolean skipIfLastMatched = (Boolean) getOption("skip_if_last_matched");
            return skipIfLastMatched != null && skipIfLastMatched && cache.lastMatched;
        }

        @SuppressWarnings("checkstyle:IllegalCatch")
        // We might be able to improve this a bit with AutoClosable magic,
        // but not worth doing it unless we revisit that code.
        void cleanup() throws Exception {
            Exception lastCaught = null;
            try {
                if (defaultField != null) {
                    // If we throw an exception before defining default field
                    // then we can't clean it up!
                    defaultField.cleanup();
                }
            } catch (Exception e) {
                lastCaught = e;
            }
            if (extraFields != null) {
                for (FieldWrapper extra : extraFields) {
                    try {
                        extra.cleanup();
                    } catch (Exception e) {
                        if (lastCaught != null) {
                            e.addSuppressed(lastCaught);
                        }
                        lastCaught = e;
                    }
                }
            }
            if (lastCaught != null) {
                throw lastCaught;
            }
        }

        SegmenterFactory getSegmenterFactory() {
            if (segmenterFactory == null) {
                segmenterFactory = buildSegmenterFactory();
            }
            return segmenterFactory;
        }

        Object getOption(String key) {
            if (context.field.fieldOptions().options() == null) {
                return null;
            }
            return context.field.fieldOptions().options().get(key);
        }

         T getOption(String key, T defaultValue) {
            @SuppressWarnings("unchecked")
            T value = (T) getOption(key);
            return value == null ? defaultValue : value;
        }

        boolean scoreMatters() {
            return scoreMatters;
        }

        private void ensureWeigher() {
            if (weigher != null) {
                return;
            }
            boolean phraseAsTerms = getOption("phrase_as_terms", FALSE);
            boolean removeHighFrequencyTermsFromCommonTerms = getOption("remove_high_freq_terms_from_common_terms", TRUE);
            int maxExpandedTerms = getOption("max_expanded_terms", 1024);
            // TODO simplify
            QueryCacheKey key = new QueryCacheKey(context.query, maxExpandedTerms, phraseAsTerms,
                    removeHighFrequencyTermsFromCommonTerms);
            weigher = cache.queryWeighers.get(key);
            if (weigher != null) {
                return;
            }
            // TODO recycle. But addReleasble doesn't seem to close it properly
            // later. I believe this is fixed in later Elasticsearch versions.
            BytesRefHashTermInfos infos = new BytesRefHashTermInfos(BigArrays.NON_RECYCLING_INSTANCE);
            // context.context.addReleasable(infos);
            weigher = new BasicQueryWeigher(new ElasticsearchQueryFlattener(maxExpandedTerms, phraseAsTerms,
                    removeHighFrequencyTermsFromCommonTerms), infos, context.hitContext.topLevelReader(), context.query);
            // Build the QueryWeigher with the top level reader to get all
            // the frequency information
            cache.queryWeighers.put(key, weigher);
        }

        /**
         * Builds the hit enum including any required wrappers.
         */
        private HitEnum buildHitEnum() throws IOException {
            HitEnum e = buildHitFindingHitEnum();

            // Merge any overlapping hits to support matched fields and
            // analyzers that make overlaps.
            e = new OverlapMergingHitEnumWrapper(e);

            if (getOption(OPTION_RETURN_DEBUG_GRAPH, FALSE)) {
                e = new GraphvizHitEnum(e);
            }
            return e;
        }

        private HitEnum buildHitFindingHitEnum() throws IOException {
            List hitEnums = buildHitFindingHitEnums();
            switch (hitEnums.size()) {
                case 0:
                    return EmptyHitEnum.INSTANCE;
                case 1:
                    return hitEnums.get(0);
                default:
                    return new MergingHitEnum(hitEnums, HitEnum.LessThans.OFFSETS);
            }
        }

        /**
         * Builds the HitEnum that actually finds the hits in the first place.
         */
        private List buildHitFindingHitEnums() throws IOException {
            Boolean skipQuery = (Boolean) getOption("skip_query");
            List hitEnums = buildRegexHitEnums();
            if (skipQuery == null || !skipQuery) {
                hitEnums.addAll(buildLuceneHitFindingHitEnums());
            }
            return hitEnums;
        }

        @SuppressWarnings("checkstyle:ModifiedControlVariable")
        // cleanup the re-assignment of `regex` if we revisit that code
        private List buildRegexHitEnums() throws IOException {
            boolean luceneRegex = isLuceneRegexFlavor();
            if (luceneRegex) {
                cache.automatonHitEnumFactories = new HashMap<>();
            }
            Boolean caseInsensitiveOption = (Boolean) getOption("regex_case_insensitive");
            boolean caseInsensitive = caseInsensitiveOption == null ? false : caseInsensitiveOption;

            List hitEnums = new ArrayList<>();
            List fieldValues = defaultField.getFieldValues();
            if (fieldValues.isEmpty()) {
                return hitEnums;
            }

            for (String regex : getRegexes()) {
                if (luceneRegex) {
                    if (caseInsensitive) {
                        regex = regex.toLowerCase(getLocale());
                    }
                    AutomatonHitEnum.Factory factory = cache.automatonHitEnumFactories.get(regex);
                    if (factory == null) {
                        factory = buildFactoryForRegex(regex);
                        cache.automatonHitEnumFactories.put(regex, factory);
                    }
                    hitEnums.add(buildLuceneRegexHitEnumForRegex(factory, fieldValues, caseInsensitive));
                } else {
                    int options = 0;
                    if (caseInsensitive) {
                        options |= Pattern.CASE_INSENSITIVE;
                    }
                    hitEnums.add(buildJavaRegexHitEnumForRegex(Pattern.compile(regex, options), fieldValues));
                }
            }
            return hitEnums;
        }

        private AutomatonHitEnum.Factory buildFactoryForRegex(String regex) {
            return AutomatonHitEnum.factory(regex, getMaxDeterminizedStates());
        }

        private int getMaxDeterminizedStates() {
            if (maxDeterminizedStates != 0) {
                return maxDeterminizedStates;
            }
            Integer maxDeterminizedStates = (Integer) getOption("max_determinized_states");
            if (maxDeterminizedStates == null) {
                this.maxDeterminizedStates = DEFAULT_MAX_DETERMINIZED_STATES;
            } else {
                this.maxDeterminizedStates = maxDeterminizedStates;
            }
            return this.maxDeterminizedStates;
        }

        /**
         * Get the list of regexes to highlight or null if there aren't any.
         */
        @SuppressWarnings("unchecked")
        private List getRegexes() {
            Object regexes = getOption("regex");
            if (regexes == null) {
                return Collections.emptyList();
            }
            if (regexes instanceof String) {
                return Collections.singletonList((String) regexes);
            }
            return (List) regexes;
        }

        private HitEnum buildLuceneRegexHitEnumForRegex(final AutomatonHitEnum.Factory factory, List fieldValues,
                final boolean caseInsensitive) {
            final int positionGap = defaultField.getPositionGap();
            if (fieldValues.size() == 1) {
                String fieldValue = fieldValues.get(0);
                if (caseInsensitive) {
                    fieldValue = fieldValue.toLowerCase(getLocale());
                }
                return factory.build(fieldValue);
            } else {
                Iterator hitEnumsFromStreams = fieldValues.stream().map(fieldValue -> {
                    if (caseInsensitive) {
                        fieldValue = fieldValue.toLowerCase(getLocale());
                    }
                    return new HitEnumAndLength(factory.build(fieldValue), fieldValue.length());
                }).iterator();
                return new ConcatHitEnum(hitEnumsFromStreams, positionGap, 1);
            }
        }

        private HitEnum buildJavaRegexHitEnumForRegex(final Pattern pattern, List fieldValues) {
            final int positionGap = defaultField.getPositionGap();
            if (fieldValues.size() == 1) {
                return new RegexHitEnum(pattern.matcher(fieldValues.get(0)));
            } else {
                Iterator hitEnumsFromStreams = fieldValues.stream()
                        .map(fieldValue -> new HitEnumAndLength(new RegexHitEnum(pattern.matcher(fieldValue)), fieldValue.length()))
                        .iterator();
                return new ConcatHitEnum(hitEnumsFromStreams, positionGap, 1);
            }
        }

        private boolean isLuceneRegexFlavor() {
            Object regexFlavor = getOption("regex_flavor");
            if (regexFlavor == null || "lucene".equals(regexFlavor)) {
                return true;
            }
            if ("java".equals(regexFlavor)) {
                return false;
            }
            throw new IllegalArgumentException("Unknown regex flavor:  " + regexFlavor);
        }

        /**
         * Builds the HitEnum that finds the hits from Lucene.
         */
        private List buildLuceneHitFindingHitEnums() throws IOException {
            Set matchedFields = context.field.fieldOptions().matchedFields();
            if (matchedFields == null) {
                if (!defaultField.canProduceHits()) {
                    return Collections.emptyList();
                }
                return Collections.singletonList(defaultField.buildHitEnum());
            }
            List hitEnums = new ArrayList<>(matchedFields.size());
            extraFields = new ArrayList<>(matchedFields.size());
            for (String field : matchedFields) {
                FieldWrapper wrapper;
                if (context.fieldName.equals(field)) {
                    wrapper = defaultField;
                } else {
                    wrapper = new FieldWrapper(this, context, weigher, field);
                    if (!wrapper.exists()) {
                        continue;
                    }
                }
                if (wrapper.canProduceHits()) {
                    hitEnums.add(wrapper.buildHitEnum());
                }
                extraFields.add(wrapper);
            }
            if (hitEnums.isEmpty()) {
                return Collections.emptyList();
            }
            return hitEnums;
        }

        private SnippetChooser buildChooser() {
            HitBuilder hitBuilder = Snippet.DEFAULT_HIT_BUILDER;
            if (getOption(OPTION_RETURN_DEBUG_GRAPH, FALSE)) {
                hitBuilder = GraphvizHit.GRAPHVIZ_HIT_BUILDER;
            }
            if (context.field.fieldOptions().scoreOrdered()) {
                return buildScoreBasedSnippetChooser(true, hitBuilder);
            }
            Boolean topScoring = (Boolean) getOption("top_scoring");
            if (topScoring != null && topScoring) {
                return buildScoreBasedSnippetChooser(false, hitBuilder);
            }
            return new BasicSourceOrderSnippetChooser(hitBuilder);
        }

        private SnippetChooser buildScoreBasedSnippetChooser(boolean scoreOrdered, HitBuilder hitBuilder) {
            Integer maxFragmentsScored = (Integer) getOption("max_fragments_scored");
            if (maxFragmentsScored == null) {
                maxFragmentsScored = Integer.MAX_VALUE;
            }
            return new BasicScoreBasedSnippetChooser(scoreOrdered, buildSnippetWeigher(), hitBuilder, maxFragmentsScored);
        }

        private SnippetWeigher buildSnippetWeigher() {
            float defaultBase = 1.1f;
            Object config = getOption("fragment_weigher");
            if (config == null) {
                return new ExponentialSnippetWeigher(defaultBase);
            }
            if (config.equals("sum")) {
                return new SumSnippetWeigher();
            }
            if (config.equals("exponential")) {
                return new ExponentialSnippetWeigher(defaultBase);
            }
            try {
                @SuppressWarnings("unchecked")
                Map map = (Map) config;
                if (map.containsKey("sum")) {
                    return new SumSnippetWeigher();
                }
                Object exponentialConfig = map.get("exponential");
                if (exponentialConfig != null) {
                    @SuppressWarnings("unchecked")
                    Map exponentialConfigMap = (Map) exponentialConfig;
                    Number base = (Number) exponentialConfigMap.get("base");
                    if (base == null) {
                        return new ExponentialSnippetWeigher(defaultBase);
                    }
                    return new ExponentialSnippetWeigher(base.floatValue());
                }
            } catch (ClassCastException e) {
                throw new IllegalArgumentException("Invalid snippet weigher config:  " + config, e);
            }
            throw new IllegalArgumentException("Invalid snippet weigher config:  " + config);
        }

        private Text[] formatSnippets(List snippets) throws IOException {
            final SnippetFormatter formatter;
            if (getOption("return_offsets", FALSE)) {
                formatter = new OffsetSnippetFormatter();
            } else if (getOption(OPTION_RETURN_DEBUG_GRAPH, FALSE)) {
                formatter = new GraphvizSnippetFormatter(defaultField.buildSourceExtracter());
            } else if (getOption(OPTION_RETURN_SNIPPETS_WITH_OFFSET, FALSE)) {
                formatter = new OffsetAugmenterSnippetFormatter(
                        new SnippetFormatter.Default(
                                defaultField.buildSourceExtracter(),
                                context.field.fieldOptions().preTags()[0],
                                context.field.fieldOptions().postTags()[0]));
            } else {
                formatter = new SnippetFormatter.Default(defaultField.buildSourceExtracter(), context.field.fieldOptions().preTags()[0],
                        context.field.fieldOptions().postTags()[0]);
            }

            List fetchFields = buildFetchFields();
            if (fetchFields == null) {
                Text[] result = new Text[snippets.size()];
                int i = 0;
                for (Snippet snippet : snippets) {
                    result[i++] = new Text(formatter.format(snippet));
                }
                return result;
            }

            int fieldsPerSnippet = 1 + fetchFields.size();
            Text[] result = new Text[snippets.size() * fieldsPerSnippet];
            FetchedFieldIndexPicker picker = segmenter.buildFetchedFieldIndexPicker();
            int i = 0;
            for (Snippet snippet : snippets) {
                result[i++] = new Text(formatter.format(snippet));
                int index = picker.index(snippet);
                for (FieldWrapper fetchField : fetchFields) {
                    List values = fetchField.getFieldValues();
                    if (index >= 0 && index < values.size()) {
                        result[i++] = new Text(values.get(index));
                    } else {
                        result[i++] = EMPTY_STRING;
                    }
                }
            }
            return result;
        }

        /**
         * Return FieldWrappers for all fetch_fields or null if there aren't
         * any.
         */
        private List buildFetchFields() {
            @SuppressWarnings("unchecked")
            List fetchFields = (List) getOption("fetch_fields");
            if (fetchFields == null) {
                return null;
            }
            List fetchFieldWrappers = new ArrayList<>(fetchFields.size());
            List newExtraFields = new ArrayList<>();
            try {
                for (String fetchField : fetchFields) {
                    boolean found = false;
                    if (extraFields != null) {
                        for (FieldWrapper extraField : extraFields) {
                            if (extraField.fieldName().equals(fetchField)) {
                                fetchFieldWrappers.add(extraField);
                                found = true;
                                break;
                            }
                        }
                    }
                    if (!found) {
                        FieldWrapper fieldWrapper = new FieldWrapper(this, context, weigher, fetchField);
                        newExtraFields.add(fieldWrapper);
                        fetchFieldWrappers.add(fieldWrapper);
                    }
                }
            } finally {
                if (extraFields == null) {
                    extraFields = newExtraFields;
                } else {
                    extraFields.addAll(newExtraFields);
                }
            }
            return fetchFieldWrappers;
        }

        private SegmenterFactory buildSegmenterFactory() {
            FieldOptions options = context.field.fieldOptions();
            if (options.numberOfFragments() == 0) {
                return new WholeSourceSegmenterFactory();
            }
            if (options.fragmenter() == null || options.fragmenter().equals("scan")) {
                // TODO boundaryChars
                return new CharScanningSegmenterFactory(options.fragmentCharSize(), options.boundaryMaxScan());
            }
            if (options.fragmenter().equals("sentence")) {
                return new SentenceIteratorSegmenterFactory(getLocale(), options.boundaryMaxScan());
            }
            if (options.fragmenter().equals("none")) {
                return new WholeSourceSegmenterFactory();
            }
            throw new IllegalArgumentException("Unknown fragmenter:  '" + options.fragmenter() + "'.  Options are 'scan' or 'sentence'.");
        }

        private Locale getLocale() {
            if (locale != null) {
                return locale;
            }
            String localeString = (String) getOption("locale");
            locale = localeString == null ? Locale.US : LocaleUtils.parse(localeString);
            return locale;
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy