All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.solr.search.grouping.distributed.shardresultserializer.TopGroupsResultTransformer Maven / Gradle / Ivy

There is a newer version: 9.7.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.solr.search.grouping.distributed.shardresultserializer;

import static org.apache.solr.common.params.CommonParams.ID;

import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.document.Document;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopFieldDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.grouping.GroupDocs;
import org.apache.lucene.search.grouping.TopGroups;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.CharsRefBuilder;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.handler.component.ResponseBuilder;
import org.apache.solr.handler.component.ShardDoc;
import org.apache.solr.schema.FieldType;
import org.apache.solr.schema.IndexSchema;
import org.apache.solr.schema.SchemaField;
import org.apache.solr.search.grouping.Command;
import org.apache.solr.search.grouping.distributed.command.QueryCommand;
import org.apache.solr.search.grouping.distributed.command.QueryCommandResult;
import org.apache.solr.search.grouping.distributed.command.TopGroupsFieldCommand;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Implementation for transforming {@link TopGroups} and {@link TopDocs} into a {@link NamedList}
 * structure and vice versa.
 */
public class TopGroupsResultTransformer
    implements ShardResultTransformer>, Map> {

  private final ResponseBuilder rb;

  private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());

  public TopGroupsResultTransformer(ResponseBuilder rb) {
    this.rb = rb;
  }

  @Override
  public NamedList> transform(List> data) throws IOException {
    NamedList> result = new NamedList<>();
    final IndexSchema schema = rb.req.getSearcher().getSchema();
    for (Command command : data) {
      NamedList commandResult;
      if (command instanceof TopGroupsFieldCommand) {
        TopGroupsFieldCommand fieldCommand = (TopGroupsFieldCommand) command;
        SchemaField groupField = schema.getField(fieldCommand.getKey());
        commandResult = serializeTopGroups(fieldCommand.result(), groupField);
      } else if (command instanceof QueryCommand) {
        QueryCommand queryCommand = (QueryCommand) command;
        commandResult = serializeTopDocs(queryCommand.result());
      } else {
        commandResult = null;
      }

      result.add(command.getKey(), commandResult);
    }
    return result;
  }

  @Override
  public Map transformToNative(
      NamedList> shardResponse, Sort groupSort, Sort withinGroupSort, String shard) {
    Map result = new HashMap<>();

    final IndexSchema schema = rb.req.getSearcher().getSchema();

    for (Map.Entry> entry : shardResponse) {
      String key = entry.getKey();
      NamedList commandResult = entry.getValue();
      Integer totalGroupedHitCount = (Integer) commandResult.get("totalGroupedHitCount");
      Number totalHits = (Number) commandResult.get("totalHits"); // previously Integer now Long
      if (totalHits != null) {
        Integer matches = (Integer) commandResult.get("matches");
        Float maxScore = (Float) commandResult.get("maxScore");
        if (maxScore == null) {
          maxScore = Float.NaN;
        }

        @SuppressWarnings("unchecked")
        List> documents =
            (List>) commandResult.get("documents");
        ScoreDoc[] scoreDocs = transformToNativeShardDoc(documents, groupSort, shard, schema);
        final TopDocs topDocs;
        if (withinGroupSort.equals(Sort.RELEVANCE)) {
          topDocs =
              new TopDocs(
                  new TotalHits(totalHits.longValue(), TotalHits.Relation.EQUAL_TO), scoreDocs);
        } else {
          topDocs =
              new TopFieldDocs(
                  new TotalHits(totalHits.longValue(), TotalHits.Relation.EQUAL_TO),
                  scoreDocs,
                  withinGroupSort.getSort());
        }
        result.put(key, new QueryCommandResult(topDocs, matches, maxScore));
        continue;
      }

      Integer totalHitCount = (Integer) commandResult.get("totalHitCount");

      List> groupDocs = new ArrayList<>();
      for (int i = 2; i < commandResult.size(); i++) {
        String groupValue = commandResult.getName(i);
        @SuppressWarnings("unchecked")
        NamedList groupResult = (NamedList) commandResult.getVal(i);
        // previously Integer now Long
        Number totalGroupHits = (Number) groupResult.get("totalHits");
        Float maxScore = (Float) groupResult.get("maxScore");
        if (maxScore == null) {
          maxScore = Float.NaN;
        }

        @SuppressWarnings("unchecked")
        List> documents = (List>) groupResult.get("documents");
        ScoreDoc[] scoreDocs = transformToNativeShardDoc(documents, withinGroupSort, shard, schema);

        BytesRef groupValueRef = groupValue != null ? new BytesRef(groupValue) : null;
        groupDocs.add(
            new GroupDocs<>(
                Float.NaN,
                maxScore,
                new TotalHits(totalGroupHits.longValue(), TotalHits.Relation.EQUAL_TO),
                scoreDocs,
                groupValueRef,
                null));
      }

      GroupDocs[] groupDocsArr = groupDocs.toArray(newGroupDocsArray(groupDocs.size()));
      TopGroups topGroups =
          new TopGroups<>(
              groupSort.getSort(),
              withinGroupSort.getSort(),
              totalHitCount,
              totalGroupedHitCount,
              groupDocsArr,
              Float.NaN);

      result.put(key, topGroups);
    }

    return result;
  }

  @SuppressWarnings("unchecked")
  private static GroupDocs[] newGroupDocsArray(int size) {
    return (GroupDocs[]) Array.newInstance(GroupDocs.class, size);
  }

  protected ScoreDoc[] transformToNativeShardDoc(
      List> documents, Sort groupSort, String shard, IndexSchema schema) {
    ScoreDoc[] scoreDocs = new ScoreDoc[documents.size()];
    int j = 0;
    for (NamedList document : documents) {
      Object docId = document.get(ID);
      if (docId != null) {
        docId = docId.toString();
      } else {
        log.error("doc {} has null 'id'", document);
      }
      Float score = (Float) document.get("score");
      if (score == null) {
        score = Float.NaN;
      }
      Object[] sortValues = null;
      Object sortValuesVal = document.get("sortValues");
      if (sortValuesVal != null) {
        sortValues = ((List) sortValuesVal).toArray();
        for (int k = 0; k < sortValues.length; k++) {
          SchemaField field =
              groupSort.getSort()[k].getField() != null
                  ? schema.getFieldOrNull(groupSort.getSort()[k].getField())
                  : null;
          sortValues[k] = ShardResultTransformerUtils.unmarshalSortValue(sortValues[k], field);
        }
      } else {
        log.debug("doc {} has null 'sortValues'", document);
      }
      scoreDocs[j++] = new ShardDoc(score, sortValues, docId, shard);
    }
    return scoreDocs;
  }

  protected NamedList serializeTopGroups(TopGroups data, SchemaField groupField)
      throws IOException {
    NamedList result = new NamedList<>();
    result.add("totalGroupedHitCount", data.totalGroupedHitCount);
    result.add("totalHitCount", data.totalHitCount);
    if (data.totalGroupCount != null) {
      result.add("totalGroupCount", data.totalGroupCount);
    }

    final IndexSchema schema = rb.req.getSearcher().getSchema();
    SchemaField uniqueField = schema.getUniqueKeyField();
    for (GroupDocs searchGroup : data.groups) {
      NamedList groupResult = new NamedList<>();
      assert searchGroup.totalHits.relation == TotalHits.Relation.EQUAL_TO;
      groupResult.add("totalHits", searchGroup.totalHits.value);
      if (!Float.isNaN(searchGroup.maxScore)) {
        groupResult.add("maxScore", searchGroup.maxScore);
      }

      List> documents = new ArrayList<>();
      for (int i = 0; i < searchGroup.scoreDocs.length; i++) {
        NamedList document = new NamedList<>();
        documents.add(document);

        Document doc = retrieveDocument(uniqueField, searchGroup.scoreDocs[i].doc);
        document.add(ID, uniqueField.getType().toExternal(doc.getField(uniqueField.getName())));
        if (!Float.isNaN(searchGroup.scoreDocs[i].score)) {
          document.add("score", searchGroup.scoreDocs[i].score);
        }
        if (!(searchGroup.scoreDocs[i] instanceof FieldDoc)) {
          continue; // thus don't add sortValues below
        }

        FieldDoc fieldDoc = (FieldDoc) searchGroup.scoreDocs[i];
        Object[] convertedSortValues = new Object[fieldDoc.fields.length];
        for (int j = 0; j < fieldDoc.fields.length; j++) {
          Object sortValue = fieldDoc.fields[j];
          Sort withinGroupSort = rb.getGroupingSpec().getWithinGroupSortSpec().getSort();
          SchemaField field =
              withinGroupSort.getSort()[j].getField() != null
                  ? schema.getFieldOrNull(withinGroupSort.getSort()[j].getField())
                  : null;
          if (field != null) {
            FieldType fieldType = field.getType();
            if (sortValue != null) {
              sortValue = fieldType.marshalSortValue(sortValue);
            }
          }
          convertedSortValues[j] = sortValue;
        }
        document.add("sortValues", convertedSortValues);
      }
      groupResult.add("documents", documents);
      String groupValue =
          searchGroup.groupValue != null
              ? groupField
                  .getType()
                  .indexedToReadable(searchGroup.groupValue, new CharsRefBuilder())
                  .toString()
              : null;
      result.add(groupValue, groupResult);
    }

    return result;
  }

  protected NamedList serializeTopDocs(QueryCommandResult result) throws IOException {
    NamedList queryResult = new NamedList<>();
    queryResult.add("matches", result.getMatches());
    TopDocs topDocs = result.getTopDocs();
    assert topDocs.totalHits.relation == TotalHits.Relation.EQUAL_TO;
    queryResult.add("totalHits", topDocs.totalHits.value);
    // debug: assert !Float.isNaN(result.getTopDocs().getMaxScore()) ==
    // rb.getGroupingSpec().isNeedScore();
    if (!Float.isNaN(result.getMaxScore())) {
      queryResult.add("maxScore", result.getMaxScore());
    }
    List> documents = new ArrayList<>();
    queryResult.add("documents", documents);

    final IndexSchema schema = rb.req.getSearcher().getSchema();
    SchemaField uniqueField = schema.getUniqueKeyField();
    for (ScoreDoc scoreDoc : result.getTopDocs().scoreDocs) {
      NamedList document = new NamedList<>();
      documents.add(document);

      Document doc = retrieveDocument(uniqueField, scoreDoc.doc);
      document.add(ID, uniqueField.getType().toExternal(doc.getField(uniqueField.getName())));
      if (!Float.isNaN(scoreDoc.score)) {
        document.add("score", scoreDoc.score);
      }
      if (!(scoreDoc instanceof FieldDoc)) {
        continue; // thus don't add sortValues below
      }

      FieldDoc fieldDoc = (FieldDoc) scoreDoc;
      Object[] convertedSortValues = new Object[fieldDoc.fields.length];
      for (int j = 0; j < fieldDoc.fields.length; j++) {
        Object sortValue = fieldDoc.fields[j];
        Sort groupSort = rb.getGroupingSpec().getGroupSortSpec().getSort();
        SchemaField field =
            groupSort.getSort()[j].getField() != null
                ? schema.getFieldOrNull(groupSort.getSort()[j].getField())
                : null;
        convertedSortValues[j] = ShardResultTransformerUtils.marshalSortValue(sortValue, field);
      }
      document.add("sortValues", convertedSortValues);
    }

    return queryResult;
  }

  private Document retrieveDocument(final SchemaField uniqueField, int doc) throws IOException {
    return rb.req.getSearcher().doc(doc, Collections.singleton(uniqueField.getName()));
  }
}