All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.lumongo.server.search.QueryCombiner Maven / Gradle / Ivy

The newest version!
package org.lumongo.server.search;

import org.apache.log4j.Logger;
import org.apache.lucene.util.FixedBitSet;
import org.lumongo.cluster.message.Lumongo;
import org.lumongo.cluster.message.Lumongo.*;
import org.lumongo.server.index.LumongoIndex;
import org.lumongo.server.index.analysis.TermFreq;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;

public class QueryCombiner {

	private final static Comparator scoreCompare = new ScoreCompare();

	private final static Logger log = Logger.getLogger(QueryCombiner.class);

	private final Map usedIndexMap;
	private final List responses;

	private final Map> indexToSegmentResponseMap;
	private final List segmentResponses;

	private final int amount;
	private final int start;
	private final LastResult lastResult;
	private final List analysisRequestList;

	private boolean isShort;
	private List results;
	private int resultsSize;

	private SortRequest sortRequest;

	private Lumongo.Query query;

	public QueryCombiner(Map usedIndexMap, QueryRequest request, List responses) {
		this.usedIndexMap = usedIndexMap;
		this.responses = responses;
		this.amount = request.getAmount() + request.getStart();
		this.indexToSegmentResponseMap = new HashMap<>();
		this.segmentResponses = new ArrayList<>();
		this.lastResult = request.getLastResult();
		this.sortRequest = request.getSortRequest();
		this.start = request.getStart();
		this.query = request.getQuery();
		this.analysisRequestList = request.getAnalysisRequestList();

		this.isShort = false;
		this.results = Collections.emptyList();
		this.resultsSize = 0;
	}

	public void validate() throws Exception {
		for (InternalQueryResponse iqr : responses) {

			for (IndexSegmentResponse isr : iqr.getIndexSegmentResponseList()) {
				String indexName = isr.getIndexName();
				if (!indexToSegmentResponseMap.containsKey(indexName)) {
					indexToSegmentResponseMap.put(indexName, new HashMap<>());
				}

				for (SegmentResponse sr : isr.getSegmentReponseList()) {
					int segmentNumber = sr.getSegmentNumber();

					Map segmentResponseMap = indexToSegmentResponseMap.get(indexName);

					if (segmentResponseMap.containsKey(segmentNumber)) {
						throw new Exception("Segment <" + segmentNumber + "> is repeated for <" + indexName + ">");
					}
					else {
						segmentResponseMap.put(segmentNumber, sr);
						segmentResponses.add(sr);
					}
				}

			}

		}

		for (String indexName : usedIndexMap.keySet()) {
			int numberOfSegments = usedIndexMap.get(indexName).getNumberOfSegments();
			Map segmentResponseMap = indexToSegmentResponseMap.get(indexName);

			if (segmentResponseMap == null) {
				throw new Exception("Missing index <" + indexName + "> in response");
			}

			if (segmentResponseMap.size() != numberOfSegments) {
				throw new Exception("Found <" + segmentResponseMap.size() + "> expected <" + numberOfSegments + ">");
			}

			for (int segmentNumber = 0; segmentNumber < numberOfSegments; segmentNumber++) {
				if (!segmentResponseMap.containsKey(segmentNumber)) {
					throw new Exception("Missing segment <" + segmentNumber + ">");
				}
			}
		}
	}

	public QueryResponse getQueryResponse() throws Exception {

		boolean sorting = (sortRequest != null && !sortRequest.getFieldSortList().isEmpty());

		long totalHits = 0;
		long returnedHits = 0;
		for (SegmentResponse sr : segmentResponses) {
			totalHits += sr.getTotalHits();
			returnedHits += sr.getScoredResultList().size();
		}

		QueryResponse.Builder builder = QueryResponse.newBuilder();
		builder.setTotalHits(totalHits);

		resultsSize = Math.min(amount, (int) returnedHits);

		results = Collections.emptyList();

		Map lastIndexResultMap = new HashMap<>();

		for (String indexName : indexToSegmentResponseMap.keySet()) {
			int numberOfSegments = usedIndexMap.get(indexName).getNumberOfSegments();
			lastIndexResultMap.put(indexName, new ScoredResult[numberOfSegments]);
		}

		for (LastIndexResult lir : lastResult.getLastIndexResultList()) {
			ScoredResult[] lastForSegmentArr = lastIndexResultMap.get(lir.getIndexName());
			// initialize with last results
			for (ScoredResult sr : lir.getLastForSegmentList()) {
				lastForSegmentArr[sr.getSegment()] = sr;
			}
		}

		Map> facetCountsMap = new HashMap<>();
		Map> segmentsReturnedMap = new HashMap<>();
		Map fullResultsMap = new HashMap<>();
		Map minForSegmentMap = new HashMap<>();

		Map> analysisRequestToTermMap = new HashMap<>();

		int segIndex = 0;

		for (SegmentResponse sr : segmentResponses) {

			for (FacetGroup fg : sr.getFacetGroupList()) {

				CountRequest countRequest = fg.getCountRequest();
				Map facetCounts = facetCountsMap.get(countRequest);
				Map segmentsReturned = segmentsReturnedMap.get(countRequest);
				FixedBitSet fullResults = fullResultsMap.get(countRequest);
				long[] minForSegment = minForSegmentMap.get(countRequest);

				if (facetCounts == null) {
					facetCounts = new HashMap<>();
					facetCountsMap.put(countRequest, facetCounts);

					segmentsReturned = new HashMap<>();
					segmentsReturnedMap.put(countRequest, segmentsReturned);

					fullResults = new FixedBitSet(segmentResponses.size());
					fullResultsMap.put(countRequest, fullResults);

					minForSegment = new long[segmentResponses.size()];
					minForSegmentMap.put(countRequest, minForSegment);
				}

				for (FacetCount fc : fg.getFacetCountList()) {
					String facet = fc.getFacet();
					AtomicLong facetSum = facetCounts.get(facet);
					FixedBitSet segmentSet = segmentsReturned.get(facet);

					if (facetSum == null) {
						facetSum = new AtomicLong();
						facetCounts.put(facet, facetSum);
						segmentSet = new FixedBitSet(segmentResponses.size());
						segmentsReturned.put(facet, segmentSet);
					}
					long count = fc.getCount();
					facetSum.addAndGet(count);
					segmentSet.set(segIndex);

					minForSegment[segIndex] = count;
				}

				int segmentFacets = countRequest.getSegmentFacets();
				int facetCountCount = fg.getFacetCountCount();
				if (facetCountCount < segmentFacets || (segmentFacets == 0)) {
					fullResults.set(segIndex);
					minForSegment[segIndex] = 0;
				}
			}

			for (Lumongo.AnalysisResult analysisResult : sr.getAnalysisResultList()) {

				AnalysisRequest analysisRequest = analysisResult.getAnalysisRequest();
				if (!analysisRequestToTermMap.containsKey(analysisRequest)) {
					analysisRequestToTermMap.put(analysisRequest, new HashMap<>());
				}

				Map termMap = analysisRequestToTermMap.get(analysisRequest);

				for (Lumongo.Term term : analysisResult.getTermsList()) {

					String key = term.getValue();
					if (!termMap.containsKey(key)) {
						termMap.put(key, Lumongo.Term.newBuilder().setValue(key).setDocFreq(0).setTermFreq(0));
					}
					Lumongo.Term.Builder termsBuilder = termMap.get(key);
					if (term.hasDocFreq()) {
						termsBuilder.setDocFreq(termsBuilder.getDocFreq() + term.getDocFreq());
					}
					if (term.hasScore()) {
						termsBuilder.setScore(termsBuilder.getScore() + term.getScore());
					}
					termsBuilder.setTermFreq(termsBuilder.getTermFreq() + term.getTermFreq());

				}
			}

			segIndex++;
		}

		for (AnalysisRequest analysisRequest : analysisRequestList) {
			Map termMap = analysisRequestToTermMap.get(analysisRequest);
			if (termMap != null) {
				List terms = new ArrayList<>(termMap.values());
				List topTerms = TermFreq.getTopTerms(terms, analysisRequest.getTopN(), analysisRequest.getTermSort());
				AnalysisResult.Builder analysisResultBuilder = AnalysisResult.newBuilder().setAnalysisRequest(analysisRequest);
				topTerms.forEach(analysisResultBuilder::addTerms);
				builder.addAnalysisResult(analysisResultBuilder);
			}
		}

		for (CountRequest countRequest : facetCountsMap.keySet()) {

			FacetGroup.Builder fg = FacetGroup.newBuilder();
			fg.setCountRequest(countRequest);
			Map facetCounts = facetCountsMap.get(countRequest);
			Map segmentsReturned = segmentsReturnedMap.get(countRequest);
			FixedBitSet fullResults = fullResultsMap.get(countRequest);
			long[] minForSegment = minForSegmentMap.get(countRequest);

			int numberOfSegments = segmentResponses.size();
			long maxValuePossibleMissing = 0;
			for (int i = 0; i < numberOfSegments; i++) {
				maxValuePossibleMissing += minForSegment[i];
			}

			boolean computeError = countRequest.getSegmentFacets() != 0 && countRequest.getComputeError();
			boolean computePossibleMissing = countRequest.getSegmentFacets() != 0 && countRequest.getComputePossibleMissed() && (maxValuePossibleMissing != 0);

			SortedSet sortedFacetResults = facetCounts.keySet().stream()
					.map(facet -> new FacetCountResult(facet, facetCounts.get(facet).get())).collect(Collectors.toCollection(TreeSet::new));

			Integer maxCount = countRequest.getMaxFacets();

			long minCountReturned = 0;

			int count = 0;
			for (FacetCountResult facet : sortedFacetResults) {

				FixedBitSet segCount = segmentsReturned.get(facet.getFacet());
				segCount.or(fullResults);

				FacetCount.Builder facetCountBuilder = FacetCount.newBuilder().setFacet(facet.getFacet()).setCount(facet.getCount());

				long maxWithError = 0;
				if (computeError) {
					long maxError = 0;
					if (segCount.cardinality() < numberOfSegments) {
						for (int i = 0; i < numberOfSegments; i++) {
							if (!segCount.get(i)) {
								maxError += minForSegment[i];
							}
						}
					}
					facetCountBuilder.setMaxError(maxError);
					maxWithError = maxError + facet.getCount();
				}

				count++;

				if (maxCount > 0 && count > maxCount) {

					if (computePossibleMissing) {
						if (maxWithError > maxValuePossibleMissing) {
							maxValuePossibleMissing = maxWithError;
						}
					}
					else {
						break;
					}
				}
				else {
					fg.addFacetCount(facetCountBuilder);
					minCountReturned = facet.getCount();
				}
			}

			if (!sortedFacetResults.isEmpty()) {
				if (maxValuePossibleMissing > minCountReturned) {
					fg.setPossibleMissing(true);
					fg.setMaxValuePossibleMissing(maxValuePossibleMissing);
				}
			}

			builder.addFacetGroup(fg);
		}

		List mergedResults = new ArrayList<>((int) returnedHits);
		for (SegmentResponse sr : segmentResponses) {
			mergedResults.addAll(sr.getScoredResultList());
		}

		Comparator myCompare = scoreCompare;

		if (sorting) {
			final List fieldSortList = sortRequest.getFieldSortList();

			final HashMap sortTypeMap = new HashMap<>();

			for (FieldSort fieldSort : fieldSortList) {
				String sortField = fieldSort.getSortField();

				for (String indexName : usedIndexMap.keySet()) {
					LumongoIndex index = usedIndexMap.get(indexName);
					FieldConfig.FieldType currentSortType = sortTypeMap.get(sortField);

					FieldConfig.FieldType indexSortType = index.getSortFieldType(sortField);
					if (currentSortType == null) {
						sortTypeMap.put(sortField, indexSortType);
					}
					else {
						if (!currentSortType.equals(indexSortType)) {
							log.error("Sort fields must be defined the same in all indexes searched in a single query");
							String message =
									"Cannot sort on field <" + sortField + ">: found type: <" + currentSortType + "> then type: <" + indexSortType + ">";
							log.error(message);

							throw new Exception(message);
						}
					}
				}
			}

			myCompare = (o1, o2) -> {
				int compare = 0;

				int sortValueIndex = 0;

				Lumongo.SortValues sortValues1 = o1.getSortValues();
				Lumongo.SortValues sortValues2 = o2.getSortValues();
				for (FieldSort fs : fieldSortList) {
					String sortField = fs.getSortField();

					FieldConfig.FieldType sortType = sortTypeMap.get(sortField);

					if (FieldConfig.FieldType.NUMERIC_INT.equals(sortType)) {
						Integer a = null;
						Integer b = null;
						a = sortValues1.getSortValue(sortValueIndex).getIntegerValue();
						b = sortValues2.getSortValue(sortValueIndex).getIntegerValue();

						compare = Comparator.nullsLast(Integer::compareTo).compare(a, b);
					}
					else if (FieldConfig.FieldType.NUMERIC_LONG.equals(sortType) || FieldConfig.FieldType.DATE.equals(sortType)) {
						Long a = null;
						Long b = null;
						a = sortValues1.getSortValue(sortValueIndex).getLongValue();
						b = sortValues2.getSortValue(sortValueIndex).getLongValue();

						compare = Comparator.nullsLast(Long::compareTo).compare(a, b);
					}
					else if (FieldConfig.FieldType.NUMERIC_FLOAT.equals(sortType)) {
						Float a = null;
						Float b = null;
						a = sortValues1.getSortValue(sortValueIndex).getFloatValue();
						b = sortValues2.getSortValue(sortValueIndex).getFloatValue();

						compare = Comparator.nullsLast(Float::compareTo).compare(a, b);
					}
					else if (FieldConfig.FieldType.NUMERIC_DOUBLE.equals(sortType)) {
						Double a = null;
						Double b = null;
						a = sortValues1.getSortValue(sortValueIndex).getDoubleValue();
						b = sortValues2.getSortValue(sortValueIndex).getDoubleValue();

						compare = Comparator.nullsLast(Double::compareTo).compare(a, b);
					}
					else {
						String a = null;
						String b = null;
						a = sortValues1.getSortValue(sortValueIndex).getStringValue();
						b = sortValues2.getSortValue(sortValueIndex).getStringValue();

						compare = Comparator.nullsLast(String::compareTo).compare(a, b);
					}

					if (FieldSort.Direction.DESCENDING.equals(fs.getDirection())) {
						compare *= -1;
					}

					if (compare != 0) {
						return compare;
					}

					sortValueIndex++;

				}

				return compare;
			};
		}

		if (!mergedResults.isEmpty()) {
			Collections.sort(mergedResults, myCompare);

			results = mergedResults.subList(0, resultsSize);

			for (ScoredResult sr : results) {
				ScoredResult[] lastForSegmentArr = lastIndexResultMap.get(sr.getIndexName());
				lastForSegmentArr[sr.getSegment()] = sr;
			}

			outside:
			for (String indexName : usedIndexMap.keySet()) {
				ScoredResult[] lastForSegmentArr = lastIndexResultMap.get(indexName);
				ScoredResult lastForIndex = null;
				for (ScoredResult sr : lastForSegmentArr) {
					if (sr != null) {
						if (lastForIndex == null) {
							lastForIndex = sr;
						}
						else {
							if (myCompare.compare(sr, lastForIndex) > 0) {
								lastForIndex = sr;
							}
						}
					}
				}

				if (lastForIndex == null) {
					//this happen when amount from index is zero
					continue;
				}

				double segmentTolerance = usedIndexMap.get(indexName).getSegmentTolerance();

				int numberOfSegments = usedIndexMap.get(indexName).getNumberOfSegments();
				Map segmentResponseMap = indexToSegmentResponseMap.get(indexName);
				for (int segmentNumber = 0; segmentNumber < numberOfSegments; segmentNumber++) {
					SegmentResponse sr = segmentResponseMap.get(segmentNumber);
					if (sr.hasNext()) {
						ScoredResult next = sr.getNext();
						int compare = myCompare.compare(lastForIndex, next);
						if (compare > 0) {

							if (sorting) {
								String msg = "Result set did not return the most relevant sorted documents for index <" + indexName + ">\n";
								msg += "    Last for index from segment <" + lastForIndex.getSegment() + "> has sort values <" + lastForIndex.getSortValues()
										+ ">\n";
								msg += "    Next for segment <" + next.getSegment() + ">  has sort values <" + next.getSortValues() + ">\n";
								msg += "    Last for segments: \n";
								msg += "      " + Arrays.toString(lastForSegmentArr) + "\n";
								msg += "    Results: \n";
								msg += "      " + results + "\n";
								msg += "    If this happens frequently increase requestFactor or minSegmentRequest\n";
								msg += "    Retrying with full request.\n";
								log.error(msg);

								isShort = true;
								break outside;
							}

							double diff = (Math.abs(lastForIndex.getScore() - next.getScore()));
							if (diff > segmentTolerance) {
								String msg = "Result set did not return the most relevant documents for index <" + indexName + "> with segment tolerance <"
										+ segmentTolerance + ">\n";
								msg += "    Query <" + query + ">\n";
								msg += "    Last for index from segment <" + lastForIndex.getSegment() + "> has score <" + lastForIndex.getScore() + ">\n";
								msg += "    Next for segment <" + next.getSegment() + "> has score <" + next.getScore() + ">\n";
								msg += "    Last for segments: \n";
								msg += "      " + Arrays.toString(lastForSegmentArr) + "\n";
								msg += "    Results: \n";
								msg += "      " + results + "\n";
								msg += "    If this happens frequently increase requestFactor, minSegmentRequest, or segmentTolerance\n";
								msg += "    Retrying with full request.\n";
								log.error(msg);

								isShort = true;
								break outside;
							}
						}
					}
				}
			}

		}

		int i = 0;
		for (ScoredResult scoredResult : results) {
			if (i >= start) {
				builder.addResults(scoredResult);
			}
			i++;
		}

		LastResult.Builder newLastResultBuilder = LastResult.newBuilder();
		for (String indexName : lastIndexResultMap.keySet()) {
			ScoredResult[] lastForSegmentArr = lastIndexResultMap.get(indexName);
			int numberOfSegments = usedIndexMap.get(indexName).getNumberOfSegments();
			List indexList = new ArrayList<>();
			for (int seg = 0; seg < numberOfSegments; seg++) {
				if (lastForSegmentArr[seg] != null) {
					ScoredResult.Builder minimalSR = ScoredResult.newBuilder(lastForSegmentArr[seg]);
					minimalSR = minimalSR.clearUniqueId().clearIndexName().clearResultIndex().clearTimestamp().clearResultDocument();
					indexList.add(minimalSR.build());
				}
			}
			if (!indexList.isEmpty()) {
				LastIndexResult lastIndexResult = LastIndexResult.newBuilder().addAllLastForSegment(indexList).setIndexName(indexName).build();
				newLastResultBuilder.addLastIndexResult(lastIndexResult);
			}
		}

		builder.setLastResult(newLastResultBuilder.build());

		return builder.build();
	}

	public boolean isShort() {
		return isShort;
	}

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy