io.substrait.extension.SimpleExtension Maven / Gradle / Ivy
Show all versions of core Show documentation
package io.substrait.extension;
import com.fasterxml.jackson.annotation.JacksonInject;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.InjectableValues;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
import com.fasterxml.jackson.datatype.jdk8.Jdk8Module;
import io.substrait.expression.Expression;
import io.substrait.function.ParameterizedType;
import io.substrait.function.ToTypeString;
import io.substrait.function.TypeExpression;
import io.substrait.type.Deserializers;
import io.substrait.type.TypeExpressionEvaluator;
import io.substrait.util.Util;
import java.io.IOException;
import java.io.InputStream;
import java.util.*;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.immutables.value.Value;
/** Classes used to deserialize YAML extension files. Handles functions and types. */
@Value.Enclosing
public class SimpleExtension {
static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(SimpleExtension.class);
// Key for looking up URI in InjectableValues
public static final String URI_LOCATOR_KEY = "uri";
private static ObjectMapper objectMapper(String namespace) {
InjectableValues.Std iv = new InjectableValues.Std();
iv.addValue(URI_LOCATOR_KEY, namespace);
return new ObjectMapper(new YAMLFactory())
.enable(DeserializationFeature.ACCEPT_SINGLE_VALUE_AS_ARRAY)
.registerModule(new Jdk8Module())
.registerModule(Deserializers.MODULE)
.setInjectableValues(iv);
}
enum Nullability {
MIRROR,
DECLARED_OUTPUT,
DISCRETE
}
enum Decomposability {
NONE,
ONE,
MANY
}
public enum WindowType {
PARTITION,
STREAMING
}
private SimpleExtension() {}
@JsonTypeInfo(use = JsonTypeInfo.Id.DEDUCTION)
@JsonSubTypes({
@JsonSubTypes.Type(ValueArgument.class),
@JsonSubTypes.Type(TypeArgument.class),
@JsonSubTypes.Type(EnumArgument.class)
})
public interface Argument {
String toTypeString();
@JsonProperty()
@Nullable
String name();
@JsonProperty()
@Nullable
String description();
boolean required();
}
@JsonDeserialize(as = ImmutableSimpleExtension.Option.class)
@JsonSerialize(as = ImmutableSimpleExtension.Option.class)
@Value.Immutable
public interface Option {
Optional getDescription();
List getValues();
}
@JsonSerialize(as = ImmutableSimpleExtension.ValueArgument.class)
@JsonDeserialize(as = ImmutableSimpleExtension.ValueArgument.class)
@Value.Immutable
public abstract static class ValueArgument implements Argument {
@JsonProperty(required = true)
public abstract ParameterizedType value();
@JsonProperty()
@Nullable
public abstract Boolean constant();
@Override
public String toTypeString() {
return value().accept(ToTypeString.INSTANCE);
}
public boolean required() {
return true;
}
public static ImmutableSimpleExtension.ValueArgument.Builder builder() {
return ImmutableSimpleExtension.ValueArgument.builder();
}
}
@JsonSerialize(as = ImmutableSimpleExtension.TypeArgument.class)
@JsonDeserialize(as = ImmutableSimpleExtension.TypeArgument.class)
@Value.Immutable
public abstract static class TypeArgument implements Argument {
@JsonProperty(required = true)
public abstract ParameterizedType type();
public String toTypeString() {
return "type";
}
public boolean required() {
return true;
}
public static ImmutableSimpleExtension.TypeArgument.Builder builder() {
return ImmutableSimpleExtension.TypeArgument.builder();
}
}
/**
* An enum argument is required to be known within the plan (no field references can be used).
* These are distinct from Enum data types; which are just like any other type and could be
* reasonably expressed by referencing a preexisting column that contains those values.
*
* For more details see comments
* in this issue
*/
@JsonSerialize(as = ImmutableSimpleExtension.EnumArgument.class)
@JsonDeserialize(as = ImmutableSimpleExtension.EnumArgument.class)
@Value.Immutable
public abstract static class EnumArgument implements Argument {
@JsonProperty(required = true)
public abstract List options();
@Override
public boolean required() {
return true;
}
public String toTypeString() {
return "req";
}
public static ImmutableSimpleExtension.EnumArgument.Builder builder() {
return ImmutableSimpleExtension.EnumArgument.builder();
}
}
public interface Anchor {
String namespace();
String key();
}
@Value.Immutable
public interface FunctionAnchor extends Anchor {
static FunctionAnchor of(String namespace, String key) {
return ImmutableSimpleExtension.FunctionAnchor.builder()
.namespace(namespace)
.key(key)
.build();
}
}
@Value.Immutable
public interface TypeAnchor extends Anchor {
static TypeAnchor of(String namespace, String name) {
return ImmutableSimpleExtension.TypeAnchor.builder().namespace(namespace).key(name).build();
}
}
@JsonDeserialize(as = ImmutableSimpleExtension.VariadicBehavior.class)
@JsonSerialize(as = ImmutableSimpleExtension.VariadicBehavior.class)
@Value.Immutable
public interface VariadicBehavior {
int getMin();
OptionalInt getMax();
enum ParameterConsistency {
CONSISTENT,
INCONSISTENT
}
default ParameterConsistency parameterConsistency() {
return ParameterConsistency.CONSISTENT;
}
}
public abstract static class Function {
@Value.Default
public String name() {
// we can't use null detection here since we initially construct this with a parent name.
return "";
}
@Value.Default
public String uri() {
// we can't use null detection here since we initially construct this without a uri, then
// resolve later.
return "";
}
public abstract Optional variadic();
@Value.Default
@Nullable
public String description() {
return "";
}
public abstract List args();
public abstract Map options();
public List requiredArguments() {
return requiredArgsSupplier.get();
}
@Override
public String toString() {
return key();
}
@Value.Default
public Nullability nullability() {
return Nullability.MIRROR;
}
@Nullable
public abstract Boolean ordered();
public FunctionAnchor getAnchor() {
return anchorSupplier.get();
}
@JsonProperty(value = "return")
public abstract TypeExpression returnType();
private final Supplier anchorSupplier =
Util.memoize(() -> FunctionAnchor.of(uri(), key()));
private final Supplier keySupplier = Util.memoize(() -> constructKey(name(), args()));
private final Supplier> requiredArgsSupplier =
Util.memoize(
() -> {
return args().stream()
.filter(Argument::required)
.collect(java.util.stream.Collectors.toList());
});
public static String constructKeyFromTypes(
String name, List arguments) {
try {
return name
+ ":"
+ arguments.stream()
.map(t -> t.accept(ToTypeString.INSTANCE))
.collect(Collectors.joining("_"));
} catch (UnsupportedOperationException ex) {
throw new UnsupportedOperationException(
String.format("Failure converting types of function %s.", name), ex);
}
}
public static String constructKey(String name, List arguments) {
try {
return name
+ ":"
+ arguments.stream().map(Argument::toTypeString).collect(Collectors.joining("_"));
} catch (UnsupportedOperationException ex) {
throw new UnsupportedOperationException(
String.format("Failure converting types of function %s.", name), ex);
}
}
public Util.IntRange getRange() {
// end range is exclusive so add one to size.
long optionalCount = args().stream().filter(t -> !t.required()).count();
int max =
variadic()
.map(
t -> {
OptionalInt optionalMax = t.getMax();
IntStream stream =
optionalMax.isPresent()
? IntStream.of(optionalMax.getAsInt())
: IntStream.empty();
return stream
.map(x -> args().size() - 1 + x + 1)
.findFirst()
.orElse(Integer.MAX_VALUE);
})
.orElse(args().size() + 1);
int min =
variadic().map(t -> args().size() - 1 + t.getMin()).orElse(requiredArguments().size());
return Util.IntRange.of(min, max);
}
public void validateOutputType(
List argumentExpressions, io.substrait.type.Type outputType) {
// TODO: support advanced output type validation using return expressions, parameters, etc.
// The code below was too restrictive in the case of nullability conversion.
return;
// boolean makeNullable = nullability() == Nullability.MIRROR &&
// argumentExpressions.stream().filter(e ->
// e.getType().nullable()).findFirst().isPresent();
// if (returnType() instanceof Type && !outputType.equals(returnType())) {
//
// throw new IllegalArgumentException(String.format("Output type of %s doesn't match
// expected output
// type of %s for %s.", outputType, returnType(), this.key()));
// }
}
public String key() {
return keySupplier.get();
}
public io.substrait.type.Type resolveType(List argumentTypes) {
return TypeExpressionEvaluator.evaluateExpression(returnType(), args(), argumentTypes);
}
}
@JsonDeserialize(as = ImmutableSimpleExtension.ScalarFunction.class)
@JsonSerialize(as = ImmutableSimpleExtension.ScalarFunction.class)
@Value.Immutable
public abstract static class ScalarFunction {
public abstract String name();
@Nullable
public abstract String description();
public abstract List impls();
public Stream resolve(String uri) {
return impls().stream().map(f -> f.resolve(uri, name(), description()));
}
}
@JsonDeserialize(as = ImmutableSimpleExtension.ScalarFunctionVariant.class)
@JsonSerialize(as = ImmutableSimpleExtension.ScalarFunctionVariant.class)
@Value.Immutable
public abstract static class ScalarFunctionVariant extends Function {
public ScalarFunctionVariant resolve(String uri, String name, String description) {
return ImmutableSimpleExtension.ScalarFunctionVariant.builder()
.uri(uri)
.name(name)
.description(description)
.nullability(nullability())
.args(args())
.options(options())
.ordered(ordered())
.variadic(variadic())
.returnType(returnType())
.build();
}
}
@JsonDeserialize(as = ImmutableSimpleExtension.AggregateFunction.class)
@JsonSerialize(as = ImmutableSimpleExtension.AggregateFunction.class)
@Value.Immutable
public abstract static class AggregateFunction {
@Nullable
public abstract String name();
@Nullable
public abstract String description();
public abstract List impls();
public Stream resolve(String uri) {
return impls().stream().map(f -> f.resolve(uri, name(), description()));
}
}
@JsonDeserialize(as = ImmutableSimpleExtension.WindowFunction.class)
@JsonSerialize(as = ImmutableSimpleExtension.WindowFunction.class)
@Value.Immutable
public abstract static class WindowFunction {
@Nullable
public abstract String name();
@Nullable
public abstract String description();
public abstract List impls();
public Stream resolve(String uri) {
return impls().stream().map(f -> f.resolve(uri, name(), description()));
}
public static ImmutableSimpleExtension.WindowFunction.Builder builder() {
return ImmutableSimpleExtension.WindowFunction.builder();
}
}
@JsonDeserialize(as = ImmutableSimpleExtension.AggregateFunctionVariant.class)
@JsonSerialize(as = ImmutableSimpleExtension.AggregateFunctionVariant.class)
@Value.Immutable
public abstract static class AggregateFunctionVariant extends Function {
@Value.Default
@JsonProperty("decomposable")
public Decomposability decomposability() {
return Decomposability.NONE;
}
@Override
public String toString() {
return super.toString();
}
@Nullable
public abstract TypeExpression intermediate();
AggregateFunctionVariant resolve(String uri, String name, String description) {
return ImmutableSimpleExtension.AggregateFunctionVariant.builder()
.uri(uri)
.name(name)
.description(description)
.nullability(nullability())
.args(args())
.options(options())
.ordered(ordered())
.variadic(variadic())
.decomposability(decomposability())
.intermediate(intermediate())
.returnType(returnType())
.build();
}
}
@JsonDeserialize(as = ImmutableSimpleExtension.WindowFunctionVariant.class)
@JsonSerialize(as = ImmutableSimpleExtension.WindowFunctionVariant.class)
@Value.Immutable
public abstract static class WindowFunctionVariant extends Function {
@Value.Default
@JsonProperty("decomposable")
public Decomposability decomposability() {
return Decomposability.NONE;
}
@Nullable
public abstract TypeExpression intermediate();
@Value.Default
@JsonProperty("window_type")
public WindowType windowType() {
return WindowType.PARTITION;
}
@Override
public String toString() {
return super.toString();
}
WindowFunctionVariant resolve(String uri, String name, String description) {
return ImmutableSimpleExtension.WindowFunctionVariant.builder()
.uri(uri)
.name(name)
.description(description)
.nullability(nullability())
.args(args())
.options(options())
.ordered(ordered())
.variadic(variadic())
.decomposability(decomposability())
.intermediate(intermediate())
.returnType(returnType())
.windowType(windowType())
.build();
}
public static ImmutableSimpleExtension.WindowFunctionVariant.Builder builder() {
return ImmutableSimpleExtension.WindowFunctionVariant.builder();
}
}
@JsonDeserialize(as = ImmutableSimpleExtension.Type.class)
@JsonSerialize(as = ImmutableSimpleExtension.Type.class)
@Value.Immutable
public abstract static class Type {
public abstract String name();
@JacksonInject(SimpleExtension.URI_LOCATOR_KEY)
public abstract String uri();
// TODO: Handle conversion of structure object to Named Struct representation
protected abstract Optional