com.hazelcast.jet.sql.impl.opt.physical.AggregateAbstractPhysicalRule Maven / Gradle / Ivy
/*
* Copyright 2021 Hazelcast Inc.
*
* Licensed under the Hazelcast Community License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://hazelcast.com/hazelcast-community-license
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.hazelcast.jet.sql.impl.opt.physical;
import com.hazelcast.function.BiConsumerEx;
import com.hazelcast.function.FunctionEx;
import com.hazelcast.function.SupplierEx;
import com.hazelcast.jet.aggregate.AggregateOperation;
import com.hazelcast.jet.impl.execution.init.Contexts;
import com.hazelcast.jet.sql.impl.ExpressionUtil;
import com.hazelcast.jet.sql.impl.JetSqlSerializerHook;
import com.hazelcast.jet.sql.impl.aggregate.AvgSqlAggregations;
import com.hazelcast.jet.sql.impl.aggregate.CountSqlAggregations;
import com.hazelcast.jet.sql.impl.aggregate.JsonObjectAggAggregation;
import com.hazelcast.jet.sql.impl.aggregate.MaxSqlAggregation;
import com.hazelcast.jet.sql.impl.aggregate.MinSqlAggregation;
import com.hazelcast.jet.sql.impl.aggregate.OrderedJsonArrayAggAggregation;
import com.hazelcast.jet.sql.impl.aggregate.SqlAggregation;
import com.hazelcast.jet.sql.impl.aggregate.SumSqlAggregations;
import com.hazelcast.jet.sql.impl.aggregate.UnorderedJsonArrayAggAggregation;
import com.hazelcast.jet.sql.impl.aggregate.ValueSqlAggregation;
import com.hazelcast.jet.sql.impl.aggregate.function.HazelcastJsonArrayAggFunction;
import com.hazelcast.jet.sql.impl.aggregate.function.HazelcastJsonObjectAggFunction;
import com.hazelcast.jet.sql.impl.opt.OptUtils;
import com.hazelcast.nio.ObjectDataInput;
import com.hazelcast.nio.ObjectDataOutput;
import com.hazelcast.nio.serialization.IdentifiedDataSerializable;
import com.hazelcast.sql.impl.QueryException;
import com.hazelcast.sql.impl.row.JetSqlRow;
import com.hazelcast.sql.impl.type.QueryDataType;
import com.hazelcast.org.apache.calcite.plan.RelRule;
import com.hazelcast.org.apache.calcite.plan.RelRule.Config;
import com.hazelcast.org.apache.calcite.rel.RelFieldCollation;
import com.hazelcast.org.apache.calcite.rel.core.AggregateCall;
import com.hazelcast.org.apache.calcite.rel.type.RelDataType;
import com.hazelcast.org.apache.calcite.sql.SqlKind;
import com.hazelcast.org.apache.calcite.util.ImmutableBitSet;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import static com.hazelcast.jet.sql.impl.opt.FieldCollation.convertCollation;
public abstract class AggregateAbstractPhysicalRule extends RelRule {
protected AggregateAbstractPhysicalRule(Config config) {
super(config);
}
protected static AggregateOperation, JetSqlRow> aggregateOperation(
RelDataType inputType,
ImmutableBitSet groupSet,
List aggregateCalls
) {
List operandTypes = OptUtils.schema(inputType).getTypes();
List> aggregationProviders = new ArrayList<>();
List> valueProviders = new ArrayList<>();
for (Integer groupIndex : groupSet.toList()) {
aggregationProviders.add(ValueSqlAggregation::new);
// getMaybeSerialized is safe for ValueAggr because it only passes the value on
valueProviders.add(new RowGetMaybeSerializedFn(groupIndex));
}
for (AggregateCall aggregateCall : aggregateCalls) {
boolean distinct = aggregateCall.isDistinct();
List aggregateCallArguments = aggregateCall.getArgList();
SqlKind kind = aggregateCall.getAggregation().getKind();
switch (kind) {
case COUNT:
if (distinct) {
int countIndex = aggregateCallArguments.get(0);
aggregationProviders.add(new AggregateCountSupplier(true, true));
// getMaybeSerialized is safe for COUNT because the aggregation only looks whether it is null or not
valueProviders.add(new RowGetMaybeSerializedFn(countIndex));
} else if (aggregateCallArguments.size() == 1) {
int countIndex = aggregateCallArguments.get(0);
aggregationProviders.add(new AggregateCountSupplier(true, false));
valueProviders.add(new RowGetMaybeSerializedFn(countIndex));
} else {
aggregationProviders.add(new AggregateCountSupplier(false, false));
valueProviders.add(NullFunction.INSTANCE);
}
break;
case MIN:
int minIndex = aggregateCallArguments.get(0);
aggregationProviders.add(MinSqlAggregation::new);
valueProviders.add(new RowGetFn(minIndex));
break;
case MAX:
int maxIndex = aggregateCallArguments.get(0);
aggregationProviders.add(MaxSqlAggregation::new);
valueProviders.add(new RowGetFn(maxIndex));
break;
case SUM:
int sumIndex = aggregateCallArguments.get(0);
QueryDataType sumOperandType = operandTypes.get(sumIndex);
aggregationProviders.add(new AggregateSumSupplier(distinct, sumOperandType));
valueProviders.add(new RowGetFn(sumIndex));
break;
case AVG:
int avgIndex = aggregateCallArguments.get(0);
QueryDataType avgOperandType = operandTypes.get(avgIndex);
aggregationProviders.add(new AggregateAvgSupplier(distinct, avgOperandType));
valueProviders.add(new RowGetFn(avgIndex));
break;
case JSON_ARRAYAGG:
int arrayAggIndex = aggregateCallArguments.get(0);
List colls = aggregateCall.getCollation().getFieldCollations();
if (colls.size() > 0) {
ExpressionUtil.SqlRowComparator comparator = new ExpressionUtil.SqlRowComparator(convertCollation(colls));
HazelcastJsonArrayAggFunction agg = (HazelcastJsonArrayAggFunction) aggregateCall.getAggregation();
aggregationProviders.add(new AggregateArrayAggSupplier(comparator, agg.isAbsentOnNull(), arrayAggIndex));
valueProviders.add(new RowIdentityFn());
} else {
HazelcastJsonArrayAggFunction agg = (HazelcastJsonArrayAggFunction) aggregateCall.getAggregation();
aggregationProviders.add(new AggregateArrayAggSupplier(agg.isAbsentOnNull()));
valueProviders.add(new RowGetFn(arrayAggIndex));
}
break;
case JSON_OBJECTAGG:
int keyIndex = aggregateCallArguments.get(0);
int valueIndex = aggregateCallArguments.get(1);
HazelcastJsonObjectAggFunction objAgg = (HazelcastJsonObjectAggFunction) aggregateCall.getAggregation();
aggregationProviders.add(new AggregateObjectAggSupplier(keyIndex, valueIndex, objAgg.isAbsentOnNull()));
valueProviders.add(new RowIdentityFn());
break;
default:
throw QueryException.error("Unsupported aggregation function: " + kind);
}
}
return AggregateOperation
.withCreate(new AggregateCreateSupplier(aggregationProviders))
.andAccumulate(new AggregateAccumulateFunction(valueProviders))
.andCombine(AggregateCombineFunction.INSTANCE)
.andExport(AggregateExportFunction.INSTANCE)
.andFinish(AggregateFinishFunction.INSTANCE);
}
public static class AggregateAvgSupplier implements IdentifiedDataSerializable,
SupplierEx {
private boolean distinct;
private QueryDataType avgOperandType;
public AggregateAvgSupplier() {
}
public AggregateAvgSupplier(boolean distinct, QueryDataType avgOperandType) {
this.distinct = distinct;
this.avgOperandType = avgOperandType;
}
@Override
public SqlAggregation getEx() {
return AvgSqlAggregations.from(avgOperandType, distinct);
}
@Override
public void writeData(ObjectDataOutput out) throws IOException {
out.writeBoolean(distinct);
out.writeObject(avgOperandType);
}
@Override
public void readData(ObjectDataInput in) throws IOException {
distinct = in.readBoolean();
avgOperandType = in.readObject();
}
@Override
public int getFactoryId() {
return JetSqlSerializerHook.F_ID;
}
@Override
public int getClassId() {
return JetSqlSerializerHook.AGGREGATE_AVG_SUPPLIER;
}
}
public static class AggregateSumSupplier implements IdentifiedDataSerializable,
SupplierEx {
private boolean distinct;
private QueryDataType sumOperandType;
public AggregateSumSupplier() {
}
public AggregateSumSupplier(boolean distinct, QueryDataType sumOperandType) {
this.distinct = distinct;
this.sumOperandType = sumOperandType;
}
@Override
public SqlAggregation getEx() {
return SumSqlAggregations.from(sumOperandType, distinct);
}
@Override
public void writeData(ObjectDataOutput out) throws IOException {
out.writeBoolean(distinct);
out.writeObject(sumOperandType);
}
@Override
public void readData(ObjectDataInput in) throws IOException {
distinct = in.readBoolean();
sumOperandType = in.readObject();
}
@Override
public int getFactoryId() {
return JetSqlSerializerHook.F_ID;
}
@Override
public int getClassId() {
return JetSqlSerializerHook.AGGREGATE_SUM_SUPPLIER;
}
}
public static class AggregateArrayAggSupplier implements IdentifiedDataSerializable,
SupplierEx {
private boolean ordered;
private ExpressionUtil.SqlRowComparator comparator;
private boolean isAbsentOnNull;
private int aggIndex;
public AggregateArrayAggSupplier() {
}
public AggregateArrayAggSupplier(boolean isAbsentOnNull) {
this.ordered = false;
this.isAbsentOnNull = isAbsentOnNull;
}
public AggregateArrayAggSupplier(ExpressionUtil.SqlRowComparator comparator, boolean isAbsentOnNull, int aggIndex) {
this.ordered = true;
this.comparator = comparator;
this.isAbsentOnNull = isAbsentOnNull;
this.aggIndex = aggIndex;
}
@Override
public SqlAggregation getEx() {
if (comparator == null) {
return UnorderedJsonArrayAggAggregation.create(isAbsentOnNull);
} else {
return OrderedJsonArrayAggAggregation.create(comparator, isAbsentOnNull, aggIndex);
}
}
@Override
public void writeData(ObjectDataOutput out) throws IOException {
out.writeBoolean(ordered);
if (ordered) {
out.writeObject(comparator);
out.writeBoolean(isAbsentOnNull);
out.writeInt(aggIndex);
} else {
out.writeBoolean(isAbsentOnNull);
}
}
@Override
public void readData(ObjectDataInput in) throws IOException {
ordered = in.readBoolean();
if (ordered) {
comparator = in.readObject(ExpressionUtil.SqlRowComparator.class);
isAbsentOnNull = in.readBoolean();
aggIndex = in.readInt();
} else {
isAbsentOnNull = in.readBoolean();
}
}
@Override
public int getFactoryId() {
return JetSqlSerializerHook.F_ID;
}
@Override
public int getClassId() {
return JetSqlSerializerHook.AGGREGATE_JSON_ARRAY_AGG_SUPPLIER;
}
}
public static final class AggregateObjectAggSupplier implements IdentifiedDataSerializable, SupplierEx {
private int keyIndex;
private int valueIndex;
private boolean isAbsentOnNull;
public AggregateObjectAggSupplier() {
}
public AggregateObjectAggSupplier(int keyIndex, int valueIndex, boolean isAbsentOnNull) {
this.keyIndex = keyIndex;
this.valueIndex = valueIndex;
this.isAbsentOnNull = isAbsentOnNull;
}
@Override
public SqlAggregation getEx() {
return new JsonObjectAggAggregation(keyIndex, valueIndex, isAbsentOnNull);
}
@Override
public void writeData(ObjectDataOutput out) throws IOException {
out.writeInt(keyIndex);
out.writeInt(valueIndex);
out.writeBoolean(isAbsentOnNull);
}
@Override
public void readData(ObjectDataInput in) throws IOException {
keyIndex = in.readInt();
valueIndex = in.readInt();
isAbsentOnNull = in.readBoolean();
}
@Override
public int getFactoryId() {
return JetSqlSerializerHook.F_ID;
}
@Override
public int getClassId() {
return JetSqlSerializerHook.AGGREGATE_JSON_OBJECT_AGG_SUPPLIER;
}
}
public static final class AggregateCountSupplier implements IdentifiedDataSerializable,
SupplierEx {
private boolean ignoreNulls;
private boolean distinct;
public AggregateCountSupplier() {
}
public AggregateCountSupplier(boolean ignoreNulls, boolean distinct) {
this.ignoreNulls = ignoreNulls;
this.distinct = distinct;
}
@Override
public SqlAggregation getEx() {
return CountSqlAggregations.from(ignoreNulls, distinct);
}
@Override
public void writeData(ObjectDataOutput out) throws IOException {
out.writeBoolean(ignoreNulls);
out.writeBoolean(distinct);
}
@Override
public void readData(ObjectDataInput in) throws IOException {
ignoreNulls = in.readBoolean();
distinct = in.readBoolean();
}
@Override
public int getFactoryId() {
return JetSqlSerializerHook.F_ID;
}
@Override
public int getClassId() {
return JetSqlSerializerHook.AGGREGATE_COUNT_SUPPLIER;
}
}
public static class AggregateCreateSupplier implements IdentifiedDataSerializable, SupplierEx> {
private List> aggregationProviders;
public AggregateCreateSupplier() {
}
public AggregateCreateSupplier(List> aggregationProviders) {
this.aggregationProviders = aggregationProviders;
}
@Override
public List getEx() {
List aggregations = new ArrayList<>(aggregationProviders.size());
for (SupplierEx aggregationProvider : aggregationProviders) {
aggregations.add(aggregationProvider.get());
}
return aggregations;
}
@Override
public void writeData(ObjectDataOutput out) throws IOException {
out.writeInt(aggregationProviders.size());
for (SupplierEx aggregationProvider : aggregationProviders) {
out.writeObject(aggregationProvider);
}
}
@Override
public void readData(ObjectDataInput in) throws IOException {
int aggregationProvidersSize = in.readInt();
aggregationProviders = new ArrayList<>(aggregationProvidersSize);
for (int i = 0; i < aggregationProvidersSize; i++) {
aggregationProviders.add(in.readObject());
}
}
@Override
public int getFactoryId() {
return JetSqlSerializerHook.F_ID;
}
@Override
public int getClassId() {
return JetSqlSerializerHook.AGGREGATE_CREATE_SUPPLIER;
}
}
public static class AggregateAccumulateFunction implements IdentifiedDataSerializable,
BiConsumerEx, JetSqlRow> {
private List> valueProviders;
public AggregateAccumulateFunction() {
}
public AggregateAccumulateFunction(List> valueProviders) {
this.valueProviders = valueProviders;
}
@Override
public void acceptEx(List aggregations, JetSqlRow row) {
for (int i = 0; i < aggregations.size(); i++) {
aggregations.get(i).accumulate(valueProviders.get(i).apply(row));
}
}
@Override
public void writeData(ObjectDataOutput out) throws IOException {
out.writeInt(valueProviders.size());
for (FunctionEx aggregationProvider : valueProviders) {
out.writeObject(aggregationProvider);
}
}
@Override
public void readData(ObjectDataInput in) throws IOException {
int aggregationProvidersSize = in.readInt();
valueProviders = new ArrayList<>(aggregationProvidersSize);
for (int i = 0; i < aggregationProvidersSize; i++) {
valueProviders.add(in.readObject());
}
}
@Override
public int getFactoryId() {
return JetSqlSerializerHook.F_ID;
}
@Override
public int getClassId() {
return JetSqlSerializerHook.AGGREGATE_ACCUMULATE_FUNCTION;
}
}
public static final class AggregateCombineFunction implements IdentifiedDataSerializable,
BiConsumerEx, List> {
public static final AggregateCombineFunction INSTANCE = new AggregateCombineFunction();
private AggregateCombineFunction() {
}
@Override
public void acceptEx(List lefts, List rights) {
assert lefts.size() == rights.size();
for (int i = 0; i < lefts.size(); i++) {
lefts.get(i).combine(rights.get(i));
}
}
@Override
public void writeData(ObjectDataOutput out) throws IOException {
}
@Override
public void readData(ObjectDataInput in) throws IOException {
}
@Override
public int getFactoryId() {
return JetSqlSerializerHook.F_ID;
}
@Override
public int getClassId() {
return JetSqlSerializerHook.AGGREGATE_COMBINE_FUNCTION;
}
}
public static final class AggregateFinishFunction implements IdentifiedDataSerializable,
FunctionEx, JetSqlRow> {
public static final AggregateFinishFunction INSTANCE = new AggregateFinishFunction();
private AggregateFinishFunction() {
}
@Override
public JetSqlRow applyEx(List aggregations) {
Object[] row = new Object[aggregations.size()];
for (int i = 0; i < aggregations.size(); i++) {
row[i] = aggregations.get(i).collect();
}
return new JetSqlRow(Contexts.getCastedThreadContext().serializationService(), row);
}
@Override
public void writeData(ObjectDataOutput out) throws IOException {
}
@Override
public void readData(ObjectDataInput in) throws IOException {
}
@Override
public int getFactoryId() {
return JetSqlSerializerHook.F_ID;
}
@Override
public int getClassId() {
return JetSqlSerializerHook.AGGREGATE_FINISH_FUNCTION;
}
}
public static final class AggregateExportFunction implements IdentifiedDataSerializable,
FunctionEx, JetSqlRow> {
public static final AggregateExportFunction INSTANCE = new AggregateExportFunction();
private AggregateExportFunction() {
}
@Override
public JetSqlRow applyEx(List aggregations) {
throw new UnsupportedOperationException("Export function should not be called");
}
@Override
public void writeData(ObjectDataOutput out) throws IOException {
}
@Override
public void readData(ObjectDataInput in) throws IOException {
}
@Override
public int getFactoryId() {
return JetSqlSerializerHook.F_ID;
}
@Override
public int getClassId() {
return JetSqlSerializerHook.AGGREGATE_EXPORT_FUNCTION;
}
}
public static class RowGetMaybeSerializedFn implements IdentifiedDataSerializable, FunctionEx {
private Integer groupIndex;
public RowGetMaybeSerializedFn() {
}
public RowGetMaybeSerializedFn(Integer groupIndex) {
this.groupIndex = groupIndex;
}
@Override
public Object applyEx(JetSqlRow row) {
return row.getMaybeSerialized(groupIndex);
}
@Override
public void writeData(ObjectDataOutput out) throws IOException {
out.writeObject(groupIndex);
}
@Override
public void readData(ObjectDataInput in) throws IOException {
groupIndex = in.readObject();
}
@Override
public int getFactoryId() {
return JetSqlSerializerHook.F_ID;
}
@Override
public int getClassId() {
return JetSqlSerializerHook.ROW_GET_MAYBE_SERIALIZED_FN;
}
}
public static class RowIdentityFn implements IdentifiedDataSerializable, FunctionEx {
public RowIdentityFn() {
}
@Override
public Object applyEx(JetSqlRow row) {
return row;
}
@Override
public void writeData(ObjectDataOutput out) throws IOException {
}
@Override
public void readData(ObjectDataInput in) throws IOException {
}
@Override
public int getFactoryId() {
return JetSqlSerializerHook.F_ID;
}
@Override
public int getClassId() {
return JetSqlSerializerHook.ROW_IDENTITY_FN;
}
}
public static class RowGetFn implements IdentifiedDataSerializable, FunctionEx {
private int index;
public RowGetFn() {
}
public RowGetFn(Integer index) {
this.index = index;
}
@Override
public Object applyEx(JetSqlRow row) {
return row.get(index);
}
@Override
public void writeData(ObjectDataOutput out) throws IOException {
out.writeInt(index);
}
@Override
public void readData(ObjectDataInput in) throws IOException {
index = in.readInt();
}
@Override
public int getFactoryId() {
return JetSqlSerializerHook.F_ID;
}
@Override
public int getClassId() {
return JetSqlSerializerHook.ROW_GET_FN;
}
}
public static final class NullFunction implements IdentifiedDataSerializable, FunctionEx {
public static final NullFunction INSTANCE = new NullFunction();
private NullFunction() {
}
@Override
public Object applyEx(JetSqlRow row) {
return null;
}
@Override
public void writeData(ObjectDataOutput out) throws IOException {
}
@Override
public void readData(ObjectDataInput in) throws IOException {
}
@Override
public int getFactoryId() {
return JetSqlSerializerHook.F_ID;
}
@Override
public int getClassId() {
return JetSqlSerializerHook.NULL_FUNCTION;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy