All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.solr.search.grouping.distributed.responseprocessor.TopGroupsShardResponseProcessor Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.solr.search.grouping.distributed.responseprocessor;

import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopFieldDocs;
import org.apache.lucene.search.grouping.GroupDocs;
import org.apache.lucene.search.grouping.TopGroups;
import org.apache.lucene.util.BytesRef;
import org.apache.solr.client.solrj.SolrServerException;
import org.apache.solr.common.params.ShardParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.common.util.SimpleOrderedMap;
import org.apache.solr.handler.component.ResponseBuilder;
import org.apache.solr.handler.component.ShardDoc;
import org.apache.solr.handler.component.ShardRequest;
import org.apache.solr.handler.component.ShardResponse;
import org.apache.solr.response.SolrQueryResponse;
import org.apache.solr.search.Grouping;
import org.apache.solr.search.SortSpec;
import org.apache.solr.search.grouping.distributed.ShardResponseProcessor;
import org.apache.solr.search.grouping.distributed.command.QueryCommandResult;
import org.apache.solr.search.grouping.distributed.shardresultserializer.TopGroupsResultTransformer;

/**
 * Concrete implementation for merging {@link TopGroups} instances from shard responses.
 */
public class TopGroupsShardResponseProcessor implements ShardResponseProcessor {

  @Override
  @SuppressWarnings("unchecked")
  public void process(ResponseBuilder rb, ShardRequest shardRequest) {
    Sort groupSort = rb.getGroupingSpec().getGroupSortSpec().getSort();
    String[] fields = rb.getGroupingSpec().getFields();
    String[] queries = rb.getGroupingSpec().getQueries();
    SortSpec withinGroupSortSpec = rb.getGroupingSpec().getWithinGroupSortSpec();
    Sort withinGroupSort = withinGroupSortSpec.getSort();
    assert withinGroupSort != null;

    boolean simpleOrMain = rb.getGroupingSpec().getResponseFormat() == Grouping.Format.simple ||
        rb.getGroupingSpec().isMain();

    // If group.format=simple group.offset doesn't make sense
    int groupOffsetDefault;
    if (simpleOrMain) {
      groupOffsetDefault = 0;
    } else {
      groupOffsetDefault = withinGroupSortSpec.getOffset();
    }
    int docsPerGroupDefault = withinGroupSortSpec.getCount();

    Map>> commandTopGroups = new HashMap<>();
    for (String field : fields) {
      commandTopGroups.put(field, new ArrayList<>());
    }

    Map> commandTopDocs = new HashMap<>();
    for (String query : queries) {
      commandTopDocs.put(query, new ArrayList<>());
    }

    TopGroupsResultTransformer serializer = new TopGroupsResultTransformer(rb);

    NamedList shardInfo = null;
    if (rb.req.getParams().getBool(ShardParams.SHARDS_INFO, false)) {
      shardInfo = new SimpleOrderedMap<>();
      rb.rsp.getValues().add(ShardParams.SHARDS_INFO, shardInfo);
    }

    for (ShardResponse srsp : shardRequest.responses) {
      SimpleOrderedMap individualShardInfo = null;
      if (shardInfo != null) {
        individualShardInfo = new SimpleOrderedMap<>();

        if (srsp.getException() != null) {
          Throwable t = srsp.getException();
          if (t instanceof SolrServerException && ((SolrServerException) t).getCause() != null) {
            t = ((SolrServerException) t).getCause();
          }
          individualShardInfo.add("error", t.toString());
          StringWriter trace = new StringWriter();
          t.printStackTrace(new PrintWriter(trace));
          individualShardInfo.add("trace", trace.toString());
        } else {
          // summary for successful shard response is added down below
        }
        if (srsp.getSolrResponse() != null) {
          individualShardInfo.add("time", srsp.getSolrResponse().getElapsedTime());
        }
        if (srsp.getShardAddress() != null) {
          individualShardInfo.add("shardAddress", srsp.getShardAddress());
        }
        shardInfo.add(srsp.getShard(), individualShardInfo);
      }
      if (ShardParams.getShardsTolerantAsBool(rb.req.getParams()) && srsp.getException() != null) {
        rb.rsp.getResponseHeader().asShallowMap().put(SolrQueryResponse.RESPONSE_HEADER_PARTIAL_RESULTS_KEY, Boolean.TRUE);
        continue; // continue if there was an error and we're tolerant.  
      }
      NamedList secondPhaseResult = (NamedList) srsp.getSolrResponse().getResponse().get("secondPhase");
      if(secondPhaseResult == null)
        continue;
      Map result = serializer.transformToNative(secondPhaseResult, groupSort, withinGroupSort, srsp.getShard());
      int numFound = 0;
      float maxScore = Float.NaN;
      for (Map.Entry>> entry : commandTopGroups.entrySet()) {
        TopGroups topGroups = (TopGroups) result.get(entry.getKey());
        if (topGroups == null) {
          continue;
        }
        if (individualShardInfo != null) { // keep track of this when shards.info=true
          numFound += topGroups.totalHitCount;
          if (Float.isNaN(maxScore) || topGroups.maxScore > maxScore) maxScore = topGroups.maxScore;
        }
        entry.getValue().add(topGroups);
      }
      for (String query : queries) {
        QueryCommandResult queryCommandResult = (QueryCommandResult) result.get(query);
        if (individualShardInfo != null) { // keep track of this when shards.info=true
          numFound += queryCommandResult.getMatches();
          float thisMax = queryCommandResult.getMaxScore();
          if (Float.isNaN(maxScore) || thisMax > maxScore) maxScore = thisMax;
        }
        commandTopDocs.get(query).add(queryCommandResult);
      }
      if (individualShardInfo != null) { // when shards.info=true
        individualShardInfo.add("numFound", numFound);
        individualShardInfo.add("maxScore", maxScore);
      }
    }
    for (Map.Entry>> entry : commandTopGroups.entrySet()) {
      List> topGroups = entry.getValue();
      if (topGroups.isEmpty()) {
        continue;
      }

      TopGroups[] topGroupsArr = new TopGroups[topGroups.size()];
      int docsPerGroup = docsPerGroupDefault;
      if (docsPerGroup < 0) {
        docsPerGroup = 0;
        for (TopGroups subTopGroups : topGroups) {
          docsPerGroup += subTopGroups.totalGroupedHitCount;
        }
      }
      rb.mergedTopGroups.put(entry.getKey(), TopGroups.merge(topGroups.toArray(topGroupsArr), groupSort, withinGroupSort, groupOffsetDefault, docsPerGroup, TopGroups.ScoreMergeMode.None));
    }

    // calculate topN and start for group.query
    int topN = docsPerGroupDefault >= 0? docsPerGroupDefault: Integer.MAX_VALUE;
    int start = groupOffsetDefault;
    if (simpleOrMain) {
      // use start and rows here
      start = rb.getGroupingSpec().getGroupSortSpec().getOffset();
      int limit = rb.getGroupingSpec().getGroupSortSpec().getCount();
      topN = limit >= 0? limit: Integer.MAX_VALUE;
    }

    for (Map.Entry> entry : commandTopDocs.entrySet()) {
      List queryCommandResults = entry.getValue();
      List topDocs = new ArrayList<>(queryCommandResults.size());
      int mergedMatches = 0;
      float maxScore = Float.NaN;
      for (QueryCommandResult queryCommandResult : queryCommandResults) {
        TopDocs thisTopDocs = queryCommandResult.getTopDocs();
        topDocs.add(thisTopDocs);
        mergedMatches += queryCommandResult.getMatches();
        if (thisTopDocs.scoreDocs.length > 0) {
          float thisMaxScore = queryCommandResult.getMaxScore();
          if (Float.isNaN(maxScore) || thisMaxScore > maxScore) {
            maxScore = thisMaxScore;
          }
        }
      }

      final TopDocs mergedTopDocs;
      if (withinGroupSort.equals(Sort.RELEVANCE)) {
        mergedTopDocs = TopDocs.merge(
            start, topN, topDocs.toArray(new TopDocs[topDocs.size()]), true);
      } else {
        mergedTopDocs = TopDocs.merge(
            withinGroupSort, start, topN, topDocs.toArray(new TopFieldDocs[topDocs.size()]), true);
      }
      rb.mergedQueryCommandResults.put(entry.getKey(), new QueryCommandResult(mergedTopDocs, mergedMatches, maxScore));
    }
    fillResultIds(rb);
  }

  /**
   * Fill the {@link ResponseBuilder}'s resultIds field.
   * @param rb the response builder
   */
  static void fillResultIds(ResponseBuilder rb) {
    Map resultIds = new HashMap<>();
    int i = 0;
    for (TopGroups topGroups : rb.mergedTopGroups.values()) {
      for (GroupDocs group : topGroups.groups) {
        for (ScoreDoc scoreDoc : group.scoreDocs) {
          ShardDoc solrDoc = (ShardDoc) scoreDoc;
          // Include the first if there are duplicate IDs
          if ( ! resultIds.containsKey(solrDoc.id)) {
            solrDoc.positionInResponse = i++;
            resultIds.put(solrDoc.id, solrDoc);
          }
        }
      }
    }
    for (QueryCommandResult queryCommandResult : rb.mergedQueryCommandResults.values()) {
      for (ScoreDoc scoreDoc : queryCommandResult.getTopDocs().scoreDocs) {
        ShardDoc solrDoc = (ShardDoc) scoreDoc;
        solrDoc.positionInResponse = i++;
        resultIds.put(solrDoc.id, solrDoc);
      }
    }

    rb.resultIds = resultIds;
  }
}