All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.calcite.adapter.spark.SparkRules Maven / Gradle / Ivy

There is a newer version: 1.37.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to you under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.calcite.adapter.spark;

import org.apache.calcite.DataContext;
import org.apache.calcite.adapter.enumerable.EnumerableConvention;
import org.apache.calcite.adapter.enumerable.JavaRowFormat;
import org.apache.calcite.adapter.enumerable.PhysType;
import org.apache.calcite.adapter.enumerable.PhysTypeImpl;
import org.apache.calcite.adapter.enumerable.RexImpTable;
import org.apache.calcite.adapter.enumerable.RexToLixTranslator;
import org.apache.calcite.adapter.java.JavaTypeFactory;
import org.apache.calcite.linq4j.tree.BlockBuilder;
import org.apache.calcite.linq4j.tree.BlockStatement;
import org.apache.calcite.linq4j.tree.Expression;
import org.apache.calcite.linq4j.tree.Expressions;
import org.apache.calcite.linq4j.tree.ParameterExpression;
import org.apache.calcite.linq4j.tree.Primitive;
import org.apache.calcite.linq4j.tree.Types;
import org.apache.calcite.plan.Convention;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelWriter;
import org.apache.calcite.rel.SingleRel;
import org.apache.calcite.rel.convert.ConverterRule;
import org.apache.calcite.rel.core.Values;
import org.apache.calcite.rel.logical.LogicalCalc;
import org.apache.calcite.rel.logical.LogicalValues;
import org.apache.calcite.rel.metadata.RelMdUtil;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.rules.CoreRules;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexMultisetUtil;
import org.apache.calcite.rex.RexProgram;
import org.apache.calcite.sql.validate.SqlConformance;
import org.apache.calcite.sql.validate.SqlConformanceEnum;
import org.apache.calcite.util.BuiltInMethod;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;

import org.checkerframework.checker.nullness.qual.Nullable;

import scala.Tuple2;

import java.lang.reflect.Type;
import java.util.AbstractList;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Random;

/**
 * Rules for the {@link SparkRel#CONVENTION Spark calling convention}.
 *
 * @see JdbcToSparkConverterRule
 */
public abstract class SparkRules {
  private SparkRules() {}

  /** Rule that converts from Spark to enumerable convention. */
  public static final SparkToEnumerableConverterRule SPARK_TO_ENUMERABLE =
      SparkToEnumerableConverterRule.DEFAULT_CONFIG
          .toRule(SparkToEnumerableConverterRule.class);

  /** Rule that converts from enumerable to Spark convention. */
  public static final EnumerableToSparkConverterRule ENUMERABLE_TO_SPARK =
      EnumerableToSparkConverterRule.DEFAULT_CONFIG
          .toRule(EnumerableToSparkConverterRule.class);

  /** Rule that converts a {@link org.apache.calcite.rel.logical.LogicalCalc}
   * to a {@link org.apache.calcite.adapter.spark.SparkRules.SparkCalc}. */
  public static final SparkCalcRule SPARK_CALC_RULE =
      SparkCalcRule.DEFAULT_CONFIG.toRule(SparkCalcRule.class);

  /** Rule that implements VALUES operator in Spark convention. */
  public static final SparkValuesRule SPARK_VALUES_RULE =
      SparkValuesRule.DEFAULT_CONFIG.toRule(SparkValuesRule.class);

  public static List rules() {
    return ImmutableList.of(
        // TODO: add SparkProjectRule, SparkFilterRule, SparkProjectToCalcRule,
        // SparkFilterToCalcRule, and remove the following 2 rules.
        CoreRules.PROJECT_TO_CALC,
        CoreRules.FILTER_TO_CALC,
        ENUMERABLE_TO_SPARK,
        SPARK_TO_ENUMERABLE,
        SPARK_VALUES_RULE,
        SPARK_CALC_RULE);
  }

  /** Planner rule that converts from enumerable to Spark convention.
   *
   * @see #ENUMERABLE_TO_SPARK */
  static class EnumerableToSparkConverterRule extends ConverterRule {
    /** Default configuration. */
    static final Config DEFAULT_CONFIG = Config.INSTANCE
        .withConversion(RelNode.class, EnumerableConvention.INSTANCE,
            SparkRel.CONVENTION, "EnumerableToSparkConverterRule")
        .withRuleFactory(EnumerableToSparkConverterRule::new);

    EnumerableToSparkConverterRule(Config config) {
      super(config);
    }

    @Override public RelNode convert(RelNode rel) {
      return new EnumerableToSparkConverter(rel.getCluster(),
          rel.getTraitSet().replace(SparkRel.CONVENTION), rel);
    }
  }

  /** Planner rule that converts from Spark to enumerable convention.
   *
   * @see #SPARK_TO_ENUMERABLE */
  static class SparkToEnumerableConverterRule extends ConverterRule {
    static final Config DEFAULT_CONFIG = Config.INSTANCE
        .withConversion(RelNode.class, SparkRel.CONVENTION,
            EnumerableConvention.INSTANCE, "SparkToEnumerableConverterRule")
        .withRuleFactory(SparkToEnumerableConverterRule::new);

    SparkToEnumerableConverterRule(Config config) {
      super(config);
    }

    @Override public RelNode convert(RelNode rel) {
      return new SparkToEnumerableConverter(rel.getCluster(),
          rel.getTraitSet().replace(EnumerableConvention.INSTANCE), rel);
    }
  }

  /** Planner rule that implements VALUES operator in Spark convention.
   *
   * @see #SPARK_VALUES_RULE */
  public static class SparkValuesRule extends ConverterRule {
    /** Default configuration. */
    static final Config DEFAULT_CONFIG = Config.INSTANCE
        .withConversion(LogicalValues.class, Convention.NONE,
            SparkRel.CONVENTION, "SparkValuesRule")
        .withRuleFactory(SparkValuesRule::new);

    SparkValuesRule(Config config) {
      super(config);
    }

    @Override public RelNode convert(RelNode rel) {
      LogicalValues values = (LogicalValues) rel;
      return new SparkValues(
          values.getCluster(),
          values.getRowType(),
          values.getTuples(),
          values.getTraitSet().replace(getOutTrait()));
    }
  }

  /** VALUES construct implemented in Spark. */
  public static class SparkValues extends Values implements SparkRel {
    SparkValues(
        RelOptCluster cluster,
        RelDataType rowType,
        ImmutableList> tuples,
        RelTraitSet traitSet) {
      super(cluster, rowType, tuples, traitSet);
    }

    @Override public RelNode copy(
        RelTraitSet traitSet, List inputs) {
      assert inputs.isEmpty();
      return new SparkValues(
          getCluster(), rowType, tuples, traitSet);
    }

    @Override public Result implementSpark(Implementor implementor) {
/*
            return Linq4j.asSpark(
                new Object[][] {
                    new Object[] {1, 2},
                    new Object[] {3, 4}
                });
*/
      final JavaTypeFactory typeFactory =
          (JavaTypeFactory) getCluster().getTypeFactory();
      final BlockBuilder builder = new BlockBuilder();
      final PhysType physType =
          PhysTypeImpl.of(implementor.getTypeFactory(),
              getRowType(),
              JavaRowFormat.CUSTOM);
      final Type rowClass = physType.getJavaRowType();

      final List expressions = new ArrayList<>();
      final List fields = rowType.getFieldList();
      for (List tuple : tuples) {
        final List literals = new ArrayList<>();
        for (Pair pair
            : Pair.zip(fields, tuple)) {
          literals.add(
              RexToLixTranslator.translateLiteral(
                  pair.right,
                  pair.left.getType(),
                  typeFactory,
                  RexImpTable.NullAs.NULL));
        }
        expressions.add(physType.record(literals));
      }
      builder.add(
          Expressions.return_(null,
              Expressions.call(SparkMethod.ARRAY_TO_RDD.method,
                  Expressions.call(SparkMethod.GET_SPARK_CONTEXT.method,
                      implementor.getRootExpression()),
                  Expressions.newArrayInit(Primitive.box(rowClass),
                      expressions))));
      return implementor.result(physType, builder.toBlock());
    }
  }

  /**
   * Rule to convert a {@link org.apache.calcite.rel.logical.LogicalCalc} to an
   * {@link org.apache.calcite.adapter.spark.SparkRules.SparkCalc}.
   *
   * @see #SPARK_CALC_RULE
   */
  private static class SparkCalcRule extends ConverterRule {
    /** Default configuration. */
    static final Config DEFAULT_CONFIG = Config.INSTANCE
        .withConversion(LogicalCalc.class, Convention.NONE, SparkRel.CONVENTION,
            "SparkCalcRule")
        .withRuleFactory(SparkCalcRule::new);

    SparkCalcRule(Config config) {
      super(config);
    }

    @Override public RelNode convert(RelNode rel) {
      final LogicalCalc calc = (LogicalCalc) rel;

      // If there's a multiset, let FarragoMultisetSplitter work on it
      // first.
      final RexProgram program = calc.getProgram();
      if (RexMultisetUtil.containsMultiset(program)
          || program.containsAggs()) {
        return null;
      }

      return new SparkCalc(
          rel.getCluster(),
          rel.getTraitSet().replace(SparkRel.CONVENTION),
          convert(calc.getInput(),
              calc.getInput().getTraitSet().replace(SparkRel.CONVENTION)),
          program);
    }
  }

  /** Implementation of {@link org.apache.calcite.rel.core.Calc}
   * in Spark convention. */
  public static class SparkCalc extends SingleRel implements SparkRel {
    private final RexProgram program;

    public SparkCalc(RelOptCluster cluster,
        RelTraitSet traitSet,
        RelNode input,
        RexProgram program) {
      super(cluster, traitSet, input);
      assert getConvention() == SparkRel.CONVENTION;
      assert !program.containsAggs();
      this.program = program;
      this.rowType = program.getOutputRowType();
    }

    @Deprecated // to be removed before 2.0
    public SparkCalc(RelOptCluster cluster, RelTraitSet traitSet, RelNode input,
        RexProgram program, int flags) {
      this(cluster, traitSet, input, program);
      Util.discard(flags);
    }

    @Override public RelWriter explainTerms(RelWriter pw) {
      return program.explainCalc(super.explainTerms(pw));
    }

    @Override public double estimateRowCount(RelMetadataQuery mq) {
      return RelMdUtil.estimateFilteredRows(getInput(), program, mq);
    }

    @Override public @Nullable RelOptCost computeSelfCost(RelOptPlanner planner,
        RelMetadataQuery mq) {
      double dRows = mq.getRowCount(this);
      double dCpu = mq.getRowCount(getInput())
          * program.getExprCount();
      double dIo = 0;
      return planner.getCostFactory().makeCost(dRows, dCpu, dIo);
    }

    @Override public RelNode copy(RelTraitSet traitSet, List inputs) {
      return new SparkCalc(
          getCluster(),
          traitSet,
          sole(inputs),
          program);
    }

    @Deprecated // to be removed before 2.0
    public int getFlags() {
      return 1;
    }

    @Override public Result implementSpark(Implementor implementor) {
      final JavaTypeFactory typeFactory = implementor.getTypeFactory();
      final BlockBuilder builder = new BlockBuilder();
      final SparkRel child = (SparkRel) getInput();

      final Result result = implementor.visitInput(this, 0, child);

      final PhysType physType =
          PhysTypeImpl.of(
              typeFactory, getRowType(), JavaRowFormat.CUSTOM);

      // final RDD inputRdd = <>;
      // return inputRdd.flatMap(
      //   new FlatMapFunction() {
      //          public List call(Employee e) {
      //              if (!(e.empno < 10)) {
      //                  return Collections.emptyList();
      //              }
      //              return Collections.singletonList(
      //                  new X(...)));
      //          }
      //      })


      Type outputJavaType = physType.getJavaRowType();
      @SuppressWarnings("unused")
      final Type rddType =
          Types.of(
              JavaRDD.class, outputJavaType);
      Type inputJavaType = result.physType.getJavaRowType();
      final Expression inputRdd_ =
          builder.append(
              "inputRdd",
              result.block);

      BlockBuilder builder2 = new BlockBuilder();

      final ParameterExpression e_ =
          Expressions.parameter(inputJavaType, "e");
      if (program.getCondition() != null) {
        Expression condition =
            RexToLixTranslator.translateCondition(
                program,
                typeFactory,
                builder2,
                new RexToLixTranslator.InputGetterImpl(e_, result.physType),
                null, implementor.getConformance());
        builder2.add(
            Expressions.ifThen(
                Expressions.not(condition),
                Expressions.return_(null,
                    Expressions.call(
                        BuiltInMethod.COLLECTIONS_EMPTY_LIST.method))));
      }

      final SqlConformance conformance = SqlConformanceEnum.DEFAULT;
      List expressions =
          RexToLixTranslator.translateProjects(
              program,
              typeFactory,
              conformance,
              builder2,
              null,
              null,
              DataContext.ROOT,
              new RexToLixTranslator.InputGetterImpl(e_, result.physType),
              null);
      builder2.add(
          Expressions.return_(null,
              Expressions.convert_(
                  Expressions.call(
                      BuiltInMethod.COLLECTIONS_SINGLETON_LIST.method,
                      physType.record(expressions)),
                  List.class)));

      final BlockStatement callBody = builder2.toBlock();
      builder.add(
          Expressions.return_(
              null,
              Expressions.call(
                  inputRdd_,
                  SparkMethod.RDD_FLAT_MAP.method,
                  Expressions.lambda(
                      SparkRuntime.CalciteFlatMapFunction.class,
                      callBody,
                      e_))));
      return implementor.result(physType, builder.toBlock());
    }
  }

  // Play area

  public static void main(String[] args) {
    final JavaSparkContext sc = new JavaSparkContext("local[1]", "calcite");
    final JavaRDD file = sc.textFile("/usr/share/dict/words");
    System.out.println(
        file.map(s -> s.substring(0, Math.min(s.length(), 1)))
            .distinct().count());
    file.cache();
    String s =
        file.groupBy((Function) s1 -> s1.substring(0, Math.min(s1.length(), 1))
            //CHECKSTYLE: IGNORE 1
        ).map((Function>, Object>) pair ->
            pair._1() + ":" + Iterables.size(pair._2())).collect().toString();
    System.out.print(s);

    final JavaRDD rdd = sc.parallelize(
        new AbstractList() {
          final Random random = new Random();
          @Override public Integer get(int index) {
            System.out.println("get(" + index + ")");
            return random.nextInt(100);
          }

          @Override public int size() {
            System.out.println("size");
            return 10;
          }
        });
    System.out.println(
        rdd.groupBy((Function) integer -> integer % 2).collect().toString());
    System.out.println(
        file.flatMap((FlatMapFunction>) x -> {
          if (!x.startsWith("a")) {
            return Collections.emptyIterator();
          }
          return Collections.singletonList(
              Pair.of(x.toUpperCase(Locale.ROOT), x.length())).iterator();
        })
            .take(5)
            .toString());
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy