All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.substrait.relation.RelProtoConverter Maven / Gradle / Ivy

Go to download

Create a well-defined, cross-language specification for data compute operations

There is a newer version: 0.46.1
Show newest version
package io.substrait.relation;

import io.substrait.expression.Expression;
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.proto.ExpressionProtoConverter;
import io.substrait.expression.proto.ExpressionProtoConverter.BoundConverter;
import io.substrait.extension.ExtensionCollector;
import io.substrait.proto.AggregateFunction;
import io.substrait.proto.AggregateRel;
import io.substrait.proto.ConsistentPartitionWindowRel;
import io.substrait.proto.CrossRel;
import io.substrait.proto.ExpandRel;
import io.substrait.proto.ExtensionLeafRel;
import io.substrait.proto.ExtensionMultiRel;
import io.substrait.proto.ExtensionSingleRel;
import io.substrait.proto.FetchRel;
import io.substrait.proto.FilterRel;
import io.substrait.proto.HashJoinRel;
import io.substrait.proto.JoinRel;
import io.substrait.proto.MergeJoinRel;
import io.substrait.proto.NestedLoopJoinRel;
import io.substrait.proto.ProjectRel;
import io.substrait.proto.ReadRel;
import io.substrait.proto.Rel;
import io.substrait.proto.RelCommon;
import io.substrait.proto.SetRel;
import io.substrait.proto.SortField;
import io.substrait.proto.SortRel;
import io.substrait.relation.files.FileOrFiles;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.MergeJoin;
import io.substrait.relation.physical.NestedLoopJoin;
import io.substrait.type.proto.TypeProtoConverter;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/** Converts from {@link io.substrait.relation.Rel} to {@link io.substrait.proto.Rel} */
public class RelProtoConverter implements RelVisitor {

  private final ExpressionProtoConverter exprProtoConverter;
  private final TypeProtoConverter typeProtoConverter;
  private final ExtensionCollector functionCollector;

  public RelProtoConverter(ExtensionCollector functionCollector) {
    this.functionCollector = functionCollector;
    this.exprProtoConverter = new ExpressionProtoConverter(functionCollector, this);
    this.typeProtoConverter = new TypeProtoConverter(functionCollector);
  }

  private List toProto(Collection expressions) {
    return expressions.stream().map(this::toProto).collect(Collectors.toList());
  }

  private io.substrait.proto.Expression toProto(Expression expression) {
    return expression.accept(exprProtoConverter);
  }

  public io.substrait.proto.Rel toProto(io.substrait.relation.Rel rel) {
    return rel.accept(this);
  }

  private io.substrait.proto.Type toProto(io.substrait.type.Type type) {
    return type.accept(typeProtoConverter);
  }

  private List toProtoS(Collection sorts) {
    return sorts.stream()
        .map(
            s -> {
              return SortField.newBuilder()
                  .setDirection(s.direction().toProto())
                  .setExpr(toProto(s.expr()))
                  .build();
            })
        .collect(Collectors.toList());
  }

  private io.substrait.proto.Expression.FieldReference toProto(FieldReference fieldReference) {
    return fieldReference.accept(exprProtoConverter).getSelection();
  }

  @Override
  public Rel visit(Aggregate aggregate) throws RuntimeException {
    var builder =
        AggregateRel.newBuilder()
            .setInput(toProto(aggregate.getInput()))
            .setCommon(common(aggregate))
            .addAllGroupings(
                aggregate.getGroupings().stream().map(this::toProto).collect(Collectors.toList()))
            .addAllMeasures(
                aggregate.getMeasures().stream().map(this::toProto).collect(Collectors.toList()));

    aggregate.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
    return Rel.newBuilder().setAggregate(builder).build();
  }

  private AggregateRel.Measure toProto(Aggregate.Measure measure) {
    var argVisitor = FunctionArg.toProto(typeProtoConverter, exprProtoConverter);
    var args = measure.getFunction().arguments();
    var aggFuncDef = measure.getFunction().declaration();

    var func =
        AggregateFunction.newBuilder()
            .setPhase(measure.getFunction().aggregationPhase().toProto())
            .setInvocation(measure.getFunction().invocation().toProto())
            .setOutputType(toProto(measure.getFunction().getType()))
            .addAllArguments(
                IntStream.range(0, args.size())
                    .mapToObj(i -> args.get(i).accept(aggFuncDef, i, argVisitor))
                    .collect(Collectors.toList()))
            .addAllSorts(toProtoS(measure.getFunction().sort()))
            .setFunctionReference(
                functionCollector.getFunctionReference(measure.getFunction().declaration()))
            .addAllOptions(
                measure.getFunction().options().stream()
                    .map(ExpressionProtoConverter::from)
                    .collect(Collectors.toList()));

    var builder = AggregateRel.Measure.newBuilder().setMeasure(func);

    measure.getPreMeasureFilter().ifPresent(f -> builder.setFilter(toProto(f)));
    return builder.build();
  }

  private AggregateRel.Grouping toProto(Aggregate.Grouping grouping) {
    return AggregateRel.Grouping.newBuilder()
        .addAllGroupingExpressions(toProto(grouping.getExpressions()))
        .build();
  }

  @Override
  public Rel visit(EmptyScan emptyScan) throws RuntimeException {
    return Rel.newBuilder()
        .setRead(
            ReadRel.newBuilder()
                .setCommon(common(emptyScan))
                .setVirtualTable(ReadRel.VirtualTable.newBuilder().build())
                .setBaseSchema(emptyScan.getInitialSchema().toProto(typeProtoConverter))
                .build())
        .build();
  }

  @Override
  public Rel visit(Fetch fetch) throws RuntimeException {
    var builder =
        FetchRel.newBuilder()
            .setCommon(common(fetch))
            .setInput(toProto(fetch.getInput()))
            .setOffset(fetch.getOffset())
            // -1 is used as a sentinel value to signal LIMIT ALL
            .setCount(fetch.getCount().orElse(-1));

    fetch.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
    return Rel.newBuilder().setFetch(builder).build();
  }

  @Override
  public Rel visit(Filter filter) throws RuntimeException {
    var builder =
        FilterRel.newBuilder()
            .setCommon(common(filter))
            .setInput(toProto(filter.getInput()))
            .setCondition(filter.getCondition().accept(exprProtoConverter));

    filter.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
    return Rel.newBuilder().setFilter(builder).build();
  }

  @Override
  public Rel visit(Join join) throws RuntimeException {
    var builder =
        JoinRel.newBuilder()
            .setCommon(common(join))
            .setLeft(toProto(join.getLeft()))
            .setRight(toProto(join.getRight()))
            .setType(join.getJoinType().toProto());

    join.getCondition().ifPresent(t -> builder.setExpression(toProto(t)));

    join.getPostJoinFilter().ifPresent(t -> builder.setPostJoinFilter(toProto(t)));

    join.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
    return Rel.newBuilder().setJoin(builder).build();
  }

  @Override
  public Rel visit(Set set) throws RuntimeException {
    var builder = SetRel.newBuilder().setCommon(common(set)).setOp(set.getSetOp().toProto());
    set.getInputs()
        .forEach(
            inputRel -> {
              builder.addInputs(toProto(inputRel));
            });

    set.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
    return Rel.newBuilder().setSet(builder).build();
  }

  @Override
  public Rel visit(NamedScan namedScan) throws RuntimeException {
    var builder =
        ReadRel.newBuilder()
            .setCommon(common(namedScan))
            .setNamedTable(ReadRel.NamedTable.newBuilder().addAllNames(namedScan.getNames()))
            .setBaseSchema(namedScan.getInitialSchema().toProto(typeProtoConverter));

    namedScan.getFilter().ifPresent(f -> builder.setFilter(toProto(f)));

    namedScan.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
    return Rel.newBuilder().setRead(builder).build();
  }

  @Override
  public Rel visit(LocalFiles localFiles) throws RuntimeException {
    var builder =
        ReadRel.newBuilder()
            .setCommon(common(localFiles))
            .setLocalFiles(
                ReadRel.LocalFiles.newBuilder()
                    .addAllItems(
                        localFiles.getItems().stream()
                            .map(FileOrFiles::toProto)
                            .collect(Collectors.toList()))
                    .build())
            .setBaseSchema(localFiles.getInitialSchema().toProto(typeProtoConverter));
    localFiles.getFilter().ifPresent(t -> builder.setFilter(toProto(t)));

    localFiles.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
    return Rel.newBuilder().setRead(builder.build()).build();
  }

  @Override
  public Rel visit(ExtensionTable extensionTable) throws RuntimeException {
    ReadRel.ExtensionTable.Builder extensionTableBuilder =
        ReadRel.ExtensionTable.newBuilder().setDetail(extensionTable.getDetail().toProto());
    var builder =
        ReadRel.newBuilder()
            .setCommon(common(extensionTable))
            .setBaseSchema(extensionTable.getInitialSchema().toProto(typeProtoConverter))
            .setExtensionTable(extensionTableBuilder);

    extensionTable.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
    return Rel.newBuilder().setRead(builder).build();
  }

  @Override
  public Rel visit(HashJoin hashJoin) throws RuntimeException {
    var builder =
        HashJoinRel.newBuilder()
            .setCommon(common(hashJoin))
            .setLeft(toProto(hashJoin.getLeft()))
            .setRight(toProto(hashJoin.getRight()))
            .setType(hashJoin.getJoinType().toProto());

    List leftKeys = hashJoin.getLeftKeys();
    List rightKeys = hashJoin.getRightKeys();

    if (leftKeys.size() != rightKeys.size()) {
      throw new RuntimeException("Number of left and right keys must be equal.");
    }

    builder.addAllLeftKeys(leftKeys.stream().map(this::toProto).collect(Collectors.toList()));
    builder.addAllRightKeys(rightKeys.stream().map(this::toProto).collect(Collectors.toList()));

    hashJoin.getPostJoinFilter().ifPresent(t -> builder.setPostJoinFilter(toProto(t)));

    hashJoin.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
    return Rel.newBuilder().setHashJoin(builder).build();
  }

  @Override
  public Rel visit(MergeJoin mergeJoin) throws RuntimeException {
    var builder =
        MergeJoinRel.newBuilder()
            .setCommon(common(mergeJoin))
            .setLeft(toProto(mergeJoin.getLeft()))
            .setRight(toProto(mergeJoin.getRight()))
            .setType(mergeJoin.getJoinType().toProto());

    List leftKeys = mergeJoin.getLeftKeys();
    List rightKeys = mergeJoin.getRightKeys();

    if (leftKeys.size() != rightKeys.size()) {
      throw new RuntimeException("Number of left and right keys must be equal.");
    }

    builder.addAllLeftKeys(leftKeys.stream().map(this::toProto).collect(Collectors.toList()));
    builder.addAllRightKeys(rightKeys.stream().map(this::toProto).collect(Collectors.toList()));

    mergeJoin.getPostJoinFilter().ifPresent(t -> builder.setPostJoinFilter(toProto(t)));

    mergeJoin.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
    return Rel.newBuilder().setMergeJoin(builder).build();
  }

  @Override
  public Rel visit(NestedLoopJoin nestedLoopJoin) throws RuntimeException {
    var builder =
        NestedLoopJoinRel.newBuilder()
            .setCommon(common(nestedLoopJoin))
            .setLeft(toProto(nestedLoopJoin.getLeft()))
            .setRight(toProto(nestedLoopJoin.getRight()))
            .setExpression(toProto(nestedLoopJoin.getCondition()))
            .setType(nestedLoopJoin.getJoinType().toProto());

    nestedLoopJoin.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
    return Rel.newBuilder().setNestedLoopJoin(builder).build();
  }

  @Override
  public Rel visit(ConsistentPartitionWindow consistentPartitionWindow) throws RuntimeException {
    var builder =
        ConsistentPartitionWindowRel.newBuilder()
            .setCommon(common(consistentPartitionWindow))
            .setInput(toProto(consistentPartitionWindow.getInput()))
            .addAllSorts(toProtoS(consistentPartitionWindow.getSorts()))
            .addAllPartitionExpressions(
                toProto(consistentPartitionWindow.getPartitionExpressions()))
            .addAllWindowFunctions(
                toProtoWindowRelFunctions(consistentPartitionWindow.getWindowFunctions()));

    consistentPartitionWindow
        .getExtension()
        .ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));

    return Rel.newBuilder().setWindow(builder).build();
  }

  private List toProtoWindowRelFunctions(
      Collection
          windowRelFunctionInvocations) {

    return windowRelFunctionInvocations.stream()
        .map(
            f -> {
              var argVisitor = FunctionArg.toProto(typeProtoConverter, exprProtoConverter);
              var args = f.arguments();
              var aggFuncDef = f.declaration();

              var arguments =
                  IntStream.range(0, args.size())
                      .mapToObj(i -> args.get(i).accept(aggFuncDef, i, argVisitor))
                      .collect(Collectors.toList());
              var options =
                  f.options().stream()
                      .map(ExpressionProtoConverter::from)
                      .collect(Collectors.toList());

              return ConsistentPartitionWindowRel.WindowRelFunction.newBuilder()
                  .setInvocation(f.invocation().toProto())
                  .setPhase(f.aggregationPhase().toProto())
                  .setOutputType(toProto(f.outputType()))
                  .addAllArguments(arguments)
                  .addAllOptions(options)
                  .setFunctionReference(functionCollector.getFunctionReference(f.declaration()))
                  .setBoundsType(f.boundsType().toProto())
                  .setLowerBound(BoundConverter.convert(f.lowerBound()))
                  .setUpperBound(BoundConverter.convert(f.upperBound()))
                  .build();
            })
        .collect(Collectors.toList());
  }

  @Override
  public Rel visit(Project project) throws RuntimeException {
    var builder =
        ProjectRel.newBuilder()
            .setCommon(common(project))
            .setInput(toProto(project.getInput()))
            .addAllExpressions(
                project.getExpressions().stream().map(this::toProto).collect(Collectors.toList()));

    project.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
    return Rel.newBuilder().setProject(builder).build();
  }

  @Override
  public Rel visit(Expand expand) throws RuntimeException {
    var builder =
        ExpandRel.newBuilder().setCommon(common(expand)).setInput(toProto(expand.getInput()));

    expand
        .getFields()
        .forEach(
            expandField -> {
              if (expandField instanceof Expand.ConsistentField cf) {
                builder.addFields(
                    ExpandRel.ExpandField.newBuilder()
                        .setConsistentField(toProto(cf.getExpression()))
                        .build());

              } else if (expandField instanceof Expand.SwitchingField sf) {
                builder.addFields(
                    ExpandRel.ExpandField.newBuilder()
                        .setSwitchingField(
                            ExpandRel.SwitchingField.newBuilder()
                                .addAllDuplicates(
                                    sf.getDuplicates().stream()
                                        .map(this::toProto)
                                        .collect(Collectors.toList())))
                        .build());
              } else {
                throw new RuntimeException(
                    "Consistent or Switching fields must be set for the Expand relation.");
              }
            });
    return Rel.newBuilder().setExpand(builder).build();
  }

  @Override
  public Rel visit(Sort sort) throws RuntimeException {
    var builder =
        SortRel.newBuilder()
            .setCommon(common(sort))
            .setInput(toProto(sort.getInput()))
            .addAllSorts(toProtoS(sort.getSortFields()));

    sort.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
    return Rel.newBuilder().setSort(builder).build();
  }

  @Override
  public Rel visit(Cross cross) throws RuntimeException {
    var builder =
        CrossRel.newBuilder()
            .setCommon(common(cross))
            .setLeft(toProto(cross.getLeft()))
            .setRight(toProto(cross.getRight()));

    cross.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
    return Rel.newBuilder().setCross(builder).build();
  }

  @Override
  public Rel visit(VirtualTableScan virtualTableScan) throws RuntimeException {
    var builder =
        ReadRel.newBuilder()
            .setCommon(common(virtualTableScan))
            .setVirtualTable(
                ReadRel.VirtualTable.newBuilder()
                    .addAllValues(
                        virtualTableScan.getRows().stream()
                            .map(this::toProto)
                            .map(t -> t.getLiteral().getStruct())
                            .collect(Collectors.toList()))
                    .build())
            .setBaseSchema(virtualTableScan.getInitialSchema().toProto(typeProtoConverter));

    virtualTableScan.getFilter().ifPresent(f -> builder.setFilter(toProto(f)));

    virtualTableScan.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
    return Rel.newBuilder().setRead(builder).build();
  }

  @Override
  public Rel visit(ExtensionLeaf extensionLeaf) throws RuntimeException {
    var builder =
        ExtensionLeafRel.newBuilder()
            .setCommon(common(extensionLeaf))
            .setDetail(extensionLeaf.getDetail().toProto());
    return Rel.newBuilder().setExtensionLeaf(builder).build();
  }

  @Override
  public Rel visit(ExtensionSingle extensionSingle) throws RuntimeException {
    var builder =
        ExtensionSingleRel.newBuilder()
            .setCommon(common(extensionSingle))
            .setInput(toProto(extensionSingle.getInput()))
            .setDetail(extensionSingle.getDetail().toProto());
    return Rel.newBuilder().setExtensionSingle(builder).build();
  }

  @Override
  public Rel visit(ExtensionMulti extensionMulti) throws RuntimeException {
    List inputs =
        extensionMulti.getInputs().stream().map(this::toProto).collect(Collectors.toList());
    var builder =
        ExtensionMultiRel.newBuilder()
            .setCommon(common(extensionMulti))
            .addAllInputs(inputs)
            .setDetail(extensionMulti.getDetail().toProto());
    return Rel.newBuilder().setExtensionMulti(builder).build();
  }

  private RelCommon common(io.substrait.relation.Rel rel) {
    var builder = RelCommon.newBuilder();
    rel.getCommonExtension()
        .ifPresent(extension -> builder.setAdvancedExtension(extension.toProto()));

    var remap = rel.getRemap().orElse(null);
    if (remap != null) {
      builder.setEmit(RelCommon.Emit.newBuilder().addAllOutputMapping(remap.indices()));
    } else {
      builder.setDirect(RelCommon.Direct.getDefaultInstance());
    }

    rel.getHint().ifPresent(md -> builder.setHint(md.toProto()));

    return builder.build();
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy