com.google.cloud.firestore.AggregateQuery Maven / Gradle / Ivy
Show all versions of google-cloud-firestore Show documentation
/*
* Copyright 2022 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.cloud.firestore;
import static com.google.cloud.firestore.telemetry.TraceUtil.ATTRIBUTE_KEY_ATTEMPT;
import static com.google.cloud.firestore.telemetry.TraceUtil.SPAN_NAME_RUN_AGGREGATION_QUERY;
import com.google.api.core.ApiFuture;
import com.google.api.core.InternalExtensionOnly;
import com.google.api.core.SettableApiFuture;
import com.google.api.gax.rpc.ResponseObserver;
import com.google.api.gax.rpc.ServerStreamingCallable;
import com.google.api.gax.rpc.StatusCode;
import com.google.api.gax.rpc.StreamController;
import com.google.cloud.Timestamp;
import com.google.cloud.firestore.telemetry.TraceUtil;
import com.google.cloud.firestore.telemetry.TraceUtil.Scope;
import com.google.cloud.firestore.v1.FirestoreSettings;
import com.google.common.collect.ImmutableMap;
import com.google.firestore.v1.RunAggregationQueryRequest;
import com.google.firestore.v1.RunAggregationQueryResponse;
import com.google.firestore.v1.RunQueryRequest;
import com.google.firestore.v1.StructuredAggregationQuery;
import com.google.firestore.v1.StructuredAggregationQuery.Aggregation;
import com.google.firestore.v1.StructuredQuery;
import com.google.firestore.v1.Value;
import com.google.protobuf.ByteString;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
/** A query that calculates aggregations over an underlying query. */
@InternalExtensionOnly
public class AggregateQuery {
@Nonnull private final Query query;
@Nonnull private final List aggregateFieldList;
@Nonnull private final Map aliasMap;
AggregateQuery(@Nonnull Query query, @Nonnull List aggregateFields) {
this.query = query;
this.aggregateFieldList = aggregateFields;
this.aliasMap = new HashMap<>();
}
@Nonnull
private TraceUtil getTraceUtil() {
return query.getFirestore().getOptions().getTraceUtil();
}
/** Returns the query whose aggregations will be calculated by this object. */
@Nonnull
public Query getQuery() {
return query;
}
/**
* Executes this query.
*
* @return An {@link ApiFuture} that will be resolved with the results of the query.
*/
@Nonnull
public ApiFuture get() {
return get(null, null);
}
/**
* Plans and optionally executes this query. Returns an ApiFuture that will be resolved with the
* planner information, statistics from the query execution (if any), and the query results (if
* any).
*
* @return An ApiFuture that will be resolved with the planner information, statistics from the
* query execution (if any), and the query results (if any).
*/
@Nonnull
public ApiFuture> explain(ExplainOptions options) {
TraceUtil.Span span = getTraceUtil().startSpan(TraceUtil.SPAN_NAME_AGGREGATION_QUERY_GET);
try (Scope ignored = span.makeCurrent()) {
AggregateQueryExplainResponseDeliverer responseDeliverer =
new AggregateQueryExplainResponseDeliverer(
/* transactionId= */ null,
/* readTime= */ null,
/* startTimeNanos= */ query.rpcContext.getClock().nanoTime(),
/* explainOptions= */ options);
runQuery(responseDeliverer, /* attempt */ 0);
ApiFuture> result = responseDeliverer.getFuture();
span.endAtFuture(result);
return result;
} catch (Exception error) {
span.end(error);
throw error;
}
}
@Nonnull
ApiFuture get(
@Nullable final ByteString transactionId, @Nullable com.google.protobuf.Timestamp readTime) {
TraceUtil.Span span =
getTraceUtil()
.startSpan(
transactionId == null
? TraceUtil.SPAN_NAME_AGGREGATION_QUERY_GET
: TraceUtil.SPAN_NAME_TRANSACTION_GET_AGGREGATION_QUERY);
try (Scope ignored = span.makeCurrent()) {
AggregateQueryResponseDeliverer responseDeliverer =
new AggregateQueryResponseDeliverer(
transactionId,
readTime,
/* startTimeNanos= */ query.rpcContext.getClock().nanoTime());
runQuery(responseDeliverer, /* attempt= */ 0);
ApiFuture result = responseDeliverer.getFuture();
span.endAtFuture(result);
return result;
} catch (Exception error) {
span.end(error);
throw error;
}
}
private void runQuery(ResponseDeliverer responseDeliverer, int attempt) {
RunAggregationQueryRequest request =
toProto(
responseDeliverer.getTransactionId(),
responseDeliverer.getReadTime(),
responseDeliverer.getExplainOptions());
AggregateQueryResponseObserver responseObserver =
new AggregateQueryResponseObserver(responseDeliverer, attempt);
ServerStreamingCallable callable =
query.rpcContext.getClient().runAggregationQueryCallable();
query.rpcContext.streamRequest(request, responseObserver, callable);
}
@Nonnull
private Map convertServerAggregateFieldsMapToClientAggregateFieldsMap(
@Nonnull Map data) {
ImmutableMap.Builder builder = ImmutableMap.builder();
data.forEach((serverAlias, value) -> builder.put(aliasMap.get(serverAlias), value));
return builder.build();
}
private abstract static class ResponseDeliverer {
private final @Nullable ByteString transactionId;
private final @Nullable com.google.protobuf.Timestamp readTime;
private final long startTimeNanos;
private final SettableApiFuture future = SettableApiFuture.create();
ResponseDeliverer(
@Nullable ByteString transactionId,
@Nullable com.google.protobuf.Timestamp readTime,
long startTimeNanos) {
this.transactionId = transactionId;
this.readTime = readTime;
this.startTimeNanos = startTimeNanos;
}
@Nullable
ByteString getTransactionId() {
return transactionId;
}
@Nullable
com.google.protobuf.Timestamp getReadTime() {
return readTime;
}
long getStartTimeNanos() {
return startTimeNanos;
}
@Nullable
ExplainOptions getExplainOptions() {
return null;
}
ApiFuture getFuture() {
return future;
}
protected void setFuture(T value) {
future.set(value);
}
void deliverError(Throwable throwable) {
future.setException(throwable);
}
abstract void deliverResult(
@Nullable Map serverData,
Timestamp readTime,
@Nullable ExplainMetrics metrics);
}
private class AggregateQueryResponseDeliverer extends ResponseDeliverer {
AggregateQueryResponseDeliverer(
@Nullable ByteString transactionId,
@Nullable com.google.protobuf.Timestamp readTime,
long startTimeNanos) {
super(transactionId, readTime, startTimeNanos);
}
@Override
void deliverResult(
@Nullable Map serverData,
Timestamp readTime,
@Nullable ExplainMetrics metrics) {
if (serverData == null) {
deliverError(new RuntimeException("Did not receive any aggregate query results."));
return;
}
setFuture(
new AggregateQuerySnapshot(
AggregateQuery.this,
readTime,
convertServerAggregateFieldsMapToClientAggregateFieldsMap(serverData)));
}
}
private final class AggregateQueryExplainResponseDeliverer
extends ResponseDeliverer> {
private final @Nullable ExplainOptions explainOptions;
AggregateQueryExplainResponseDeliverer(
@Nullable ByteString transactionId,
@Nullable com.google.protobuf.Timestamp readTime,
long startTimeNanos,
@Nullable ExplainOptions explainOptions) {
super(transactionId, readTime, startTimeNanos);
this.explainOptions = explainOptions;
}
@Override
@Nullable
ExplainOptions getExplainOptions() {
return explainOptions;
}
@Override
void deliverResult(
@Nullable Map serverData,
Timestamp readTime,
@Nullable ExplainMetrics metrics) {
// The server is required to provide ExplainMetrics for explain queries.
if (metrics == null) {
deliverError(new RuntimeException("Did not receive any metrics for explain query."));
return;
}
AggregateQuerySnapshot snapshot =
serverData == null
? null
: new AggregateQuerySnapshot(
AggregateQuery.this,
readTime,
convertServerAggregateFieldsMapToClientAggregateFieldsMap(serverData));
setFuture(new ExplainResults<>(metrics, snapshot));
}
}
private final class AggregateQueryResponseObserver
implements ResponseObserver {
private final ResponseDeliverer responseDeliverer;
private Timestamp readTime = Timestamp.MAX_VALUE;
@Nullable private Map aggregateFieldsMap = null;
@Nullable private ExplainMetrics metrics = null;
private int attempt;
AggregateQueryResponseObserver(ResponseDeliverer responseDeliverer, int attempt) {
this.responseDeliverer = responseDeliverer;
this.attempt = attempt;
}
Map getAttemptAttributes() {
return Collections.singletonMap(ATTRIBUTE_KEY_ATTEMPT, attempt);
}
private boolean isExplainQuery() {
return this.responseDeliverer.getExplainOptions() != null;
}
@Override
public void onStart(StreamController streamController) {
getTraceUtil()
.currentSpan()
.addEvent(SPAN_NAME_RUN_AGGREGATION_QUERY + " Stream started.", getAttemptAttributes());
}
@Override
public void onResponse(RunAggregationQueryResponse response) {
getTraceUtil()
.currentSpan()
.addEvent(
SPAN_NAME_RUN_AGGREGATION_QUERY + " Response Received.", getAttemptAttributes());
if (response.hasReadTime()) {
readTime = Timestamp.fromProto(response.getReadTime());
}
if (response.hasResult()) {
aggregateFieldsMap = response.getResult().getAggregateFieldsMap();
}
if (response.hasExplainMetrics()) {
metrics = new ExplainMetrics(response.getExplainMetrics());
}
if (!isExplainQuery()) {
// Deliver the result; even though the `RunAggregationQuery` RPC is a "streaming" RPC,
// meaning that `onResponse()` can be called multiple times, it _should_ only be called
// once for non-explain queries. But even if it is called more than once,
// `responseDeliverer` will drop superfluous results. For explain queries, there will
// be more than one response, and the last response will contain the metrics.
onComplete();
}
}
@Override
public void onError(Throwable throwable) {
if (shouldRetry(throwable)) {
getTraceUtil()
.currentSpan()
.addEvent(
SPAN_NAME_RUN_AGGREGATION_QUERY + ": Retryable Error",
Collections.singletonMap("error.message", throwable.getMessage()));
runQuery(responseDeliverer, attempt + 1);
} else {
getTraceUtil()
.currentSpan()
.addEvent(
SPAN_NAME_RUN_AGGREGATION_QUERY + ": Error",
Collections.singletonMap("error.message", throwable.getMessage()));
responseDeliverer.deliverError(throwable);
}
}
private boolean shouldRetry(Throwable throwable) {
// Do not retry EXPLAIN requests because it'd be executing
// multiple queries. This means stats would have to be aggregated,
// and that may not even make sense for many statistics.
if (isExplainQuery()) {
return false;
}
Set retryableCodes =
FirestoreSettings.newBuilder().runAggregationQuerySettings().getRetryableCodes();
return query.shouldRetryQuery(
throwable,
responseDeliverer.getTransactionId(),
responseDeliverer.getStartTimeNanos(),
retryableCodes);
}
@Override
public void onComplete() {
responseDeliverer.deliverResult(aggregateFieldsMap, readTime, metrics);
}
}
/**
* Returns the {@link RunAggregationQueryRequest} that this AggregateQuery instance represents.
* The request contain the serialized form of all aggregations and Query constraints.
*
* @return the serialized RunAggregationQueryRequest
*/
@Nonnull
public RunAggregationQueryRequest toProto() {
return toProto(/* transactionId= */ null, /* readTime= */ null, /* explainOptions= */ null);
}
@Nonnull
RunAggregationQueryRequest toProto(
@Nullable final ByteString transactionId,
@Nullable final com.google.protobuf.Timestamp readTime,
@Nullable ExplainOptions explainOptions) {
RunQueryRequest runQueryRequest = query.toProto();
RunAggregationQueryRequest.Builder request = RunAggregationQueryRequest.newBuilder();
request.setParent(runQueryRequest.getParent());
if (transactionId != null) {
request.setTransaction(transactionId);
}
if (readTime != null) {
request.setReadTime(readTime);
}
if (explainOptions != null) {
request.setExplainOptions(explainOptions.toProto());
}
StructuredAggregationQuery.Builder structuredAggregationQuery =
request.getStructuredAggregationQueryBuilder();
structuredAggregationQuery.setStructuredQuery(runQueryRequest.getStructuredQuery());
// We use this set to remove duplicate aggregates.
// For example, `aggregate(sum("foo"), sum("foo"))`
HashSet uniqueAggregates = new HashSet<>();
List aggregations = new ArrayList<>();
int aggregationNum = 0;
for (AggregateField aggregateField : aggregateFieldList) {
// `getAlias()` provides a unique representation of an AggregateField.
boolean isNewAggregateField = uniqueAggregates.add(aggregateField.getAlias());
if (!isNewAggregateField) {
// This is a duplicate AggregateField. We don't need to include it in the request.
continue;
}
// If there's a field for this aggregation, build its proto.
StructuredQuery.FieldReference field = null;
if (!aggregateField.getFieldPath().isEmpty()) {
field =
StructuredQuery.FieldReference.newBuilder()
.setFieldPath(aggregateField.getFieldPath())
.build();
}
// Build the aggregation proto.
Aggregation.Builder aggregation = Aggregation.newBuilder();
if (aggregateField instanceof AggregateField.CountAggregateField) {
aggregation.setCount(Aggregation.Count.getDefaultInstance());
} else if (aggregateField instanceof AggregateField.SumAggregateField) {
aggregation.setSum(Aggregation.Sum.newBuilder().setField(field).build());
} else if (aggregateField instanceof AggregateField.AverageAggregateField) {
aggregation.setAvg(Aggregation.Avg.newBuilder().setField(field).build());
} else {
throw new RuntimeException("Unsupported aggregation");
}
// Map all client-side aliases to a unique short-form alias.
// This avoids issues with client-side aliases that exceed the 1500-byte string size limit.
String serverAlias = "aggregate_" + aggregationNum++;
aliasMap.put(serverAlias, aggregateField.getAlias());
aggregation.setAlias(serverAlias);
aggregations.add(aggregation.build());
}
structuredAggregationQuery.addAllAggregations(aggregations);
return request.build();
}
/**
* Returns an AggregateQuery instance that can be used to execute the provided {@link
* RunAggregationQueryRequest}.
*
* Only RunAggregationQueryRequests that pertain to the same project as the Firestore instance
* can be deserialized.
*
* @param firestore a Firestore instance to apply the query to.
* @param proto the serialized RunAggregationQueryRequest.
* @return a AggregateQuery instance that can be used to execute the RunAggregationQueryRequest.
*/
@Nonnull
public static AggregateQuery fromProto(Firestore firestore, RunAggregationQueryRequest proto) {
RunQueryRequest runQueryRequest =
RunQueryRequest.newBuilder()
.setParent(proto.getParent())
.setStructuredQuery(proto.getStructuredAggregationQuery().getStructuredQuery())
.build();
Query query = Query.fromProto(firestore, runQueryRequest);
List aggregateFields = new ArrayList<>();
List aggregations = proto.getStructuredAggregationQuery().getAggregationsList();
aggregations.forEach(
aggregation -> {
if (aggregation.hasCount()) {
aggregateFields.add(AggregateField.count());
} else if (aggregation.hasAvg()) {
aggregateFields.add(
AggregateField.average(aggregation.getAvg().getField().getFieldPath()));
} else if (aggregation.hasSum()) {
aggregateFields.add(AggregateField.sum(aggregation.getSum().getField().getFieldPath()));
} else {
throw new RuntimeException("Unsupported aggregation.");
}
});
return new AggregateQuery(query, aggregateFields);
}
/**
* Calculates and returns the hash code for this object.
*
* @return the hash code for this object.
*/
@Override
public int hashCode() {
return Objects.hash(query, aggregateFieldList);
}
/**
* Compares this object with the given object for equality.
*
* This object is considered "equal" to the other object if and only if all of the following
* conditions are satisfied:
*
*
* - {@code object} is a non-null instance of {@link AggregateQuery}.
*
- {@code object} performs the same aggregations as this {@link AggregateQuery}.
*
- The underlying {@link Query} of {@code object} compares equal to that of this object.
*
*
* @param object The object to compare to this object for equality.
* @return {@code true} if this object is "equal" to the given object, as defined above, or {@code
* false} otherwise.
*/
@Override
public boolean equals(Object object) {
if (object == this) {
return true;
} else if (!(object instanceof AggregateQuery)) {
return false;
}
AggregateQuery other = (AggregateQuery) object;
return query.equals(other.query) && aggregateFieldList.equals(other.aggregateFieldList);
}
}