io.substrait.dsl.SubstraitBuilder Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of core Show documentation
Show all versions of core Show documentation
Create a well-defined, cross-language specification for data compute operations
package io.substrait.dsl;
import com.github.bsideup.jabel.Desugar;
import io.substrait.expression.AggregateFunctionInvocation;
import io.substrait.expression.Expression;
import io.substrait.expression.Expression.FailureBehavior;
import io.substrait.expression.Expression.IfClause;
import io.substrait.expression.Expression.IfThen;
import io.substrait.expression.Expression.SwitchClause;
import io.substrait.expression.FieldReference;
import io.substrait.expression.ImmutableExpression.Cast;
import io.substrait.expression.ImmutableExpression.SingleOrList;
import io.substrait.expression.ImmutableExpression.Switch;
import io.substrait.expression.ImmutableFieldReference;
import io.substrait.expression.WindowBound;
import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.extension.SimpleExtension;
import io.substrait.function.ToTypeString;
import io.substrait.plan.ImmutablePlan;
import io.substrait.plan.ImmutableRoot;
import io.substrait.plan.Plan;
import io.substrait.relation.Aggregate;
import io.substrait.relation.Cross;
import io.substrait.relation.Expand;
import io.substrait.relation.Fetch;
import io.substrait.relation.Filter;
import io.substrait.relation.Join;
import io.substrait.relation.NamedScan;
import io.substrait.relation.Project;
import io.substrait.relation.Rel;
import io.substrait.relation.Set;
import io.substrait.relation.Sort;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.MergeJoin;
import io.substrait.relation.physical.NestedLoopJoin;
import io.substrait.type.ImmutableType;
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class SubstraitBuilder {
static final TypeCreator R = TypeCreator.of(false);
static final TypeCreator N = TypeCreator.of(true);
private final SimpleExtension.ExtensionCollection extensions;
public SubstraitBuilder(SimpleExtension.ExtensionCollection extensions) {
this.extensions = extensions;
}
// Relations
public Aggregate.Measure measure(AggregateFunctionInvocation aggFn) {
return Aggregate.Measure.builder().function(aggFn).build();
}
public Aggregate.Measure measure(AggregateFunctionInvocation aggFn, Expression preMeasureFilter) {
return Aggregate.Measure.builder().function(aggFn).preMeasureFilter(preMeasureFilter).build();
}
public Aggregate aggregate(
Function groupingFn,
Function> measuresFn,
Rel input) {
Function> groupingsFn =
groupingFn.andThen(g -> Stream.of(g).collect(Collectors.toList()));
return aggregate(groupingsFn, measuresFn, Optional.empty(), input);
}
public Aggregate aggregate(
Function groupingFn,
Function> measuresFn,
Rel.Remap remap,
Rel input) {
Function> groupingsFn =
groupingFn.andThen(g -> Stream.of(g).collect(Collectors.toList()));
return aggregate(groupingsFn, measuresFn, Optional.of(remap), input);
}
private Aggregate aggregate(
Function> groupingsFn,
Function> measuresFn,
Optional remap,
Rel input) {
var groupings = groupingsFn.apply(input);
var measures = measuresFn.apply(input);
return Aggregate.builder()
.groupings(groupings)
.measures(measures)
.remap(remap)
.input(input)
.build();
}
public Cross cross(Rel left, Rel right) {
return cross(left, right, Optional.empty());
}
public Cross cross(Rel left, Rel right, Rel.Remap remap) {
return cross(left, right, Optional.of(remap));
}
private Cross cross(Rel left, Rel right, Optional remap) {
return Cross.builder().left(left).right(right).remap(remap).build();
}
public Fetch fetch(long offset, long count, Rel input) {
return fetch(offset, OptionalLong.of(count), Optional.empty(), input);
}
public Fetch fetch(long offset, long count, Rel.Remap remap, Rel input) {
return fetch(offset, OptionalLong.of(count), Optional.of(remap), input);
}
public Fetch limit(long limit, Rel input) {
return fetch(0, OptionalLong.of(limit), Optional.empty(), input);
}
public Fetch limit(long limit, Rel.Remap remap, Rel input) {
return fetch(0, OptionalLong.of(limit), Optional.of(remap), input);
}
public Fetch offset(long offset, Rel input) {
return fetch(offset, OptionalLong.empty(), Optional.empty(), input);
}
public Fetch offset(long offset, Rel.Remap remap, Rel input) {
return fetch(offset, OptionalLong.empty(), Optional.of(remap), input);
}
private Fetch fetch(long offset, OptionalLong count, Optional remap, Rel input) {
return Fetch.builder().offset(offset).count(count).input(input).remap(remap).build();
}
public Filter filter(Function conditionFn, Rel input) {
return filter(conditionFn, Optional.empty(), input);
}
public Filter filter(Function conditionFn, Rel.Remap remap, Rel input) {
return filter(conditionFn, Optional.of(remap), input);
}
private Filter filter(
Function conditionFn, Optional remap, Rel input) {
var condition = conditionFn.apply(input);
return Filter.builder().input(input).condition(condition).remap(remap).build();
}
@Desugar
public record JoinInput(Rel left, Rel right) {}
public Join innerJoin(Function conditionFn, Rel left, Rel right) {
return join(conditionFn, Join.JoinType.INNER, left, right);
}
public Join innerJoin(
Function conditionFn, Rel.Remap remap, Rel left, Rel right) {
return join(conditionFn, Join.JoinType.INNER, remap, left, right);
}
public Join join(
Function conditionFn, Join.JoinType joinType, Rel left, Rel right) {
return join(conditionFn, joinType, Optional.empty(), left, right);
}
public Join join(
Function conditionFn,
Join.JoinType joinType,
Rel.Remap remap,
Rel left,
Rel right) {
return join(conditionFn, joinType, Optional.of(remap), left, right);
}
private Join join(
Function conditionFn,
Join.JoinType joinType,
Optional remap,
Rel left,
Rel right) {
var condition = conditionFn.apply(new JoinInput(left, right));
return Join.builder()
.left(left)
.right(right)
.condition(condition)
.joinType(joinType)
.remap(remap)
.build();
}
public HashJoin hashJoin(
List leftKeys,
List rightKeys,
HashJoin.JoinType joinType,
Rel left,
Rel right) {
return hashJoin(leftKeys, rightKeys, joinType, Optional.empty(), left, right);
}
public HashJoin hashJoin(
List leftKeys,
List rightKeys,
HashJoin.JoinType joinType,
Optional remap,
Rel left,
Rel right) {
return HashJoin.builder()
.left(left)
.right(right)
.leftKeys(
this.fieldReferences(left, leftKeys.stream().mapToInt(Integer::intValue).toArray()))
.rightKeys(
this.fieldReferences(right, rightKeys.stream().mapToInt(Integer::intValue).toArray()))
.joinType(joinType)
.remap(remap)
.build();
}
public MergeJoin mergeJoin(
List leftKeys,
List rightKeys,
MergeJoin.JoinType joinType,
Rel left,
Rel right) {
return mergeJoin(leftKeys, rightKeys, joinType, Optional.empty(), left, right);
}
public MergeJoin mergeJoin(
List leftKeys,
List rightKeys,
MergeJoin.JoinType joinType,
Optional remap,
Rel left,
Rel right) {
return MergeJoin.builder()
.left(left)
.right(right)
.leftKeys(
this.fieldReferences(left, leftKeys.stream().mapToInt(Integer::intValue).toArray()))
.rightKeys(
this.fieldReferences(right, rightKeys.stream().mapToInt(Integer::intValue).toArray()))
.joinType(joinType)
.remap(remap)
.build();
}
public NestedLoopJoin nestedLoopJoin(
Function conditionFn,
NestedLoopJoin.JoinType joinType,
Rel left,
Rel right) {
return nestedLoopJoin(conditionFn, joinType, Optional.empty(), left, right);
}
private NestedLoopJoin nestedLoopJoin(
Function conditionFn,
NestedLoopJoin.JoinType joinType,
Optional remap,
Rel left,
Rel right) {
var condition = conditionFn.apply(new JoinInput(left, right));
return NestedLoopJoin.builder()
.left(left)
.right(right)
.condition(condition)
.joinType(joinType)
.remap(remap)
.build();
}
public NamedScan namedScan(
Iterable tableName, Iterable columnNames, Iterable types) {
return namedScan(tableName, columnNames, types, Optional.empty());
}
public NamedScan namedScan(
Iterable tableName,
Iterable columnNames,
Iterable types,
Rel.Remap remap) {
return namedScan(tableName, columnNames, types, Optional.of(remap));
}
private NamedScan namedScan(
Iterable tableName,
Iterable columnNames,
Iterable types,
Optional remap) {
var struct = Type.Struct.builder().addAllFields(types).nullable(false).build();
var namedStruct = NamedStruct.of(columnNames, struct);
return NamedScan.builder().names(tableName).initialSchema(namedStruct).remap(remap).build();
}
public Project project(Function> expressionsFn, Rel input) {
return project(expressionsFn, Optional.empty(), input);
}
public Project project(
Function> expressionsFn, Rel.Remap remap, Rel input) {
return project(expressionsFn, Optional.of(remap), input);
}
private Project project(
Function> expressionsFn,
Optional remap,
Rel input) {
var expressions = expressionsFn.apply(input);
return Project.builder().input(input).expressions(expressions).remap(remap).build();
}
public Expand expand(Function> fieldsFn, Rel input) {
return expand(fieldsFn, Optional.empty(), input);
}
public Expand expand(
Function> fieldsFn, Rel.Remap remap, Rel input) {
return expand(fieldsFn, Optional.of(remap), input);
}
private Expand expand(
Function> fieldsFn,
Optional remap,
Rel input) {
var fields = fieldsFn.apply(input);
return Expand.builder().input(input).fields(fields).remap(remap).build();
}
public Set set(Set.SetOp op, Rel... inputs) {
return set(op, Optional.empty(), inputs);
}
public Set set(Set.SetOp op, Rel.Remap remap, Rel... inputs) {
return set(op, Optional.of(remap), inputs);
}
private Set set(Set.SetOp op, Optional remap, Rel... inputs) {
return Set.builder().setOp(op).remap(remap).addAllInputs(Arrays.asList(inputs)).build();
}
public Sort sort(Function> sortFieldFn, Rel input) {
return sort(sortFieldFn, Optional.empty(), input);
}
public Sort sort(
Function> sortFieldFn,
Rel.Remap remap,
Rel input) {
return sort(sortFieldFn, Optional.of(remap), input);
}
private Sort sort(
Function> sortFieldFn,
Optional remap,
Rel input) {
var condition = sortFieldFn.apply(input);
return Sort.builder().input(input).sortFields(condition).remap(remap).build();
}
// Expressions
public Expression.BoolLiteral bool(boolean v) {
return Expression.BoolLiteral.builder().value(v).build();
}
public Expression.I32Literal i32(int v) {
return Expression.I32Literal.builder().value(v).build();
}
public Expression.FP64Literal fp64(double v) {
return Expression.FP64Literal.builder().value(v).build();
}
public Expression cast(Expression input, Type type) {
return Cast.builder()
.input(input)
.type(type)
.failureBehavior(FailureBehavior.UNSPECIFIED)
.build();
}
public FieldReference fieldReference(Rel input, int index) {
return ImmutableFieldReference.newInputRelReference(index, input);
}
public List fieldReferences(Rel input, int... indexes) {
return Arrays.stream(indexes)
.mapToObj(index -> fieldReference(input, index))
.collect(java.util.stream.Collectors.toList());
}
public FieldReference fieldReference(List inputs, int index) {
return ImmutableFieldReference.newInputRelReference(index, inputs);
}
public List fieldReferences(List inputs, int... indexes) {
return Arrays.stream(indexes)
.mapToObj(index -> fieldReference(inputs, index))
.collect(java.util.stream.Collectors.toList());
}
public IfThen ifThen(Iterable extends IfClause> ifClauses, Expression elseClause) {
return IfThen.builder().addAllIfClauses(ifClauses).elseClause(elseClause).build();
}
public IfClause ifClause(Expression condition, Expression then) {
return IfClause.builder().condition(condition).then(then).build();
}
public Expression singleOrList(Expression condition, Expression... options) {
return SingleOrList.builder().condition(condition).addOptions(options).build();
}
public Expression.InPredicate inPredicate(Rel haystack, Expression... needles) {
return Expression.InPredicate.builder()
.addAllNeedles(Arrays.asList(needles))
.haystack(haystack)
.build();
}
public List sortFields(Rel input, int... indexes) {
return Arrays.stream(indexes)
.mapToObj(
index ->
Expression.SortField.builder()
.expr(ImmutableFieldReference.newInputRelReference(index, input))
.direction(Expression.SortDirection.ASC_NULLS_LAST)
.build())
.collect(java.util.stream.Collectors.toList());
}
public Expression.SortField sortField(
Expression expression, Expression.SortDirection sortDirection) {
return Expression.SortField.builder().expr(expression).direction(sortDirection).build();
}
public SwitchClause switchClause(Expression.Literal condition, Expression then) {
return SwitchClause.builder().condition(condition).then(then).build();
}
public Switch switchExpression(
Expression match, Iterable extends SwitchClause> clauses, Expression defaultClause) {
return Switch.builder()
.match(match)
.addAllSwitchClauses(clauses)
.defaultClause(defaultClause)
.build();
}
// Aggregate Functions
public AggregateFunctionInvocation aggregateFn(
String namespace, String key, Type outputType, Expression... args) {
var declaration =
extensions.getAggregateFunction(SimpleExtension.FunctionAnchor.of(namespace, key));
return AggregateFunctionInvocation.builder()
.arguments(Arrays.stream(args).collect(java.util.stream.Collectors.toList()))
.outputType(outputType)
.declaration(declaration)
.aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT)
.invocation(Expression.AggregationInvocation.ALL)
.build();
}
public Aggregate.Grouping grouping(Rel input, int... indexes) {
var columns = fieldReferences(input, indexes);
return Aggregate.Grouping.builder().addAllExpressions(columns).build();
}
public Aggregate.Grouping grouping(Expression... expressions) {
return Aggregate.Grouping.builder().addExpressions(expressions).build();
}
public Aggregate.Measure count(Rel input, int field) {
var declaration =
extensions.getAggregateFunction(
SimpleExtension.FunctionAnchor.of(
DefaultExtensionCatalog.FUNCTIONS_AGGREGATE_GENERIC, "count:any"));
return measure(
AggregateFunctionInvocation.builder()
.arguments(fieldReferences(input, field))
.outputType(R.I64)
.declaration(declaration)
.aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT)
.invocation(Expression.AggregationInvocation.ALL)
.build());
}
public Aggregate.Measure min(Rel input, int field) {
return min(fieldReference(input, field));
}
public Aggregate.Measure min(Expression expr) {
return singleArgumentArithmeticAggregate(
expr,
"min",
// min output is always nullable
TypeCreator.asNullable(expr.getType()));
}
public Aggregate.Measure max(Rel input, int field) {
return max(fieldReference(input, field));
}
public Aggregate.Measure max(Expression expr) {
return singleArgumentArithmeticAggregate(
expr,
"max",
// max output is always nullable
TypeCreator.asNullable(expr.getType()));
}
public Aggregate.Measure avg(Rel input, int field) {
return avg(fieldReference(input, field));
}
public Aggregate.Measure avg(Expression expr) {
return singleArgumentArithmeticAggregate(
expr,
"avg",
// avg output is always nullable
TypeCreator.asNullable(expr.getType()));
}
public Aggregate.Measure sum(Rel input, int field) {
return sum(fieldReference(input, field));
}
public Aggregate.Measure sum(Expression expr) {
return singleArgumentArithmeticAggregate(
expr,
"sum",
// sum output is always nullable
TypeCreator.asNullable(expr.getType()));
}
public Aggregate.Measure sum0(Rel input, int field) {
return sum(fieldReference(input, field));
}
public Aggregate.Measure sum0(Expression expr) {
return singleArgumentArithmeticAggregate(
expr,
"sum0",
// sum0 output is always NOT NULL I64
R.I64);
}
private Aggregate.Measure singleArgumentArithmeticAggregate(
Expression expr, String functionName, Type outputType) {
String typeString = ToTypeString.apply(expr.getType());
var declaration =
extensions.getAggregateFunction(
SimpleExtension.FunctionAnchor.of(
DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC,
String.format("%s:%s", functionName, typeString)));
return measure(
AggregateFunctionInvocation.builder()
.arguments(Arrays.asList(expr))
.outputType(outputType)
.declaration(declaration)
// INITIAL_TO_RESULT is the most restrictive aggregation phase type,
// as it does not allow decomposition. Use it as the default for now.
// TODO: set this per function
.aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT)
.invocation(Expression.AggregationInvocation.ALL)
.build());
}
// Scalar Functions
public Expression.ScalarFunctionInvocation negate(Expression expr) {
// output type of negate is the same as the input type
var outputType = expr.getType();
return scalarFn(
DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC,
String.format("negate:%s", ToTypeString.apply(outputType)),
outputType,
expr);
}
public Expression.ScalarFunctionInvocation add(Expression left, Expression right) {
return arithmeticFunction("add", left, right);
}
public Expression.ScalarFunctionInvocation subtract(Expression left, Expression right) {
return arithmeticFunction("substract", left, right);
}
public Expression.ScalarFunctionInvocation multiply(Expression left, Expression right) {
return arithmeticFunction("multiply", left, right);
}
public Expression.ScalarFunctionInvocation divide(Expression left, Expression right) {
return arithmeticFunction("divide", left, right);
}
private Expression.ScalarFunctionInvocation arithmeticFunction(
String fname, Expression left, Expression right) {
var leftTypeStr = ToTypeString.apply(left.getType());
var rightTypeStr = ToTypeString.apply(right.getType());
var key = String.format("%s:%s_%s", fname, leftTypeStr, rightTypeStr);
var isOutputNullable = left.getType().nullable() || right.getType().nullable();
var outputType = left.getType();
outputType =
isOutputNullable
? TypeCreator.asNullable(outputType)
: TypeCreator.asNotNullable(outputType);
return scalarFn(DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, key, outputType, left, right);
}
public Expression.ScalarFunctionInvocation equal(Expression left, Expression right) {
return scalarFn(
DefaultExtensionCatalog.FUNCTIONS_COMPARISON, "equal:any_any", R.BOOLEAN, left, right);
}
public Expression.ScalarFunctionInvocation or(Expression... args) {
// If any arg is nullable, the output of or is potentially nullable
// For example: false or null = null
var isOutputNullable = Arrays.stream(args).anyMatch(a -> a.getType().nullable());
var outputType = isOutputNullable ? N.BOOLEAN : R.BOOLEAN;
return scalarFn(DefaultExtensionCatalog.FUNCTIONS_BOOLEAN, "or:bool", outputType, args);
}
public Expression.ScalarFunctionInvocation scalarFn(
String namespace, String key, Type outputType, Expression... args) {
var declaration =
extensions.getScalarFunction(SimpleExtension.FunctionAnchor.of(namespace, key));
return Expression.ScalarFunctionInvocation.builder()
.declaration(declaration)
.outputType(outputType)
.arguments(Arrays.stream(args).collect(java.util.stream.Collectors.toList()))
.build();
}
public Expression.WindowFunctionInvocation windowFn(
String namespace,
String key,
Type outputType,
Expression.AggregationPhase aggregationPhase,
Expression.AggregationInvocation invocation,
Expression.WindowBoundsType boundsType,
WindowBound lowerBound,
WindowBound upperBound,
Expression... args) {
var declaration =
extensions.getWindowFunction(SimpleExtension.FunctionAnchor.of(namespace, key));
return Expression.WindowFunctionInvocation.builder()
.declaration(declaration)
.outputType(outputType)
.aggregationPhase(aggregationPhase)
.invocation(invocation)
.boundsType(boundsType)
.lowerBound(lowerBound)
.upperBound(upperBound)
.arguments(Arrays.stream(args).collect(java.util.stream.Collectors.toList()))
.build();
}
// Types
public Type.UserDefined userDefinedType(String namespace, String typeName) {
return ImmutableType.UserDefined.builder()
.uri(namespace)
.name(typeName)
.nullable(false)
.build();
}
// Misc
public Plan.Root root(Rel rel) {
return ImmutableRoot.builder().input(rel).build();
}
public Plan plan(Plan.Root root) {
return ImmutablePlan.builder().addRoots(root).build();
}
public Rel.Remap remap(Integer... fields) {
return Rel.Remap.of(Arrays.asList(fields));
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy