com.amazonaws.athena.connector.lambda.handlers.UserDefinedFunctionHandler Maven / Gradle / Ivy
Show all versions of aws-athena-federation-sdk Show documentation
package com.amazonaws.athena.connector.lambda.handlers;
/*-
* #%L
* Amazon Athena Query Federation SDK
* %%
* Copyright (C) 2019 Amazon Web Services
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
import com.amazonaws.athena.connector.lambda.data.Block;
import com.amazonaws.athena.connector.lambda.data.BlockAllocator;
import com.amazonaws.athena.connector.lambda.data.BlockAllocatorImpl;
import com.amazonaws.athena.connector.lambda.data.BlockUtils;
import com.amazonaws.athena.connector.lambda.data.FieldResolver;
import com.amazonaws.athena.connector.lambda.data.projectors.ArrowValueProjector;
import com.amazonaws.athena.connector.lambda.data.projectors.ProjectorUtils;
import com.amazonaws.athena.connector.lambda.data.writers.GeneratedRowWriter;
import com.amazonaws.athena.connector.lambda.data.writers.extractors.BigIntExtractor;
import com.amazonaws.athena.connector.lambda.data.writers.extractors.BitExtractor;
import com.amazonaws.athena.connector.lambda.data.writers.extractors.DateDayExtractor;
import com.amazonaws.athena.connector.lambda.data.writers.extractors.DateMilliExtractor;
import com.amazonaws.athena.connector.lambda.data.writers.extractors.DecimalExtractor;
import com.amazonaws.athena.connector.lambda.data.writers.extractors.Extractor;
import com.amazonaws.athena.connector.lambda.data.writers.extractors.Float4Extractor;
import com.amazonaws.athena.connector.lambda.data.writers.extractors.Float8Extractor;
import com.amazonaws.athena.connector.lambda.data.writers.extractors.IntExtractor;
import com.amazonaws.athena.connector.lambda.data.writers.extractors.SmallIntExtractor;
import com.amazonaws.athena.connector.lambda.data.writers.extractors.TinyIntExtractor;
import com.amazonaws.athena.connector.lambda.data.writers.extractors.VarBinaryExtractor;
import com.amazonaws.athena.connector.lambda.data.writers.extractors.VarCharExtractor;
import com.amazonaws.athena.connector.lambda.data.writers.fieldwriters.FieldWriterFactory;
import com.amazonaws.athena.connector.lambda.data.writers.holders.NullableDecimalHolder;
import com.amazonaws.athena.connector.lambda.data.writers.holders.NullableVarBinaryHolder;
import com.amazonaws.athena.connector.lambda.data.writers.holders.NullableVarCharHolder;
import com.amazonaws.athena.connector.lambda.domain.predicate.ConstraintProjector;
import com.amazonaws.athena.connector.lambda.request.FederationRequest;
import com.amazonaws.athena.connector.lambda.request.FederationResponse;
import com.amazonaws.athena.connector.lambda.request.PingRequest;
import com.amazonaws.athena.connector.lambda.request.PingResponse;
import com.amazonaws.athena.connector.lambda.serde.VersionedObjectMapperFactory;
import com.amazonaws.athena.connector.lambda.udf.UserDefinedFunctionRequest;
import com.amazonaws.athena.connector.lambda.udf.UserDefinedFunctionResponse;
import com.amazonaws.athena.connector.lambda.udf.UserDefinedFunctionType;
import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.RequestStreamHandler;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.complex.reader.FieldReader;
import org.apache.arrow.vector.holders.NullableBigIntHolder;
import org.apache.arrow.vector.holders.NullableBitHolder;
import org.apache.arrow.vector.holders.NullableDateDayHolder;
import org.apache.arrow.vector.holders.NullableDateMilliHolder;
import org.apache.arrow.vector.holders.NullableFloat4Holder;
import org.apache.arrow.vector.holders.NullableFloat8Holder;
import org.apache.arrow.vector.holders.NullableIntHolder;
import org.apache.arrow.vector.holders.NullableSmallIntHolder;
import org.apache.arrow.vector.holders.NullableTinyIntHolder;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.math.BigDecimal;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import static com.amazonaws.athena.connector.lambda.handlers.FederationCapabilities.CAPABILITIES;
import static com.amazonaws.athena.connector.lambda.handlers.SerDeVersion.SERDE_VERSION;
import static com.google.common.base.Preconditions.checkState;
/**
* Athena UDF users are expected to extend this class to create UDFs.
*/
public abstract class UserDefinedFunctionHandler
implements RequestStreamHandler
{
private static final Logger logger = LoggerFactory.getLogger(UserDefinedFunctionHandler.class);
private static final int RETURN_COLUMN_COUNT = 1;
//Used to tag log lines generated by this connector for diagnostic purposes when interacting with Athena.
private final String sourceType;
public UserDefinedFunctionHandler(String sourceType)
{
this.sourceType = sourceType;
}
@Override
public final void handleRequest(InputStream inputStream, OutputStream outputStream, Context context)
{
try (BlockAllocator allocator = new BlockAllocatorImpl()) {
ObjectMapper objectMapper = VersionedObjectMapperFactory.create(allocator);
try (FederationRequest rawRequest = objectMapper.readValue(inputStream, FederationRequest.class)) {
if (rawRequest instanceof PingRequest) {
try (PingResponse response = doPing((PingRequest) rawRequest)) {
assertNotNull(response);
objectMapper.writeValue(outputStream, response);
}
return;
}
if (!(rawRequest instanceof UserDefinedFunctionRequest)) {
throw new RuntimeException("Expected a UserDefinedFunctionRequest but found "
+ rawRequest.getClass());
}
doHandleRequest(allocator, objectMapper, (UserDefinedFunctionRequest) rawRequest, outputStream);
}
catch (Exception ex) {
throw (ex instanceof RuntimeException) ? (RuntimeException) ex : new RuntimeException(ex);
}
}
}
protected final void doHandleRequest(BlockAllocator allocator,
ObjectMapper objectMapper,
UserDefinedFunctionRequest req,
OutputStream outputStream)
throws Exception
{
logger.info("doHandleRequest: request[{}]", req);
try (UserDefinedFunctionResponse response = processFunction(allocator, req)) {
logger.info("doHandleRequest: response[{}]", response);
assertNotNull(response);
objectMapper.writeValue(outputStream, response);
}
}
@VisibleForTesting
UserDefinedFunctionResponse processFunction(BlockAllocator allocator, UserDefinedFunctionRequest req)
throws Exception
{
UserDefinedFunctionType functionType = req.getFunctionType();
switch (functionType) {
case SCALAR:
return processScalarFunction(allocator, req);
default:
throw new UnsupportedOperationException("Unsupported function type " + functionType);
}
}
private UserDefinedFunctionResponse processScalarFunction(BlockAllocator allocator, UserDefinedFunctionRequest req)
throws Exception
{
Method udfMethod = extractScalarFunctionMethod(req);
Block inputRecords = req.getInputRecords();
Schema outputSchema = req.getOutputSchema();
Block outputRecords = processRows(allocator, udfMethod, inputRecords, outputSchema);
return new UserDefinedFunctionResponse(outputRecords, udfMethod.getName());
}
/**
* Processes a group by rows. This method takes in a block of data (containing multiple rows), process them and
* returns multiple rows of the output column in a block.
*
* UDF methods are invoked row-by-row in a for loop. Arrow values are converted to Java Objects and then passed into
* the UDF java method. This is not very efficient because we might potentially be doing a lot of data copying.
* Advanced users could choose to override this method and directly deal with Arrow data to achieve better
* performance.
*
* @param allocator arrow memory allocator
* @param udfMethod the extracted java method matching the User-Defined-Function defined in Athena.
* @param inputRecords input data in Arrow format
* @param outputSchema output data schema in Arrow format
* @return output data in Arrow format
*/
protected Block processRows(BlockAllocator allocator, Method udfMethod, Block inputRecords, Schema outputSchema)
throws Exception
{
int rowCount = inputRecords.getRowCount();
List valueProjectors = Lists.newArrayList();
for (Field field : inputRecords.getFields()) {
FieldReader fieldReader = inputRecords.getFieldReader(field.getName());
ArrowValueProjector arrowValueProjector = ProjectorUtils.createArrowValueProjector(fieldReader);
valueProjectors.add(arrowValueProjector);
}
Field outputField = outputSchema.getFields().get(0);
GeneratedRowWriter outputRowWriter = createOutputRowWriter(outputField, valueProjectors, udfMethod);
Block outputRecords = allocator.createBlock(outputSchema);
outputRecords.setRowCount(rowCount);
try {
for (int rowNum = 0; rowNum < rowCount; ++rowNum) {
outputRowWriter.writeRow(outputRecords, rowNum, rowNum);
}
}
catch (Throwable t) {
try {
outputRecords.close();
}
catch (Exception e) {
logger.error("Error closing output block", e);
}
throw t;
}
return outputRecords;
}
/**
* Use reflection to find tha java method that maches the UDF function defined in Athena SQL.
*
* @param req UDF request
* @return java method matching the UDF defined in Athena query.
*/
private Method extractScalarFunctionMethod(UserDefinedFunctionRequest req)
{
String methodName = req.getMethodName();
Class[] argumentTypes = extractJavaTypes(req.getInputRecords().getSchema());
Class[] returnTypes = extractJavaTypes(req.getOutputSchema());
checkState(returnTypes.length == RETURN_COLUMN_COUNT,
String.format("Expecting %d return columns, found %d in method signature.",
RETURN_COLUMN_COUNT, returnTypes.length));
Class returnType = returnTypes[0];
Method udfMethod;
try {
udfMethod = this.getClass().getMethod(methodName, argumentTypes);
logger.info(String.format("Found UDF method %s with input types [%s] and output types [%s]",
methodName, Arrays.toString(argumentTypes), returnType.getName()));
}
catch (NoSuchMethodException e) {
String msg = "Failed to find UDF method. " + e.getMessage()
+ " Please make sure the method name contains only lowercase and the method signature (name and"
+ " argument types) in Lambda matches the function signature defined in SQL.";
throw new RuntimeException(msg, e);
}
if (!returnType.equals(udfMethod.getReturnType())) {
throw new IllegalArgumentException("signature return type " + returnType +
" does not match udf implementation return type " + udfMethod.getReturnType());
}
return udfMethod;
}
private Class[] extractJavaTypes(Schema schema)
{
Class[] types = new Class[schema.getFields().size()];
List fields = schema.getFields();
for (int i = 0; i < fields.size(); ++i) {
Types.MinorType minorType = Types.getMinorTypeForArrowType(fields.get(i).getType());
types[i] = BlockUtils.getJavaType(minorType);
}
return types;
}
private final PingResponse doPing(PingRequest request)
{
PingResponse response = new PingResponse(request.getCatalogName(), request.getQueryId(), sourceType, CAPABILITIES, SERDE_VERSION);
try {
onPing(request);
}
catch (Exception ex) {
logger.warn("doPing: encountered an exception while delegating onPing.", ex);
}
return response;
}
protected void onPing(PingRequest request)
{
//NoOp
}
private void assertNotNull(FederationResponse response)
{
if (response == null) {
throw new RuntimeException("Response was null");
}
}
private GeneratedRowWriter createOutputRowWriter(Field outputField, List valueProjectors, Method udfMethod)
{
GeneratedRowWriter.RowWriterBuilder builder = GeneratedRowWriter.newBuilder();
Extractor extractor = makeExtractor(outputField, valueProjectors, udfMethod);
if (extractor != null) {
builder.withExtractor(outputField.getName(), extractor);
}
else {
builder.withFieldWriterFactory(outputField.getName(), makeFactory(outputField, valueProjectors, udfMethod));
}
return builder.build();
}
/**
* Creates an Extractor for the given outputField.
* @param outputField outputField
* @param valueProjectors projectors that we use to read input data.
* @param udfMethod
* @return
*/
private Extractor makeExtractor(Field outputField, List valueProjectors, Method udfMethod)
{
Types.MinorType fieldType = Types.getMinorTypeForArrowType(outputField.getType());
Object[] arguments = new Object[valueProjectors.size()];
switch (fieldType) {
case INT:
return (IntExtractor) (Object inputRowNum, NullableIntHolder dst) ->
{
Object result = invokeMethod(udfMethod, arguments, (int) inputRowNum, valueProjectors);
if (result == null) {
dst.isSet = 0;
}
else {
dst.isSet = 1;
dst.value = (int) result;
}
};
case DATEMILLI:
return (DateMilliExtractor) (Object inputRowNum, NullableDateMilliHolder dst) ->
{
Object result = invokeMethod(udfMethod, arguments, (int) inputRowNum, valueProjectors);
if (result == null) {
dst.isSet = 0;
}
else {
dst.isSet = 1;
dst.value = ((LocalDateTime) result).atZone(BlockUtils.UTC_ZONE_ID).toInstant().toEpochMilli();
}
};
case DATEDAY:
return (DateDayExtractor) (Object inputRowNum, NullableDateDayHolder dst) ->
{
Object result = invokeMethod(udfMethod, arguments, (int) inputRowNum, valueProjectors);
if (result == null) {
dst.isSet = 0;
}
else {
dst.isSet = 1;
dst.value = (int) ((LocalDate) result).toEpochDay();
}
};
case TINYINT:
return (TinyIntExtractor) (Object inputRowNum, NullableTinyIntHolder dst) ->
{
Object result = invokeMethod(udfMethod, arguments, (int) inputRowNum, valueProjectors);
if (result == null) {
dst.isSet = 0;
}
else {
dst.isSet = 1;
dst.value = (byte) result;
}
};
case SMALLINT:
return (SmallIntExtractor) (Object inputRowNum, NullableSmallIntHolder dst) ->
{
Object result = invokeMethod(udfMethod, arguments, (int) inputRowNum, valueProjectors);
if (result == null) {
dst.isSet = 0;
}
else {
dst.isSet = 1;
dst.value = (short) result;
}
};
case FLOAT4:
return (Float4Extractor) (Object inputRowNum, NullableFloat4Holder dst) ->
{
Object result = invokeMethod(udfMethod, arguments, (int) inputRowNum, valueProjectors);
if (result == null) {
dst.isSet = 0;
}
else {
dst.isSet = 1;
dst.value = (float) result;
}
};
case FLOAT8:
return (Float8Extractor) (Object inputRowNum, NullableFloat8Holder dst) ->
{
Object result = invokeMethod(udfMethod, arguments, (int) inputRowNum, valueProjectors);
if (result == null) {
dst.isSet = 0;
}
else {
dst.isSet = 1;
dst.value = (double) result;
}
};
case DECIMAL:
return (DecimalExtractor) (Object inputRowNum, NullableDecimalHolder dst) ->
{
Object result = invokeMethod(udfMethod, arguments, (int) inputRowNum, valueProjectors);
if (result == null) {
dst.isSet = 0;
}
else {
dst.isSet = 1;
dst.value = ((BigDecimal) result);
}
};
case BIT:
return (BitExtractor) (Object inputRowNum, NullableBitHolder dst) ->
{
Object result = invokeMethod(udfMethod, arguments, (int) inputRowNum, valueProjectors);
if (result == null) {
dst.isSet = 0;
}
else {
dst.isSet = 1;
dst.value = ((boolean) result) ? 1 : 0;
}
};
case BIGINT:
return (BigIntExtractor) (Object inputRowNum, NullableBigIntHolder dst) ->
{
Object result = invokeMethod(udfMethod, arguments, (int) inputRowNum, valueProjectors);
if (result == null) {
dst.isSet = 0;
}
else {
dst.isSet = 1;
dst.value = (long) result;
}
};
case VARCHAR:
return (VarCharExtractor) (Object inputRowNum, NullableVarCharHolder dst) ->
{
Object result = invokeMethod(udfMethod, arguments, (int) inputRowNum, valueProjectors);
if (result == null) {
dst.isSet = 0;
}
else {
dst.isSet = 1;
dst.value = ((String) result);
}
};
case VARBINARY:
return (VarBinaryExtractor) (Object inputRowNum, NullableVarBinaryHolder dst) ->
{
Object result = invokeMethod(udfMethod, arguments, (int) inputRowNum, valueProjectors);
if (result == null) {
dst.isSet = 0;
}
else {
dst.isSet = 1;
dst.value = (byte[]) result;
}
};
default:
return null;
}
}
private FieldWriterFactory makeFactory(Field field, List valueProjectors, Method udfMethod)
{
Object[] arguments = new Object[valueProjectors.size()];
Types.MinorType fieldType = Types.getMinorTypeForArrowType(field.getType());
switch (fieldType) {
case LIST:
case STRUCT:
return (FieldVector vector, Extractor extractor, ConstraintProjector ignored) ->
(Object inputRowNum, int outputRowNum) -> {
Object result = invokeMethod(udfMethod, arguments, (int) inputRowNum, valueProjectors);
BlockUtils.setComplexValue(vector, outputRowNum, FieldResolver.DEFAULT, result);
return true; // push-down does not apply in UDFs
};
default:
throw new IllegalArgumentException("Unsupported type " + fieldType);
}
}
private Object invokeMethod(Method udfMethod,
Object[] arguments,
int inputRowNum,
List valueProjectors)
{
for (int col = 0; col < valueProjectors.size(); ++col) {
arguments[col] = valueProjectors.get(col).project(inputRowNum);
}
try {
return udfMethod.invoke(this, arguments);
}
catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
catch (InvocationTargetException e) {
if (Objects.isNull(e)) {
throw new RuntimeException(e);
}
throw new RuntimeException(e.getCause());
}
catch (IllegalArgumentException e) {
String msg = String.format("%s. Expected function types %s, got types %s",
e.getMessage(),
Arrays.stream(udfMethod.getParameterTypes()).map(clazz -> clazz.getName()).collect(Collectors.toList()),
Arrays.stream(arguments).map(arg -> arg.getClass().getName()).collect(Collectors.toList()));
throw new RuntimeException(msg, e);
}
}
}