All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.lucene.search.grouping.SearchGroup Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.lucene.search.grouping;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NavigableSet;
import java.util.TreeSet;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.Pruning;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;

/**
 * Represents a group that is found during the first pass search.
 *
 * @lucene.experimental
 */
public class SearchGroup {

  /** The value that defines this group */
  public T groupValue;

  /**
   * The sort values used during sorting. These are the groupSort field values of the highest rank
   * document (by the groupSort) within the group. Can be null if 
   * fillFields=false had been passed to {@link FirstPassGroupingCollector#getTopGroups}
   */
  public Object[] sortValues;

  @Override
  public String toString() {
    return ("SearchGroup(groupValue="
        + groupValue
        + " sortValues="
        + Arrays.toString(sortValues)
        + ")");
  }

  @Override
  public boolean equals(Object o) {
    if (this == o) return true;
    if (o == null || getClass() != o.getClass()) return false;

    SearchGroup that = (SearchGroup) o;

    if (groupValue == null) {
      if (that.groupValue != null) {
        return false;
      }
    } else if (!groupValue.equals(that.groupValue)) {
      return false;
    }

    return true;
  }

  @Override
  public int hashCode() {
    return groupValue != null ? groupValue.hashCode() : 0;
  }

  private static class ShardIter {
    public final Iterator> iter;
    public final int shardIndex;

    public ShardIter(Collection> shard, int shardIndex) {
      this.shardIndex = shardIndex;
      iter = shard.iterator();
      assert iter.hasNext();
    }

    public SearchGroup next() {
      assert iter.hasNext();
      final SearchGroup group = iter.next();
      if (group.sortValues == null) {
        throw new IllegalArgumentException(
            "group.sortValues is null; you must pass fillFields=true to the first pass collector");
      }
      return group;
    }

    @Override
    public String toString() {
      return "ShardIter(shard=" + shardIndex + ")";
    }
  }

  // Holds all shards currently on the same group
  private static class MergedGroup {

    // groupValue may be null!
    public final T groupValue;

    public Object[] topValues;
    public final List> shards = new ArrayList<>();
    public int minShardIndex;
    public boolean processed;
    public boolean inQueue;

    public MergedGroup(T groupValue) {
      this.groupValue = groupValue;
    }

    // Only for assert
    private boolean neverEquals(Object _other) {
      if (_other instanceof MergedGroup) {
        MergedGroup other = (MergedGroup) _other;
        if (groupValue == null) {
          assert other.groupValue != null;
        } else {
          assert !groupValue.equals(other.groupValue);
        }
      }
      return true;
    }

    @Override
    public boolean equals(Object _other) {
      // We never have another MergedGroup instance with
      // same groupValue
      assert neverEquals(_other);

      if (_other instanceof MergedGroup) {
        MergedGroup other = (MergedGroup) _other;
        if (groupValue == null) {
          return other.groupValue == null;
        } else {
          return groupValue.equals(other.groupValue);
        }
      } else {
        return false;
      }
    }

    @Override
    public int hashCode() {
      if (groupValue == null) {
        return 0;
      } else {
        return groupValue.hashCode();
      }
    }
  }

  private static class GroupComparator implements Comparator> {

    @SuppressWarnings("rawtypes")
    public final FieldComparator[] comparators;

    public final int[] reversed;

    @SuppressWarnings({"unchecked", "rawtypes"})
    public GroupComparator(Sort groupSort) {
      final SortField[] sortFields = groupSort.getSort();
      comparators = new FieldComparator[sortFields.length];
      reversed = new int[sortFields.length];
      for (int compIDX = 0; compIDX < sortFields.length; compIDX++) {
        final SortField sortField = sortFields[compIDX];
        comparators[compIDX] = sortField.getComparator(1, Pruning.NONE);
        reversed[compIDX] = sortField.getReverse() ? -1 : 1;
      }
    }

    @Override
    @SuppressWarnings({"unchecked", "rawtypes"})
    public int compare(MergedGroup group, MergedGroup other) {
      if (group == other) {
        return 0;
      }
      // System.out.println("compare group=" + group + " other=" + other);
      final Object[] groupValues = group.topValues;
      final Object[] otherValues = other.topValues;
      // System.out.println("  groupValues=" + groupValues + " otherValues=" + otherValues);
      for (int compIDX = 0; compIDX < comparators.length; compIDX++) {
        final int c =
            reversed[compIDX]
                * comparators[compIDX].compareValues(groupValues[compIDX], otherValues[compIDX]);
        if (c != 0) {
          return c;
        }
      }

      // Tie break by min shard index:
      assert group.minShardIndex != other.minShardIndex;
      return group.minShardIndex - other.minShardIndex;
    }
  }

  private static class GroupMerger {

    private final GroupComparator groupComp;
    private final NavigableSet> queue;
    private final Map> groupsSeen;

    public GroupMerger(Sort groupSort) {
      groupComp = new GroupComparator<>(groupSort);
      queue = new TreeSet<>(groupComp);
      groupsSeen = new HashMap<>();
    }

    @SuppressWarnings({"unchecked", "rawtypes"})
    private void updateNextGroup(int topN, ShardIter shard) {
      while (shard.iter.hasNext()) {
        final SearchGroup group = shard.next();
        MergedGroup mergedGroup = groupsSeen.get(group.groupValue);
        final boolean isNew = mergedGroup == null;
        // System.out.println("    next group=" + (group.groupValue == null ? "null" : ((BytesRef)
        // group.groupValue).utf8ToString()) + " sort=" + Arrays.toString(group.sortValues));

        if (isNew) {
          // Start a new group:
          // System.out.println("      new");
          mergedGroup = new MergedGroup<>(group.groupValue);
          mergedGroup.minShardIndex = shard.shardIndex;
          assert group.sortValues != null;
          mergedGroup.topValues = group.sortValues;
          groupsSeen.put(group.groupValue, mergedGroup);
          mergedGroup.inQueue = true;
          queue.add(mergedGroup);
        } else if (mergedGroup.processed) {
          // This shard produced a group that we already
          // processed; move on to next group...
          continue;
        } else {
          // System.out.println("      old");
          boolean competes = false;
          for (int compIDX = 0; compIDX < groupComp.comparators.length; compIDX++) {
            final int cmp =
                groupComp.reversed[compIDX]
                    * groupComp.comparators[compIDX].compareValues(
                        group.sortValues[compIDX], mergedGroup.topValues[compIDX]);
            if (cmp < 0) {
              // Definitely competes
              competes = true;
              break;
            } else if (cmp > 0) {
              // Definitely does not compete
              break;
            } else if (compIDX == groupComp.comparators.length - 1) {
              if (shard.shardIndex < mergedGroup.minShardIndex) {
                competes = true;
              }
            }
          }

          // System.out.println("      competes=" + competes);

          if (competes) {
            // Group's sort changed -- remove & re-insert
            if (mergedGroup.inQueue) {
              queue.remove(mergedGroup);
            }
            mergedGroup.topValues = group.sortValues;
            mergedGroup.minShardIndex = shard.shardIndex;
            queue.add(mergedGroup);
            mergedGroup.inQueue = true;
          }
        }

        mergedGroup.shards.add(shard);
        break;
      }

      // Prune un-competitive groups:
      while (queue.size() > topN) {
        final MergedGroup group = queue.pollLast();
        // System.out.println("PRUNE: " + group);
        group.inQueue = false;
      }
    }

    public Collection> merge(
        List>> shards, int offset, int topN) {

      final int maxQueueSize = offset + topN;

      // System.out.println("merge");
      // Init queue:
      for (int shardIDX = 0; shardIDX < shards.size(); shardIDX++) {
        final Collection> shard = shards.get(shardIDX);
        if (!shard.isEmpty()) {
          // System.out.println("  insert shard=" + shardIDX);
          updateNextGroup(maxQueueSize, new ShardIter<>(shard, shardIDX));
        }
      }

      // Pull merged topN groups:
      final List> newTopGroups = new ArrayList<>(topN);

      int count = 0;

      while (!queue.isEmpty()) {
        final MergedGroup group = queue.pollFirst();
        group.processed = true;
        // System.out.println("  pop: shards=" + group.shards + " group=" + (group.groupValue ==
        // null ? "null" : (((BytesRef) group.groupValue).utf8ToString())) + " sortValues=" +
        // Arrays.toString(group.topValues));
        if (count++ >= offset) {
          final SearchGroup newGroup = new SearchGroup<>();
          newGroup.groupValue = group.groupValue;
          newGroup.sortValues = group.topValues;
          newTopGroups.add(newGroup);
          if (newTopGroups.size() == topN) {
            break;
          }
          // } else {
          // System.out.println("    skip < offset");
        }

        // Advance all iters in this group:
        for (ShardIter shardIter : group.shards) {
          updateNextGroup(maxQueueSize, shardIter);
        }
      }

      if (newTopGroups.isEmpty()) {
        return null;
      } else {
        return newTopGroups;
      }
    }
  }

  /**
   * Merges multiple collections of top groups, for example obtained from separate index shards. The
   * provided groupSort must match how the groups were sorted, and the provided SearchGroups must
   * have been computed with fillFields=true passed to {@link
   * FirstPassGroupingCollector#getTopGroups}.
   *
   * 

NOTE: this returns null if the topGroups is empty. */ public static Collection> merge( List>> topGroups, int offset, int topN, Sort groupSort) { if (topGroups.isEmpty()) { return null; } else { return new GroupMerger(groupSort).merge(topGroups, offset, topN); } } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy