
org.apache.kylin.query.relnode.OLAPAggregateRel Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of kylin-query Show documentation
Show all versions of kylin-query Show documentation
kylin query engine based on Calcite
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kylin.query.relnode;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.calcite.adapter.enumerable.EnumerableAggregate;
import org.apache.calcite.adapter.enumerable.EnumerableConvention;
import org.apache.calcite.adapter.enumerable.EnumerableRel;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.plan.RelTrait;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.InvalidRelException;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.schema.AggregateFunction;
import org.apache.calcite.schema.FunctionParameter;
import org.apache.calcite.schema.impl.AggregateFunctionImpl;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Util;
import org.apache.kylin.metadata.model.ColumnDesc;
import org.apache.kylin.metadata.model.FunctionDesc;
import org.apache.kylin.metadata.model.MeasureDesc;
import org.apache.kylin.metadata.model.ParameterDesc;
import org.apache.kylin.metadata.model.TableDesc;
import org.apache.kylin.metadata.model.TblColRef;
import org.apache.kylin.query.schema.OLAPTable;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
/**
*/
public class OLAPAggregateRel extends Aggregate implements OLAPRel {
private final static Map AGGR_FUNC_MAP = new HashMap();
static {
AGGR_FUNC_MAP.put("SUM", "SUM");
AGGR_FUNC_MAP.put("$SUM0", "SUM");
AGGR_FUNC_MAP.put("COUNT", "COUNT");
AGGR_FUNC_MAP.put("COUNT_DISTINCT", "COUNT_DISTINCT");
AGGR_FUNC_MAP.put("MAX", "MAX");
AGGR_FUNC_MAP.put("MIN", "MIN");
}
private static String getFuncName(AggregateCall aggCall) {
String aggName = aggCall.getAggregation().getName();
if (aggCall.isDistinct()) {
aggName = aggName + "_DISTINCT";
}
String funcName = AGGR_FUNC_MAP.get(aggName);
if (funcName == null) {
throw new IllegalStateException("Don't suppoprt aggregation " + aggName);
}
return funcName;
}
private OLAPContext context;
private ColumnRowType columnRowType;
private boolean afterAggregate;
private List rewriteAggCalls;
private List groups;
private List aggregations;
public OLAPAggregateRel(RelOptCluster cluster, RelTraitSet traits, RelNode child, ImmutableBitSet groupSet, List aggCalls) throws InvalidRelException {
super(cluster, traits, child, false, groupSet, asList(groupSet), aggCalls);
Preconditions.checkArgument(getConvention() == OLAPRel.CONVENTION);
this.afterAggregate = false;
this.rewriteAggCalls = aggCalls;
this.rowType = getRowType();
}
private static List asList(ImmutableBitSet groupSet) {
ArrayList l = new ArrayList(1);
l.add(groupSet);
return l;
}
@Override
public Aggregate copy(RelTraitSet traitSet, RelNode input, boolean indicator, ImmutableBitSet groupSet, List groupSets, List aggCalls) {
try {
return new OLAPAggregateRel(getCluster(), traitSet, input, groupSet, aggCalls);
} catch (InvalidRelException e) {
throw new IllegalStateException("Can't create OLAPAggregateRel!", e);
}
}
@Override
public RelOptCost computeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) {
return super.computeSelfCost(planner, mq).multiplyBy(.05);
}
@Override
public void implementOLAP(OLAPImplementor implementor) {
implementor.visitChild(getInput(), this);
this.context = implementor.getContext();
this.columnRowType = buildColumnRowType();
this.afterAggregate = this.context.afterAggregate;
// only translate the innermost aggregation
if (!this.afterAggregate) {
translateGroupBy();
this.context.aggregations.addAll(this.aggregations);
this.context.afterAggregate = true;
} else {
for (AggregateCall aggCall : aggCalls) {
// check if supported by kylin
if (aggCall.isDistinct()) {
throw new IllegalStateException("Distinct count is only allowed in innermost sub-query.");
}
}
}
}
private ColumnRowType buildColumnRowType() {
buildGroups();
buildAggregations();
ColumnRowType inputColumnRowType = ((OLAPRel) getInput()).getColumnRowType();
List columns = new ArrayList(this.rowType.getFieldCount());
columns.addAll(this.groups);
for (int i = 0; i < this.aggregations.size(); i++) {
FunctionDesc aggFunc = this.aggregations.get(i);
TblColRef aggCol = null;
if (aggFunc.needRewriteField()) {
aggCol = buildRewriteColumn(aggFunc);
} else {
AggregateCall aggCall = this.rewriteAggCalls.get(i);
if (!aggCall.getArgList().isEmpty()) {
int index = aggCall.getArgList().get(0);
aggCol = inputColumnRowType.getColumnByIndex(index);
}
}
columns.add(aggCol);
}
return new ColumnRowType(columns);
}
private TblColRef buildRewriteColumn(FunctionDesc aggFunc) {
TblColRef colRef;
if (aggFunc.needRewriteField()) {
ColumnDesc column = new ColumnDesc();
column.setName(aggFunc.getRewriteFieldName());
TableDesc table = this.context.firstTableScan.getOlapTable().getSourceTable();
column.setTable(table);
colRef = new TblColRef(column);
} else {
throw new IllegalStateException("buildRewriteColumn on a aggrFunc that does not need rewrite " + aggFunc);
}
return colRef;
}
private void buildGroups() {
ColumnRowType inputColumnRowType = ((OLAPRel) getInput()).getColumnRowType();
this.groups = new ArrayList();
for (int i = getGroupSet().nextSetBit(0); i >= 0; i = getGroupSet().nextSetBit(i + 1)) {
Set columns = inputColumnRowType.getSourceColumnsByIndex(i);
this.groups.addAll(columns);
}
}
private void buildAggregations() {
ColumnRowType inputColumnRowType = ((OLAPRel) getInput()).getColumnRowType();
this.aggregations = new ArrayList();
for (AggregateCall aggCall : this.rewriteAggCalls) {
ParameterDesc parameter = null;
if (!aggCall.getArgList().isEmpty()) {
int index = aggCall.getArgList().get(0);
TblColRef column = inputColumnRowType.getColumnByIndex(index);
if (!column.isInnerColumn()) {
parameter = new ParameterDesc();
parameter.setValue(column.getName());
parameter.setType(FunctionDesc.PARAMETER_TYPE_COLUMN);
parameter.setColRefs(Arrays.asList(column));
}
}
FunctionDesc aggFunc = new FunctionDesc();
String funcName = getFuncName(aggCall);
aggFunc.setExpression(funcName);
aggFunc.setParameter(parameter);
this.aggregations.add(aggFunc);
}
}
private void translateGroupBy() {
context.groupByColumns.addAll(this.groups);
}
@Override
public void implementRewrite(RewriteImplementor implementor) {
// only rewrite the innermost aggregation
if (!this.afterAggregate) {
translateAggregation();
buildRewriteFieldsAndMetricsColumns();
}
implementor.visitChild(this, getInput());
// only rewrite the innermost aggregation
if (!this.afterAggregate && RewriteImplementor.needRewrite(this.context)) {
// rewrite the aggCalls
this.rewriteAggCalls = new ArrayList(aggCalls.size());
for (int i = 0; i < this.aggCalls.size(); i++) {
AggregateCall aggCall = this.aggCalls.get(i);
FunctionDesc cubeFunc = this.context.aggregations.get(i);
if (cubeFunc.needRewrite()) {
aggCall = rewriteAggregateCall(aggCall, cubeFunc);
}
this.rewriteAggCalls.add(aggCall);
}
}
// rebuild rowType & columnRowType
this.rowType = this.deriveRowType();
this.columnRowType = this.buildColumnRowType();
}
private void translateAggregation() {
// now the realization is known, replace aggregations with what's defined on MeasureDesc
List measures = this.context.realization.getMeasures();
List newAggrs = Lists.newArrayList();
for (FunctionDesc aggFunc : this.aggregations) {
newAggrs.add(findInMeasures(aggFunc, measures));
}
this.aggregations.clear();
this.aggregations.addAll(newAggrs);
this.context.aggregations.clear();
this.context.aggregations.addAll(newAggrs);
}
private FunctionDesc findInMeasures(FunctionDesc aggFunc, List measures) {
for (MeasureDesc m : measures) {
if (aggFunc.equals(m.getFunction()))
return m.getFunction();
}
return aggFunc;
}
private void buildRewriteFieldsAndMetricsColumns() {
fillbackOptimizedColumn();
ColumnRowType inputColumnRowType = ((OLAPRel) getInput()).getColumnRowType();
RelDataTypeFactory typeFactory = getCluster().getTypeFactory();
for (int i = 0; i < this.aggregations.size(); i++) {
FunctionDesc aggFunc = this.aggregations.get(i);
if (aggFunc.isDimensionAsMetric()) {
this.context.groupByColumns.addAll(aggFunc.getParameter().getColRefs());
continue; // skip rewrite, let calcite handle
}
if (aggFunc.needRewriteField()) {
String rewriteFieldName = aggFunc.getRewriteFieldName();
RelDataType rewriteFieldType = OLAPTable.createSqlType(typeFactory, aggFunc.getRewriteFieldType(), true);
this.context.rewriteFields.put(rewriteFieldName, rewriteFieldType);
TblColRef column = buildRewriteColumn(aggFunc);
this.context.metricsColumns.add(column);
}
AggregateCall aggCall = this.rewriteAggCalls.get(i);
if (!aggCall.getArgList().isEmpty()) {
int index = aggCall.getArgList().get(0);
TblColRef column = inputColumnRowType.getColumnByIndex(index);
if (!column.isInnerColumn()) {
this.context.metricsColumns.add(column);
}
}
}
}
private void fillbackOptimizedColumn() {
// some aggcall will be optimized out in sub-query (e.g. tableau generated sql), we need to fill them back
RelDataType inputAggRow = getInput().getRowType();
RelDataType outputAggRow = getRowType();
if (inputAggRow.getFieldCount() != outputAggRow.getFieldCount()) {
for (RelDataTypeField inputField : inputAggRow.getFieldList()) {
String inputFieldName = inputField.getName();
if (outputAggRow.getField(inputFieldName, true, false) == null) {
TblColRef column = this.columnRowType.getColumnByIndex(inputField.getIndex());
this.context.metricsColumns.add(column);
}
}
}
}
private AggregateCall rewriteAggregateCall(AggregateCall aggCall, FunctionDesc func) {
// rebuild parameters
List newArgList = Lists.newArrayListWithCapacity(1);
if (func.needRewriteField()) {
RelDataTypeField field = getInput().getRowType().getField(func.getRewriteFieldName(), true, false);
newArgList.add(field.getIndex());
} else {
newArgList = aggCall.getArgList();
}
// rebuild function
RelDataType fieldType = aggCall.getType();
SqlAggFunction newAgg = aggCall.getAggregation();
if (func.isCount()) {
newAgg = SqlStdOperatorTable.SUM0;
} else if (func.getMeasureType().getRewriteCalciteAggrFunctionClass() != null) {
newAgg = createCustomAggFunction(func.getExpression(), fieldType, func.getMeasureType().getRewriteCalciteAggrFunctionClass());
}
// rebuild aggregate call
AggregateCall newAggCall = new AggregateCall(newAgg, false, newArgList, fieldType, newAgg.getName());
return newAggCall;
}
private SqlAggFunction createCustomAggFunction(String funcName, RelDataType returnType, Class> customAggFuncClz) {
RelDataTypeFactory typeFactory = getCluster().getTypeFactory();
SqlIdentifier sqlIdentifier = new SqlIdentifier(funcName, new SqlParserPos(1, 1));
AggregateFunction aggFunction = AggregateFunctionImpl.create(customAggFuncClz);
List argTypes = new ArrayList();
List typeFamilies = new ArrayList();
for (FunctionParameter o : aggFunction.getParameters()) {
final RelDataType type = o.getType(typeFactory);
argTypes.add(type);
typeFamilies.add(Util.first(type.getSqlTypeName().getFamily(), SqlTypeFamily.ANY));
}
return new SqlUserDefinedAggFunction(sqlIdentifier, ReturnTypes.explicit(returnType), InferTypes.explicit(argTypes), OperandTypes.family(typeFamilies), aggFunction);
}
@Override
public EnumerableRel implementEnumerable(List inputs) {
try {
return new EnumerableAggregate(getCluster(), getCluster().traitSetOf(EnumerableConvention.INSTANCE), //
sole(inputs), false, this.groupSet, this.groupSets, rewriteAggCalls);
} catch (InvalidRelException e) {
throw new IllegalStateException("Can't create EnumerableAggregate!", e);
}
}
@Override
public OLAPContext getContext() {
return context;
}
@Override
public ColumnRowType getColumnRowType() {
return columnRowType;
}
@Override
public boolean hasSubQuery() {
OLAPRel olapChild = (OLAPRel) getInput();
return olapChild.hasSubQuery();
}
@Override
public RelTraitSet replaceTraitSet(RelTrait trait) {
RelTraitSet oldTraitSet = this.traitSet;
this.traitSet = this.traitSet.replace(trait);
return oldTraitSet;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy