io.zulia.server.search.QueryCombiner Maven / Gradle / Ivy
package io.zulia.server.search;
import io.zulia.message.ZuliaBase.Term;
import io.zulia.message.ZuliaIndex.FieldConfig;
import io.zulia.message.ZuliaQuery.AnalysisRequest;
import io.zulia.message.ZuliaQuery.AnalysisResult;
import io.zulia.message.ZuliaQuery.CountRequest;
import io.zulia.message.ZuliaQuery.FacetCount;
import io.zulia.message.ZuliaQuery.FacetGroup;
import io.zulia.message.ZuliaQuery.FieldSort;
import io.zulia.message.ZuliaQuery.IndexShardResponse;
import io.zulia.message.ZuliaQuery.LastIndexResult;
import io.zulia.message.ZuliaQuery.LastResult;
import io.zulia.message.ZuliaQuery.Query;
import io.zulia.message.ZuliaQuery.ScoredResult;
import io.zulia.message.ZuliaQuery.ShardQueryResponse;
import io.zulia.message.ZuliaQuery.SortRequest;
import io.zulia.message.ZuliaQuery.SortValues;
import io.zulia.message.ZuliaServiceOuterClass.InternalQueryResponse;
import io.zulia.message.ZuliaServiceOuterClass.QueryRequest;
import io.zulia.message.ZuliaServiceOuterClass.QueryResponse;
import io.zulia.server.analysis.frequency.TermFreq;
import io.zulia.server.index.ZuliaIndex;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.concurrent.atomic.AtomicLong;
import java.util.logging.Logger;
import java.util.stream.Collectors;
public class QueryCombiner {
private final static Comparator scoreCompare = new ScoreCompare();
private final static Logger log = Logger.getLogger(QueryCombiner.class.getSimpleName());
private final List responses;
private final Map> indexToShardQueryResponseMap;
private final List shardResponses;
private final int amount;
private final int start;
private final LastResult lastResult;
private final List analysisRequestList;
private boolean isShort;
private List results;
private int resultsSize;
private SortRequest sortRequest;
private Query query;
private final Collection indexes;
private final Map indexToShardCount;
public QueryCombiner(Collection indexes, QueryRequest request, List responses) {
this.indexToShardCount = new HashMap<>();
for (ZuliaIndex zuliaIndex : indexes) {
indexToShardCount.put(zuliaIndex.getIndexName(), zuliaIndex.getNumberOfShards());
}
this.indexes = indexes;
this.responses = responses;
this.amount = request.getAmount() + request.getStart();
this.indexToShardQueryResponseMap = new HashMap<>();
this.shardResponses = new ArrayList<>();
this.lastResult = request.getLastResult();
this.sortRequest = request.getSortRequest();
this.start = request.getStart();
this.query = request.getQuery();
this.analysisRequestList = request.getAnalysisRequestList();
this.isShort = false;
this.results = Collections.emptyList();
this.resultsSize = 0;
}
private void validate() throws Exception {
for (InternalQueryResponse iqr : responses) {
for (IndexShardResponse isr : iqr.getIndexShardResponseList()) {
String indexName = isr.getIndexName();
if (!indexToShardQueryResponseMap.containsKey(indexName)) {
indexToShardQueryResponseMap.put(indexName, new HashMap<>());
}
for (ShardQueryResponse sr : isr.getShardQueryResponseList()) {
int shardNumber = sr.getShardNumber();
Map shardResponseMap = indexToShardQueryResponseMap.get(indexName);
if (shardResponseMap.containsKey(shardNumber)) {
throw new Exception("Shard <" + shardNumber + "> is repeated for <" + indexName + ">");
}
else {
shardResponseMap.put(shardNumber, sr);
shardResponses.add(sr);
}
}
}
}
for (ZuliaIndex index : indexes) {
int numberOfShards = index.getNumberOfShards();
Map shardResponseMap = indexToShardQueryResponseMap.get(index.getIndexName());
if (shardResponseMap == null) {
throw new Exception("Missing index <" + index.getIndexName() + "> in response");
}
if (shardResponseMap.size() != numberOfShards) {
throw new Exception("Found <" + shardResponseMap.size() + "> expected <" + numberOfShards + ">");
}
for (int shardNumber = 0; shardNumber < numberOfShards; shardNumber++) {
if (!shardResponseMap.containsKey(shardNumber)) {
throw new Exception("Missing shard <" + shardNumber + ">");
}
}
}
}
public QueryResponse getQueryResponse() throws Exception {
validate();
boolean sorting = (sortRequest != null && !sortRequest.getFieldSortList().isEmpty());
long totalHits = 0;
long returnedHits = 0;
for (ShardQueryResponse sr : shardResponses) {
totalHits += sr.getTotalHits();
returnedHits += sr.getScoredResultList().size();
}
QueryResponse.Builder builder = QueryResponse.newBuilder();
builder.setTotalHits(totalHits);
resultsSize = Math.min(amount, (int) returnedHits);
results = Collections.emptyList();
Map lastIndexResultMap = new HashMap<>();
for (String indexName : indexToShardQueryResponseMap.keySet()) {
int numberOfShards = indexToShardCount.get(indexName);
lastIndexResultMap.put(indexName, new ScoredResult[numberOfShards]);
}
for (LastIndexResult lir : lastResult.getLastIndexResultList()) {
ScoredResult[] lastForShardArr = lastIndexResultMap.get(lir.getIndexName());
// initialize with last results
for (ScoredResult sr : lir.getLastForShardList()) {
lastForShardArr[sr.getShard()] = sr;
}
}
Map> facetCountsMap = new HashMap<>();
Map> shardsReturnedMap = new HashMap<>();
Map fullResultsMap = new HashMap<>();
Map minForShardMap = new HashMap<>();
Map> analysisRequestToTermMap = new HashMap<>();
int shardIndex = 0;
for (ShardQueryResponse sr : shardResponses) {
for (FacetGroup fg : sr.getFacetGroupList()) {
CountRequest countRequest = fg.getCountRequest();
Map facetCounts = facetCountsMap.get(countRequest);
Map shardsReturned = shardsReturnedMap.get(countRequest);
FixedBitSet fullResults = fullResultsMap.get(countRequest);
long[] minForShard = minForShardMap.get(countRequest);
if (facetCounts == null) {
facetCounts = new HashMap<>();
facetCountsMap.put(countRequest, facetCounts);
shardsReturned = new HashMap<>();
shardsReturnedMap.put(countRequest, shardsReturned);
fullResults = new FixedBitSet(shardResponses.size());
fullResultsMap.put(countRequest, fullResults);
minForShard = new long[shardResponses.size()];
minForShardMap.put(countRequest, minForShard);
}
for (FacetCount fc : fg.getFacetCountList()) {
String facet = fc.getFacet();
AtomicLong facetSum = facetCounts.get(facet);
FixedBitSet shardSet = shardsReturned.get(facet);
if (facetSum == null) {
facetSum = new AtomicLong();
facetCounts.put(facet, facetSum);
shardSet = new FixedBitSet(shardResponses.size());
shardsReturned.put(facet, shardSet);
}
long count = fc.getCount();
facetSum.addAndGet(count);
shardSet.set(shardIndex);
minForShard[shardIndex] = count;
}
int shardFacets = countRequest.getShardFacets();
int facetCountCount = fg.getFacetCountCount();
if (facetCountCount < shardFacets || (shardFacets == -1)) {
fullResults.set(shardIndex);
minForShard[shardIndex] = 0;
}
}
for (AnalysisResult analysisResult : sr.getAnalysisResultList()) {
AnalysisRequest analysisRequest = analysisResult.getAnalysisRequest();
if (!analysisRequestToTermMap.containsKey(analysisRequest)) {
analysisRequestToTermMap.put(analysisRequest, new HashMap<>());
}
Map termMap = analysisRequestToTermMap.get(analysisRequest);
for (Term term : analysisResult.getTermsList()) {
String key = term.getValue();
if (!termMap.containsKey(key)) {
termMap.put(key, Term.newBuilder().setValue(key).setDocFreq(0).setTermFreq(0));
}
Term.Builder termsBuilder = termMap.get(key);
termsBuilder.setDocFreq(termsBuilder.getDocFreq() + term.getDocFreq());
termsBuilder.setScore(termsBuilder.getScore() + term.getScore());
termsBuilder.setTermFreq(termsBuilder.getTermFreq() + term.getTermFreq());
}
}
shardIndex++;
}
for (AnalysisRequest analysisRequest : analysisRequestList) {
Map termMap = analysisRequestToTermMap.get(analysisRequest);
if (termMap != null) {
List terms = new ArrayList<>(termMap.values());
List topTerms = TermFreq.getTopTerms(terms, analysisRequest.getTopN(), analysisRequest.getTermSort());
AnalysisResult.Builder analysisResultBuilder = AnalysisResult.newBuilder().setAnalysisRequest(analysisRequest);
topTerms.forEach(analysisResultBuilder::addTerms);
builder.addAnalysisResult(analysisResultBuilder);
}
}
for (CountRequest countRequest : facetCountsMap.keySet()) {
FacetGroup.Builder fg = FacetGroup.newBuilder();
fg.setCountRequest(countRequest);
Map facetCounts = facetCountsMap.get(countRequest);
Map shardsReturned = shardsReturnedMap.get(countRequest);
FixedBitSet fullResults = fullResultsMap.get(countRequest);
long[] minForShard = minForShardMap.get(countRequest);
int numberOfShards = shardResponses.size();
long maxValuePossibleMissing = 0;
for (int i = 0; i < numberOfShards; i++) {
maxValuePossibleMissing += minForShard[i];
}
boolean computeError = countRequest.getMaxFacets() > 0 && countRequest.getShardFacets() > 0 && numberOfShards > 1;
boolean computePossibleMissing = computeError && (maxValuePossibleMissing != 0);
SortedSet sortedFacetResults = facetCounts.keySet().stream()
.map(facet -> new FacetCountResult(facet, facetCounts.get(facet).get())).collect(Collectors.toCollection(TreeSet::new));
Integer maxCount = countRequest.getMaxFacets();
long minCountReturned = 0;
int count = 0;
for (FacetCountResult facet : sortedFacetResults) {
FixedBitSet shardCount = shardsReturned.get(facet.getFacet());
shardCount.or(fullResults);
FacetCount.Builder facetCountBuilder = FacetCount.newBuilder().setFacet(facet.getFacet()).setCount(facet.getCount());
long maxWithError = 0;
if (computeError) {
long maxError = 0;
if (shardCount.cardinality() < numberOfShards) {
for (int i = 0; i < numberOfShards; i++) {
if (!shardCount.get(i)) {
maxError += minForShard[i];
}
}
}
facetCountBuilder.setMaxError(maxError);
maxWithError = maxError + facet.getCount();
}
count++;
if (maxCount > 0 && count > maxCount) {
if (computePossibleMissing) {
if (maxWithError > maxValuePossibleMissing) {
maxValuePossibleMissing = maxWithError;
}
}
else {
break;
}
}
else {
fg.addFacetCount(facetCountBuilder);
minCountReturned = facet.getCount();
}
}
if (!sortedFacetResults.isEmpty()) {
if (maxValuePossibleMissing > minCountReturned) {
fg.setPossibleMissing(true);
fg.setMaxValuePossibleMissing(maxValuePossibleMissing);
}
}
builder.addFacetGroup(fg);
}
List mergedResults = new ArrayList<>((int) returnedHits);
for (ShardQueryResponse sr : shardResponses) {
mergedResults.addAll(sr.getScoredResultList());
}
Comparator myCompare = scoreCompare;
if (sorting) {
final List fieldSortList = sortRequest.getFieldSortList();
final HashMap sortTypeMap = new HashMap<>();
for (FieldSort fieldSort : fieldSortList) {
String sortField = fieldSort.getSortField();
for (ZuliaIndex index : indexes) {
FieldConfig.FieldType currentSortType = sortTypeMap.get(sortField);
FieldConfig.FieldType indexSortType = index.getSortFieldType(sortField);
if (currentSortType == null) {
sortTypeMap.put(sortField, indexSortType);
}
else {
if (!currentSortType.equals(indexSortType)) {
log.severe("Sort fields must be defined the same in all indexes searched in a single query");
String message =
"Cannot sort on field <" + sortField + ">: found type: <" + currentSortType + "> then type: <" + indexSortType + ">";
log.severe(message);
throw new Exception(message);
}
}
}
}
myCompare = (o1, o2) -> {
int compare = 0;
int sortValueIndex = 0;
SortValues sortValues1 = o1.getSortValues();
SortValues sortValues2 = o2.getSortValues();
for (FieldSort fs : fieldSortList) {
String sortField = fs.getSortField();
FieldConfig.FieldType sortType = sortTypeMap.get(sortField);
if (FieldConfig.FieldType.NUMERIC_INT.equals(sortType)) {
Integer a = sortValues1.getSortValue(sortValueIndex).getIntegerValue();
Integer b = sortValues2.getSortValue(sortValueIndex).getIntegerValue();
compare = Comparator.nullsLast(Integer::compareTo).compare(a, b);
}
else if (FieldConfig.FieldType.NUMERIC_LONG.equals(sortType) || FieldConfig.FieldType.DATE.equals(sortType)) {
Long a = sortValues1.getSortValue(sortValueIndex).getLongValue();
Long b = sortValues2.getSortValue(sortValueIndex).getLongValue();
compare = Comparator.nullsLast(Long::compareTo).compare(a, b);
}
else if (FieldConfig.FieldType.NUMERIC_FLOAT.equals(sortType)) {
Float a = sortValues1.getSortValue(sortValueIndex).getFloatValue();
Float b = sortValues2.getSortValue(sortValueIndex).getFloatValue();
compare = Comparator.nullsLast(Float::compareTo).compare(a, b);
}
else if (FieldConfig.FieldType.NUMERIC_DOUBLE.equals(sortType)) {
Double a = sortValues1.getSortValue(sortValueIndex).getDoubleValue();
Double b = sortValues2.getSortValue(sortValueIndex).getDoubleValue();
compare = Comparator.nullsLast(Double::compareTo).compare(a, b);
}
else {
String a = sortValues1.getSortValue(sortValueIndex).getStringValue();
String b = sortValues2.getSortValue(sortValueIndex).getStringValue();
//compare = Comparator.nullsLast(String::compareTo).compare(a, b);
compare = Comparator.nullsLast(BytesRef::compareTo).compare(new BytesRef(a), new BytesRef(b));
}
if (FieldSort.Direction.DESCENDING.equals(fs.getDirection())) {
compare *= -1;
}
if (compare != 0) {
return compare;
}
sortValueIndex++;
}
return compare;
};
}
if (!mergedResults.isEmpty()) {
mergedResults.sort(myCompare);
results = mergedResults.subList(0, resultsSize);
for (ScoredResult sr : results) {
ScoredResult[] lastForShardArr = lastIndexResultMap.get(sr.getIndexName());
lastForShardArr[sr.getShard()] = sr;
}
outside:
for (ZuliaIndex index : indexes) {
String indexName = index.getIndexName();
ScoredResult[] lastForShardArr = lastIndexResultMap.get(indexName);
ScoredResult lastForIndex = null;
for (ScoredResult sr : lastForShardArr) {
if (sr != null) {
if (lastForIndex == null) {
lastForIndex = sr;
}
else {
if (myCompare.compare(sr, lastForIndex) > 0) {
lastForIndex = sr;
}
}
}
}
if (lastForIndex == null) {
//this happen when amount from index is zero
continue;
}
double shardTolerance = index.getShardTolerance();
int numberOfShards = index.getNumberOfShards();
Map shardResponseMap = indexToShardQueryResponseMap.get(indexName);
for (int shardNumber = 0; shardNumber < numberOfShards; shardNumber++) {
ShardQueryResponse sr = shardResponseMap.get(shardNumber);
if (sr.hasNext()) {
ScoredResult next = sr.getNext();
int compare = myCompare.compare(lastForIndex, next);
if (compare > 0) {
if (sorting) {
String msg = "Result set did not return the most relevant sorted documents for index <" + indexName + ">\n";
msg += " Last for index from shard <" + lastForIndex.getShard() + "> has sort values <" + lastForIndex.getSortValues()
+ ">\n";
msg += " Next for shard <" + next.getShard() + "> has sort values <" + next.getSortValues() + ">\n";
msg += " Last for shards: \n";
msg += " " + Arrays.toString(lastForShardArr) + "\n";
msg += " Results: \n";
msg += " " + results + "\n";
msg += " If this happens frequently increase requestFactor or minShardRequest\n";
msg += " Retrying with full request.\n";
log.severe(msg);
isShort = true;
break outside;
}
double diff = (Math.abs(lastForIndex.getScore() - next.getScore()));
if (diff > shardTolerance) {
String msg = "Result set did not return the most relevant documents for index <" + indexName + "> with shard tolerance <"
+ shardTolerance + ">\n";
msg += " Query <" + query + ">\n";
msg += " Last for index from shard <" + lastForIndex.getShard() + "> has score <" + lastForIndex.getScore() + ">\n";
msg += " Next for shard <" + next.getShard() + "> has score <" + next.getScore() + ">\n";
msg += " Last for shards: \n";
msg += " " + Arrays.toString(lastForShardArr) + "\n";
msg += " Results: \n";
msg += " " + results + "\n";
msg += " If this happens frequently increase requestFactor, minShardRequest, or shardTolerance\n";
msg += " Retrying with full request.\n";
log.severe(msg);
isShort = true;
break outside;
}
}
}
}
}
}
int i = 0;
for (ScoredResult scoredResult : results) {
if (i >= start) {
builder.addResults(scoredResult);
}
i++;
}
LastResult.Builder newLastResultBuilder = LastResult.newBuilder();
for (String indexName : lastIndexResultMap.keySet()) {
ScoredResult[] lastForShardArr = lastIndexResultMap.get(indexName);
int numberOfShards = indexToShardCount.get(indexName);
List indexList = new ArrayList<>();
for (int shard = 0; shard < numberOfShards; shard++) {
if (lastForShardArr[shard] != null) {
ScoredResult.Builder minimalSR = ScoredResult.newBuilder(lastForShardArr[shard]);
minimalSR = minimalSR.clearUniqueId().clearIndexName().clearResultIndex().clearTimestamp().clearResultDocument();
indexList.add(minimalSR.build());
}
}
if (!indexList.isEmpty()) {
LastIndexResult lastIndexResult = LastIndexResult.newBuilder().addAllLastForShard(indexList).setIndexName(indexName).build();
newLastResultBuilder.addLastIndexResult(lastIndexResult);
}
}
builder.setLastResult(newLastResultBuilder.build());
return builder.build();
}
public boolean isShort() {
return isShort;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy