All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.hazelcast.org.apache.calcite.adapter.enumerable.EnumerableWindow Maven / Gradle / Ivy

There is a newer version: 5.5.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to you under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.hazelcast.org.apache.calcite.adapter.enumerable;

import com.hazelcast.org.apache.calcite.adapter.enumerable.impl.WinAggAddContextImpl;
import com.hazelcast.org.apache.calcite.adapter.enumerable.impl.WinAggResetContextImpl;
import com.hazelcast.org.apache.calcite.adapter.enumerable.impl.WinAggResultContextImpl;
import com.hazelcast.org.apache.calcite.adapter.java.JavaTypeFactory;
import com.hazelcast.org.apache.calcite.config.CalciteSystemProperty;
import com.hazelcast.org.apache.calcite.linq4j.tree.BinaryExpression;
import com.hazelcast.org.apache.calcite.linq4j.tree.BlockBuilder;
import com.hazelcast.org.apache.calcite.linq4j.tree.BlockStatement;
import com.hazelcast.org.apache.calcite.linq4j.tree.DeclarationStatement;
import com.hazelcast.org.apache.calcite.linq4j.tree.Expression;
import com.hazelcast.org.apache.calcite.linq4j.tree.Expressions;
import com.hazelcast.org.apache.calcite.linq4j.tree.ParameterExpression;
import com.hazelcast.org.apache.calcite.linq4j.tree.Primitive;
import com.hazelcast.org.apache.calcite.linq4j.tree.Statement;
import com.hazelcast.org.apache.calcite.linq4j.tree.Types;
import com.hazelcast.org.apache.calcite.plan.RelOptCluster;
import com.hazelcast.org.apache.calcite.plan.RelOptCost;
import com.hazelcast.org.apache.calcite.plan.RelOptPlanner;
import com.hazelcast.org.apache.calcite.plan.RelTraitSet;
import com.hazelcast.org.apache.calcite.rel.RelFieldCollation;
import com.hazelcast.org.apache.calcite.rel.RelNode;
import com.hazelcast.org.apache.calcite.rel.core.AggregateCall;
import com.hazelcast.org.apache.calcite.rel.core.Window;
import com.hazelcast.org.apache.calcite.rel.metadata.RelMetadataQuery;
import com.hazelcast.org.apache.calcite.rel.type.RelDataType;
import com.hazelcast.org.apache.calcite.rel.type.RelDataTypeFactory;
import com.hazelcast.org.apache.calcite.rex.RexInputRef;
import com.hazelcast.org.apache.calcite.rex.RexLiteral;
import com.hazelcast.org.apache.calcite.rex.RexNode;
import com.hazelcast.org.apache.calcite.rex.RexWindowBound;
import com.hazelcast.org.apache.calcite.runtime.SortedMultiMap;
import com.hazelcast.org.apache.calcite.sql.SqlAggFunction;
import com.hazelcast.org.apache.calcite.sql.validate.SqlConformance;
import com.hazelcast.org.apache.calcite.util.BuiltInMethod;
import com.hazelcast.org.apache.calcite.util.ImmutableBitSet;
import com.hazelcast.org.apache.calcite.util.Pair;
import com.hazelcast.org.apache.calcite.util.Util;

import com.hazelcast.com.google.common.collect.ImmutableList;

import com.hazelcast.org.checkerframework.checker.nullness.qual.Nullable;

import java.lang.reflect.Modifier;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;

import static com.hazelcast.org.apache.calcite.linq4j.Nullness.castNonNull;

import static java.util.Objects.requireNonNull;

/** Implementation of {@link com.hazelcast.org.apache.calcite.rel.core.Window} in
 * {@link com.hazelcast.org.apache.calcite.adapter.enumerable.EnumerableConvention enumerable calling convention}. */
public class EnumerableWindow extends Window implements EnumerableRel {
  /** Creates an EnumerableWindowRel. */
  EnumerableWindow(RelOptCluster cluster, RelTraitSet traits, RelNode child,
      List constants, RelDataType rowType, List groups) {
    super(cluster, traits, child, constants, rowType, groups);
  }

  @Override public RelNode copy(RelTraitSet traitSet, List inputs) {
    return new EnumerableWindow(getCluster(), traitSet, sole(inputs),
        constants, getRowType(), groups);
  }

  @Override public @Nullable RelOptCost computeSelfCost(RelOptPlanner planner,
      RelMetadataQuery mq) {
    RelOptCost cost = super.computeSelfCost(planner, mq);
    if (cost == null) {
      return null;
    }
    return cost.multiplyBy(EnumerableConvention.COST_MULTIPLIER);
  }

  /** Implementation of {@link RexToLixTranslator.InputGetter}
   * suitable for generating implementations of windowed aggregate
   * functions. */
  private static class WindowRelInputGetter
      implements RexToLixTranslator.InputGetter {
    private final Expression row;
    private final PhysType rowPhysType;
    private final int actualInputFieldCount;
    private final List constants;

    private WindowRelInputGetter(Expression row,
        PhysType rowPhysType, int actualInputFieldCount,
        List constants) {
      this.row = row;
      this.rowPhysType = rowPhysType;
      this.actualInputFieldCount = actualInputFieldCount;
      this.constants = constants;
    }

    @Override public Expression field(BlockBuilder list, int index, @Nullable Type storageType) {
      if (index < actualInputFieldCount) {
        Expression current = list.append("current", row);
        return rowPhysType.fieldReference(current, index, storageType);
      }
      return constants.get(index - actualInputFieldCount);
    }
  }

  @SuppressWarnings({"unused", "nullness"})
  private static void sampleOfTheGeneratedWindowedAggregate() {
    // Here's overview of the generated code
    // For each list of rows that have the same partitioning key, evaluate
    // all of the windowed aggregate functions.

    // builder
    Iterator iterator = null;

    // builder3
    Integer[] rows = iterator.next();

    int prevStart = -1;
    int prevEnd = -1;

    for (int i = 0; i < rows.length; i++) {
      // builder4
      Integer row = rows[i];

      int start = 0;
      int end = 100;
      if (start != prevStart || end != prevEnd) {
        // builder5
        int actualStart = 0;
        if (start != prevStart || end < prevEnd) {
          // builder6
          // recompute
          actualStart = start;
          // implementReset
        } else { // must be start == prevStart && end > prevEnd
          actualStart = prevEnd + 1;
        }
        prevStart = start;
        prevEnd = end;

        if (start != -1) {
          for (int j = actualStart; j <= end; j++) {
            // builder7
            // implementAdd
          }
        }
        // implementResult
        // list.add(new Xxx(row.deptno, row.empid, sum, count));
      }
    }
    // multiMap.clear(); // allows gc
    // source = Linq4j.asEnumerable(list);
  }

  @Override public Result implement(EnumerableRelImplementor implementor, Prefer pref) {
    final JavaTypeFactory typeFactory = implementor.getTypeFactory();
    final EnumerableRel child = (EnumerableRel) getInput();
    final BlockBuilder builder = new BlockBuilder();
    final Result result = implementor.visitChild(this, 0, child, pref);
    Expression source_ = builder.append("source", result.block);

    final List translatedConstants =
        new ArrayList<>(constants.size());
    for (RexLiteral constant : constants) {
      translatedConstants.add(
          RexToLixTranslator.translateLiteral(constant, constant.getType(),
              typeFactory, RexImpTable.NullAs.NULL));
    }

    PhysType inputPhysType = result.physType;

    ParameterExpression prevStart =
        Expressions.parameter(int.class, builder.newName("prevStart"));
    ParameterExpression prevEnd =
        Expressions.parameter(int.class, builder.newName("prevEnd"));

    builder.add(Expressions.declare(0, prevStart, null));
    builder.add(Expressions.declare(0, prevEnd, null));

    for (int windowIdx = 0; windowIdx < groups.size(); windowIdx++) {
      Group group = groups.get(windowIdx);
      // Comparator:
      // final Comparator comparator =
      //    new Comparator() {
      //      public int compare(JdbcTest.Employee o1,
      //          JdbcTest.Employee o2) {
      //        return Integer.compare(o1.empid, o2.empid);
      //      }
      //    };
      final Expression comparator_ =
          builder.append(
              "comparator",
              inputPhysType.generateComparator(
                  group.collation()));

      Pair partitionIterator =
          getPartitionIterator(builder, source_, inputPhysType, group,
              comparator_);
      final Expression collectionExpr = partitionIterator.left;
      final Expression iterator_ = partitionIterator.right;

      List aggs = new ArrayList<>();
      List aggregateCalls = group.getAggregateCalls(this);
      for (int aggIdx = 0; aggIdx < aggregateCalls.size(); aggIdx++) {
        AggregateCall call = aggregateCalls.get(aggIdx);
        if (call.ignoreNulls()) {
          throw new UnsupportedOperationException("IGNORE NULLS not supported");
        }
        aggs.add(new AggImpState(aggIdx, call, true));
      }

      // The output from this stage is the input plus the aggregate functions.
      final RelDataTypeFactory.Builder typeBuilder = typeFactory.builder();
      typeBuilder.addAll(inputPhysType.getRowType().getFieldList());
      for (AggImpState agg : aggs) {
        // CALCITE-4326
        String name = requireNonNull(agg.call.name,
            () -> "agg.call.name for " + agg.call);
        typeBuilder.add(name, agg.call.type);
      }
      RelDataType outputRowType = typeBuilder.build();
      final PhysType outputPhysType =
          PhysTypeImpl.of(
              typeFactory, outputRowType, pref.prefer(result.format));

      final Expression list_ =
          builder.append(
              "list",
              Expressions.new_(
                  ArrayList.class,
                  Expressions.call(
                      collectionExpr, BuiltInMethod.COLLECTION_SIZE.method)),
              false);

      Pair<@Nullable Expression, @Nullable Expression> collationKey =
          getRowCollationKey(builder, inputPhysType, group, windowIdx);
      Expression keySelector = collationKey.left;
      Expression keyComparator = collationKey.right;
      final BlockBuilder builder3 = new BlockBuilder();
      final Expression rows_ =
          builder3.append(
              "rows",
              Expressions.convert_(
                  Expressions.call(
                      iterator_, BuiltInMethod.ITERATOR_NEXT.method),
                  Object[].class),
              false);

      builder3.add(
          Expressions.statement(
              Expressions.assign(prevStart, Expressions.constant(-1))));
      builder3.add(
          Expressions.statement(
              Expressions.assign(prevEnd,
                  Expressions.constant(Integer.MAX_VALUE))));

      final BlockBuilder builder4 = new BlockBuilder();

      final ParameterExpression i_ =
          Expressions.parameter(int.class, builder4.newName("i"));

      final Expression row_ =
          builder4.append(
              "row",
              EnumUtils.convert(
                  Expressions.arrayIndex(rows_, i_),
                  inputPhysType.getJavaRowType()));

      final RexToLixTranslator.InputGetter inputGetter =
          new WindowRelInputGetter(row_, inputPhysType,
              result.physType.getRowType().getFieldCount(),
              translatedConstants);

      final RexToLixTranslator translator =
          RexToLixTranslator.forAggregation(typeFactory, builder4,
              inputGetter, implementor.getConformance());

      final List outputRow = new ArrayList<>();
      int fieldCountWithAggResults =
          inputPhysType.getRowType().getFieldCount();
      for (int i = 0; i < fieldCountWithAggResults; i++) {
        outputRow.add(
            inputPhysType.fieldReference(
                row_, i,
                outputPhysType.getJavaFieldType(i)));
      }

      declareAndResetState(typeFactory, builder, result, windowIdx, aggs,
          outputPhysType, outputRow);

      // There are assumptions that minX==0. If ever change this, look for
      // frameRowCount, bounds checking, etc
      final Expression minX = Expressions.constant(0);
      final Expression partitionRowCount =
          builder3.append("partRows", Expressions.field(rows_, "length"));
      final Expression maxX = builder3.append("maxX",
          Expressions.subtract(
              partitionRowCount, Expressions.constant(1)));

      final Expression startUnchecked = builder4.append("start",
          translateBound(translator, i_, row_, minX, maxX, rows_,
              group, true, inputPhysType, keySelector, keyComparator));
      final Expression endUnchecked = builder4.append("end",
          translateBound(translator, i_, row_, minX, maxX, rows_,
              group, false, inputPhysType, keySelector, keyComparator));

      final Expression startX;
      final Expression endX;
      final Expression hasRows;
      if (group.isAlwaysNonEmpty()) {
        startX = startUnchecked;
        endX = endUnchecked;
        hasRows = Expressions.constant(true);
      } else {
        Expression startTmp =
            group.lowerBound.isUnbounded() || startUnchecked == i_
                ? startUnchecked
                : builder4.append("startTmp",
                    Expressions.call(null, BuiltInMethod.MATH_MAX.method,
                        startUnchecked, minX));
        Expression endTmp =
            group.upperBound.isUnbounded() || endUnchecked == i_
                ? endUnchecked
                : builder4.append("endTmp",
                    Expressions.call(null, BuiltInMethod.MATH_MIN.method,
                        endUnchecked, maxX));

        ParameterExpression startPe = Expressions.parameter(0, int.class,
            builder4.newName("startChecked"));
        ParameterExpression endPe = Expressions.parameter(0, int.class,
            builder4.newName("endChecked"));
        builder4.add(Expressions.declare(Modifier.FINAL, startPe, null));
        builder4.add(Expressions.declare(Modifier.FINAL, endPe, null));

        hasRows = builder4.append("hasRows",
            Expressions.lessThanOrEqual(startTmp, endTmp));
        builder4.add(
            Expressions.ifThenElse(hasRows,
                Expressions.block(
                    Expressions.statement(
                        Expressions.assign(startPe, startTmp)),
                    Expressions.statement(
                      Expressions.assign(endPe, endTmp))),
            Expressions.block(
                Expressions.statement(
                    Expressions.assign(startPe, Expressions.constant(-1))),
                Expressions.statement(
                    Expressions.assign(endPe, Expressions.constant(-1))))));
        startX = startPe;
        endX = endPe;
      }

      final BlockBuilder builder5 = new BlockBuilder(true, builder4);

      BinaryExpression rowCountWhenNonEmpty = Expressions.add(
          startX == minX ? endX : Expressions.subtract(endX, startX),
          Expressions.constant(1));

      final Expression frameRowCount;

      if (hasRows.equals(Expressions.constant(true))) {
        frameRowCount =
            builder4.append("totalRows", rowCountWhenNonEmpty);
      } else {
        frameRowCount =
            builder4.append("totalRows",
                Expressions.condition(hasRows, rowCountWhenNonEmpty,
                    Expressions.constant(0)));
      }

      ParameterExpression actualStart = Expressions.parameter(
          0, int.class, builder5.newName("actualStart"));

      final BlockBuilder builder6 = new BlockBuilder(true, builder5);
      builder6.add(
          Expressions.statement(Expressions.assign(actualStart, startX)));

      for (final AggImpState agg : aggs) {
        List aggState = requireNonNull(agg.state, "agg.state");
        agg.implementor.implementReset(requireNonNull(agg.context, "agg.context"),
            new WinAggResetContextImpl(builder6, aggState, i_, startX, endX,
                hasRows, frameRowCount, partitionRowCount));
      }

      Expression lowerBoundCanChange =
          group.lowerBound.isUnbounded() && group.lowerBound.isPreceding()
          ? Expressions.constant(false)
          : Expressions.notEqual(startX, prevStart);
      Expression needRecomputeWindow = Expressions.orElse(
          lowerBoundCanChange,
          Expressions.lessThan(endX, prevEnd));

      BlockStatement resetWindowState = builder6.toBlock();
      if (resetWindowState.statements.size() == 1) {
        builder5.add(
            Expressions.declare(0, actualStart,
                Expressions.condition(needRecomputeWindow, startX,
                    Expressions.add(prevEnd, Expressions.constant(1)))));
      } else {
        builder5.add(
            Expressions.declare(0, actualStart, null));
        builder5.add(
            Expressions.ifThenElse(needRecomputeWindow,
                resetWindowState,
                Expressions.statement(
                    Expressions.assign(actualStart,
                    Expressions.add(prevEnd, Expressions.constant(1))))));
      }

      if (lowerBoundCanChange instanceof BinaryExpression) {
        builder5.add(
            Expressions.statement(Expressions.assign(prevStart, startX)));
      }
      builder5.add(
          Expressions.statement(Expressions.assign(prevEnd, endX)));

      final BlockBuilder builder7 = new BlockBuilder(true, builder5);
      final DeclarationStatement jDecl =
          Expressions.declare(0, "j", actualStart);

      final PhysType inputPhysTypeFinal = inputPhysType;
      final Function
          resultContextBuilder =
          getBlockBuilderWinAggFrameResultContextFunction(typeFactory,
              implementor.getConformance(), result, translatedConstants,
              comparator_, rows_, i_, startX, endX, minX, maxX,
              hasRows, frameRowCount, partitionRowCount,
              jDecl, inputPhysTypeFinal);

      final Function> rexArguments = agg -> {
        List argList = agg.call.getArgList();
        List inputTypes =
            EnumUtils.fieldRowTypes(
                result.physType.getRowType(),
                constants,
                argList);
        List args = new ArrayList<>(inputTypes.size());
        for (int i = 0; i < argList.size(); i++) {
          Integer idx = argList.get(i);
          args.add(new RexInputRef(idx, inputTypes.get(i)));
        }
        return args;
      };

      implementAdd(aggs, builder7, resultContextBuilder, rexArguments, jDecl);

      BlockStatement forBlock = builder7.toBlock();
      if (!forBlock.statements.isEmpty()) {
        // For instance, row_number does not use for loop to compute the value
        Statement forAggLoop = Expressions.for_(
            Arrays.asList(jDecl),
            Expressions.lessThanOrEqual(jDecl.parameter, endX),
            Expressions.preIncrementAssign(jDecl.parameter),
            forBlock);
        if (!hasRows.equals(Expressions.constant(true))) {
          forAggLoop = Expressions.ifThen(hasRows, forAggLoop);
        }
        builder5.add(forAggLoop);
      }

      if (implementResult(aggs, builder5, resultContextBuilder, rexArguments,
              true)) {
        builder4.add(
            Expressions.ifThen(
                Expressions.orElse(lowerBoundCanChange,
                    Expressions.notEqual(endX, prevEnd)),
                builder5.toBlock()));
      }

      implementResult(aggs, builder4, resultContextBuilder, rexArguments,
          false);

      builder4.add(
          Expressions.statement(
              Expressions.call(
                  list_,
                  BuiltInMethod.COLLECTION_ADD.method,
                  outputPhysType.record(outputRow))));

      builder3.add(
          Expressions.for_(
              Expressions.declare(0, i_, Expressions.constant(0)),
              Expressions.lessThan(
                  i_,
                  Expressions.field(rows_, "length")),
              Expressions.preIncrementAssign(i_),
              builder4.toBlock()));

      builder.add(
          Expressions.while_(
              Expressions.call(
                  iterator_,
                  BuiltInMethod.ITERATOR_HAS_NEXT.method),
              builder3.toBlock()));
      builder.add(
          Expressions.statement(
              Expressions.call(
                  collectionExpr,
                  BuiltInMethod.MAP_CLEAR.method)));

      // We're not assigning to "source". For each group, create a new
      // final variable called "source" or "sourceN".
      source_ =
          builder.append(
              "source",
              Expressions.call(
                  BuiltInMethod.AS_ENUMERABLE.method, list_));

      inputPhysType = outputPhysType;
    }

    //   return Linq4j.asEnumerable(list);
    builder.add(
        Expressions.return_(null, source_));
    return implementor.result(inputPhysType, builder.toBlock());
  }

  private static Function
      getBlockBuilderWinAggFrameResultContextFunction(
      final JavaTypeFactory typeFactory, final SqlConformance conformance,
      final Result result, final List translatedConstants,
      final Expression comparator_,
      final Expression rows_, final ParameterExpression i_,
      final Expression startX, final Expression endX,
      final Expression minX, final Expression maxX,
      final Expression hasRows, final Expression frameRowCount,
      final Expression partitionRowCount,
      final DeclarationStatement jDecl,
      final PhysType inputPhysType) {
    return block -> new WinAggFrameResultContext() {
      @Override public RexToLixTranslator rowTranslator(Expression rowIndex) {
        Expression row =
            getRow(rowIndex);
        final RexToLixTranslator.InputGetter inputGetter =
            new WindowRelInputGetter(row, inputPhysType,
                result.physType.getRowType().getFieldCount(),
                translatedConstants);

        return RexToLixTranslator.forAggregation(typeFactory,
            block, inputGetter, conformance);
      }

      @Override public Expression computeIndex(Expression offset,
          WinAggImplementor.SeekType seekType) {
        Expression index;
        if (seekType == WinAggImplementor.SeekType.AGG_INDEX) {
          index = jDecl.parameter;
        } else if (seekType == WinAggImplementor.SeekType.SET) {
          index = i_;
        } else if (seekType == WinAggImplementor.SeekType.START) {
          index = startX;
        } else if (seekType == WinAggImplementor.SeekType.END) {
          index = endX;
        } else {
          throw new IllegalArgumentException("SeekSet " + seekType
              + " is not supported");
        }
        if (!Expressions.constant(0).equals(offset)) {
          index = block.append("idx", Expressions.add(index, offset));
        }
        return index;
      }

      private Expression checkBounds(Expression rowIndex,
          Expression minIndex, Expression maxIndex) {
        if (rowIndex == i_ || rowIndex == startX || rowIndex == endX) {
          // No additional bounds check required
          return hasRows;
        }

        //noinspection UnnecessaryLocalVariable
        Expression res = block.append("rowInFrame",
            Expressions.foldAnd(
                ImmutableList.of(hasRows,
                    Expressions.greaterThanOrEqual(rowIndex, minIndex),
                    Expressions.lessThanOrEqual(rowIndex, maxIndex))));

        return res;
      }

      @Override public Expression rowInFrame(Expression rowIndex) {
        return checkBounds(rowIndex, startX, endX);
      }

      @Override public Expression rowInPartition(Expression rowIndex) {
        return checkBounds(rowIndex, minX, maxX);
      }

      @Override public Expression compareRows(Expression a, Expression b) {
        return Expressions.call(comparator_,
            BuiltInMethod.COMPARATOR_COMPARE.method,
            getRow(a), getRow(b));
      }

      public Expression getRow(Expression rowIndex) {
        return block.append(
            "jRow",
            EnumUtils.convert(
                Expressions.arrayIndex(rows_, rowIndex),
                inputPhysType.getJavaRowType()));
      }

      @Override public Expression index() {
        return i_;
      }

      @Override public Expression startIndex() {
        return startX;
      }

      @Override public Expression endIndex() {
        return endX;
      }

      @Override public Expression hasRows() {
        return hasRows;
      }

      @Override public Expression getFrameRowCount() {
        return frameRowCount;
      }

      @Override public Expression getPartitionRowCount() {
        return partitionRowCount;
      }
    };
  }

  private static Pair getPartitionIterator(
      BlockBuilder builder,
      Expression source_,
      PhysType inputPhysType,
      Group group,
      Expression comparator_) {
    // Populate map of lists, one per partition
    //   final Map> multiMap =
    //     new SortedMultiMap>();
    //    source.foreach(
    //      new Function1() {
    //        public Void apply(Employee v) {
    //          final Integer k = v.deptno;
    //          multiMap.putMulti(k, v);
    //          return null;
    //        }
    //      });
    //   final List list = new ArrayList(multiMap.size());
    //   Iterator iterator = multiMap.arrays(comparator);
    //
    if (group.keys.isEmpty()) {
      // If partition key is empty, no need to partition.
      //
      //   final List tempList =
      //       source.into(new ArrayList());
      //   Iterator iterator =
      //       SortedMultiMap.singletonArrayIterator(comparator, tempList);
      //   final List list = new ArrayList(tempList.size());

      final Expression tempList_ = builder.append(
          "tempList",
          Expressions.convert_(
              Expressions.call(
                  source_,
                  BuiltInMethod.INTO.method,
                  Expressions.new_(ArrayList.class)),
              List.class));
      return Pair.of(tempList_,
          builder.append(
            "iterator",
            Expressions.call(
                null,
                BuiltInMethod.SORTED_MULTI_MAP_SINGLETON.method,
                comparator_,
                tempList_)));
    }
    Expression multiMap_ =
        builder.append(
            "multiMap", Expressions.new_(SortedMultiMap.class));
    final BlockBuilder builder2 = new BlockBuilder();
    final ParameterExpression v_ =
        Expressions.parameter(inputPhysType.getJavaRowType(),
            builder2.newName("v"));

    Pair> selector =
        inputPhysType.selector(v_, group.keys.asList(), JavaRowFormat.CUSTOM);
    final ParameterExpression key_;
    if (selector.left instanceof Types.RecordType) {
      Types.RecordType keyJavaType = (Types.RecordType) selector.left;
      List initExpressions = selector.right;
      key_ = Expressions.parameter(keyJavaType, "key");
      builder2.add(Expressions.declare(0, key_, null));
      builder2.add(
          Expressions.statement(
              Expressions.assign(key_, Expressions.new_(keyJavaType))));
      List fieldList = keyJavaType.getRecordFields();
      for (int i = 0; i < initExpressions.size(); i++) {
        Expression right = initExpressions.get(i);
        builder2.add(
            Expressions.statement(
                Expressions.assign(
                    Expressions.field(key_, fieldList.get(i)), right)));
      }
    } else {
      DeclarationStatement declare =
          Expressions.declare(0, "key", selector.right.get(0));
      builder2.add(declare);
      key_ = declare.parameter;
    }
    builder2.add(
        Expressions.statement(
            Expressions.call(
                multiMap_,
                BuiltInMethod.SORTED_MULTI_MAP_PUT_MULTI.method,
                key_,
                v_)));
    builder2.add(
        Expressions.return_(
            null, Expressions.constant(null)));

    builder.add(
        Expressions.statement(
            Expressions.call(
                source_,
                BuiltInMethod.ENUMERABLE_FOREACH.method,
                Expressions.lambda(
                    builder2.toBlock(), v_))));

    return Pair.of(multiMap_,
      builder.append(
        "iterator",
        Expressions.call(
            multiMap_,
            BuiltInMethod.SORTED_MULTI_MAP_ARRAYS.method,
            comparator_)));
  }

  private static Pair<@Nullable Expression, @Nullable Expression> getRowCollationKey(
      BlockBuilder builder, PhysType inputPhysType,
      Group group, int windowIdx) {
    if (!(group.isRows
        || (group.upperBound.isUnbounded() && group.lowerBound.isUnbounded()))) {
      Pair pair =
          inputPhysType.generateCollationKey(
              group.collation().getFieldCollations());
      // optimize=false to prevent inlining of object create into for-loops
      return Pair.of(
          builder.append("keySelector" + windowIdx, pair.left, false),
          builder.append("keyComparator" + windowIdx, pair.right, false));
    } else {
      return Pair.of(null, null);
    }
  }

  private void declareAndResetState(final JavaTypeFactory typeFactory,
      BlockBuilder builder, final Result result, int windowIdx,
      List aggs, PhysType outputPhysType,
      List outputRow) {
    for (final AggImpState agg : aggs) {
      agg.context =
          new WinAggContext() {
            @Override public SqlAggFunction aggregation() {
              return agg.call.getAggregation();
            }

            @Override public RelDataType returnRelType() {
              return agg.call.type;
            }

            @Override public Type returnType() {
              return EnumUtils.javaClass(typeFactory, returnRelType());
            }

            @Override public List parameterTypes() {
              return EnumUtils.fieldTypes(typeFactory,
                  parameterRelTypes());
            }

            @Override public List parameterRelTypes() {
              return EnumUtils.fieldRowTypes(result.physType.getRowType(),
                  constants, agg.call.getArgList());
            }

            @Override public List groupSets() {
              throw new UnsupportedOperationException();
            }

            @Override public List keyOrdinals() {
              throw new UnsupportedOperationException();
            }

            @Override public List keyRelTypes() {
              throw new UnsupportedOperationException();
            }

            @Override public List keyTypes() {
              throw new UnsupportedOperationException();
            }
          };
      String aggName = "a" + agg.aggIdx;
      if (CalciteSystemProperty.DEBUG.value()) {
        aggName = Util.toJavaId(agg.call.getAggregation().getName(), 0)
            .substring("ID$0$".length()) + aggName;
      }
      List state = agg.implementor.getStateType(agg.context);
      final List decls = new ArrayList<>(state.size());
      for (int i = 0; i < state.size(); i++) {
        Type type = state.get(i);
        ParameterExpression pe =
            Expressions.parameter(type,
                builder.newName(aggName
                    + "s" + i + "w" + windowIdx));
        builder.add(Expressions.declare(0, pe, null));
        decls.add(pe);
      }
      agg.state = decls;
      Type aggHolderType = agg.context.returnType();
      Type aggStorageType =
          outputPhysType.getJavaFieldType(outputRow.size());
      if (Primitive.is(aggHolderType) && !Primitive.is(aggStorageType)) {
        aggHolderType = Primitive.box(aggHolderType);
      }
      ParameterExpression aggRes = Expressions.parameter(0,
          aggHolderType,
          builder.newName(aggName + "w" + windowIdx));

      builder.add(
          Expressions.declare(0, aggRes,
              Expressions.constant(
                  Optional.ofNullable(Primitive.of(aggRes.getType()))
                      .map(x -> x.defaultValue)
                      .orElse(null),
                  aggRes.getType())));
      agg.result = aggRes;
      outputRow.add(aggRes);
      agg.implementor.implementReset(agg.context,
          new WinAggResetContextImpl(builder, agg.state,
              castNonNull(null), castNonNull(null), castNonNull(null), castNonNull(null),
              castNonNull(null), castNonNull(null)));
    }
  }

  private static void implementAdd(List aggs,
      final BlockBuilder builder7,
      final Function frame,
      final Function> rexArguments,
      final DeclarationStatement jDecl) {
    for (final AggImpState agg : aggs) {
      final WinAggAddContext addContext =
          new WinAggAddContextImpl(builder7, requireNonNull(agg.state, "agg.state"), frame) {
            @Override public Expression currentPosition() {
              return jDecl.parameter;
            }

            @Override public List rexArguments() {
              return rexArguments.apply(agg);
            }

            @Override public @Nullable RexNode rexFilterArgument() {
              return null; // REVIEW
            }
          };
      agg.implementor.implementAdd(requireNonNull(agg.context, "agg.context"), addContext);
    }
  }

  private static boolean implementResult(List aggs,
      final BlockBuilder builder,
      final Function frame,
      final Function> rexArguments,
      boolean cachedBlock) {
    boolean nonEmpty = false;
    for (final AggImpState agg : aggs) {
      boolean needCache = true;
      if (agg.implementor instanceof WinAggImplementor) {
        WinAggImplementor imp = (WinAggImplementor) agg.implementor;
        needCache = imp.needCacheWhenFrameIntact();
      }
      if (needCache ^ cachedBlock) {
        // Regular aggregates do not change when the windowing frame keeps
        // the same. Ths
        continue;
      }
      nonEmpty = true;
      Expression res = agg.implementor.implementResult(requireNonNull(agg.context, "agg.context"),
          new WinAggResultContextImpl(builder, requireNonNull(agg.state, "agg.state"), frame) {
            @Override public List rexArguments() {
              return rexArguments.apply(agg);
            }
          });
      // Several count(a) and count(b) might share the result
      Expression result = requireNonNull(agg.result,
          () -> "agg.result for " + agg.call);
      Expression aggRes = builder.append("a" + agg.aggIdx + "res",
          EnumUtils.convert(res, result.getType()));
      builder.add(
          Expressions.statement(Expressions.assign(result, aggRes)));
    }
    return nonEmpty;
  }

  private static Expression translateBound(RexToLixTranslator translator,
      ParameterExpression i_, Expression row_, Expression min_, Expression max_,
      Expression rows_, Group group, boolean lower, PhysType physType,
      @Nullable Expression keySelector, @Nullable Expression keyComparator) {
    RexWindowBound bound = lower ? group.lowerBound : group.upperBound;
    if (bound.isUnbounded()) {
      return bound.isPreceding() ? min_ : max_;
    }
    if (group.isRows) {
      if (bound.isCurrentRow()) {
        return i_;
      }
      RexNode node = bound.getOffset();
      Expression offs = translator.translate(node);
      // Floating offset does not make sense since we refer to array index.
      // Nulls do not make sense as well.
      offs = EnumUtils.convert(offs, int.class);

      Expression b = i_;
      if (bound.isFollowing()) {
        b = Expressions.add(b, offs);
      } else {
        b = Expressions.subtract(b, offs);
      }
      return b;
    }
    Expression searchLower = min_;
    Expression searchUpper = max_;
    if (bound.isCurrentRow()) {
      if (lower) {
        searchUpper = i_;
      } else {
        searchLower = i_;
      }
    }

    List fieldCollations =
        group.collation().getFieldCollations();
    if (bound.isCurrentRow() && fieldCollations.size() != 1) {
      return Expressions.call(
          (lower
              ? BuiltInMethod.BINARY_SEARCH5_LOWER
              : BuiltInMethod.BINARY_SEARCH5_UPPER).method,
          rows_, row_, searchLower, searchUpper,
          requireNonNull(keySelector, "keySelector"),
          requireNonNull(keyComparator, "keyComparator"));
    }
    assert fieldCollations.size() == 1
        : "When using range window specification, ORDER BY should have"
        + " exactly one expression."
        + " Actual collation is " + group.collation();
    // isRange
    int orderKey =
        fieldCollations.get(0).getFieldIndex();
    RelDataType keyType =
        physType.getRowType().getFieldList().get(orderKey).getType();
    Type desiredKeyType = translator.typeFactory.getJavaClass(keyType);
    if (bound.getOffset() == null) {
      desiredKeyType = Primitive.box(desiredKeyType);
    }
    Expression val = translator.translate(
        new RexInputRef(orderKey, keyType), desiredKeyType);
    if (!bound.isCurrentRow()) {
      RexNode node = bound.getOffset();
      Expression offs = translator.translate(node);
      // TODO: support date + interval somehow
      if (bound.isFollowing()) {
        val = Expressions.add(val, offs);
      } else {
        val = Expressions.subtract(val, offs);
      }
    }
    return Expressions.call(
        (lower
            ? BuiltInMethod.BINARY_SEARCH6_LOWER
            : BuiltInMethod.BINARY_SEARCH6_UPPER).method,
        rows_, val, searchLower, searchUpper,
        requireNonNull(keySelector, "keySelector"),
        requireNonNull(keyComparator, "keyComparator"));
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy