All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.table.runtime.arrow.ArrowUtils Maven / Gradle / Ivy

There is a newer version: 2.0-preview1
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.runtime.arrow;

import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.RuntimeExecutionMode;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.configuration.ExecutionOptions;
import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableEnvironment;
import org.apache.flink.table.api.internal.TableEnvironmentImpl;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.table.data.ArrayData;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.columnar.vector.ColumnVector;
import org.apache.flink.table.data.util.DataFormatConverters;
import org.apache.flink.table.operations.OutputConversionModifyOperation;
import org.apache.flink.table.runtime.arrow.sources.ArrowTableSource;
import org.apache.flink.table.runtime.arrow.vectors.ArrowArrayColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowBigIntColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowBinaryColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowBooleanColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowDateColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowDecimalColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowDoubleColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowFloatColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowIntColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowMapColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowNullColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowRowColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowSmallIntColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowTimeColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowTimestampColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowTinyIntColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowVarBinaryColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowVarCharColumnVector;
import org.apache.flink.table.runtime.arrow.writers.ArrayWriter;
import org.apache.flink.table.runtime.arrow.writers.ArrowFieldWriter;
import org.apache.flink.table.runtime.arrow.writers.BigIntWriter;
import org.apache.flink.table.runtime.arrow.writers.BinaryWriter;
import org.apache.flink.table.runtime.arrow.writers.BooleanWriter;
import org.apache.flink.table.runtime.arrow.writers.DateWriter;
import org.apache.flink.table.runtime.arrow.writers.DecimalWriter;
import org.apache.flink.table.runtime.arrow.writers.DoubleWriter;
import org.apache.flink.table.runtime.arrow.writers.FloatWriter;
import org.apache.flink.table.runtime.arrow.writers.IntWriter;
import org.apache.flink.table.runtime.arrow.writers.MapWriter;
import org.apache.flink.table.runtime.arrow.writers.NullWriter;
import org.apache.flink.table.runtime.arrow.writers.RowWriter;
import org.apache.flink.table.runtime.arrow.writers.SmallIntWriter;
import org.apache.flink.table.runtime.arrow.writers.TimeWriter;
import org.apache.flink.table.runtime.arrow.writers.TimestampWriter;
import org.apache.flink.table.runtime.arrow.writers.TinyIntWriter;
import org.apache.flink.table.runtime.arrow.writers.VarBinaryWriter;
import org.apache.flink.table.runtime.arrow.writers.VarCharWriter;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.ArrayType;
import org.apache.flink.table.types.logical.BigIntType;
import org.apache.flink.table.types.logical.BinaryType;
import org.apache.flink.table.types.logical.BooleanType;
import org.apache.flink.table.types.logical.CharType;
import org.apache.flink.table.types.logical.DateType;
import org.apache.flink.table.types.logical.DecimalType;
import org.apache.flink.table.types.logical.DoubleType;
import org.apache.flink.table.types.logical.FloatType;
import org.apache.flink.table.types.logical.IntType;
import org.apache.flink.table.types.logical.LegacyTypeInformationType;
import org.apache.flink.table.types.logical.LocalZonedTimestampType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.MapType;
import org.apache.flink.table.types.logical.NullType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.logical.SmallIntType;
import org.apache.flink.table.types.logical.TimeType;
import org.apache.flink.table.types.logical.TimestampType;
import org.apache.flink.table.types.logical.TinyIntType;
import org.apache.flink.table.types.logical.VarBinaryType;
import org.apache.flink.table.types.logical.VarCharType;
import org.apache.flink.table.types.logical.utils.LogicalTypeDefaultVisitor;
import org.apache.flink.table.types.utils.TypeConversions;
import org.apache.flink.types.Row;
import org.apache.flink.types.RowKind;
import org.apache.flink.util.Preconditions;

import org.apache.flink.shaded.guava31.com.google.common.collect.LinkedHashMultiset;

import org.apache.arrow.flatbuf.MessageHeader;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.BitVector;
import org.apache.arrow.vector.DateDayVector;
import org.apache.arrow.vector.DecimalVector;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.FixedSizeBinaryVector;
import org.apache.arrow.vector.Float4Vector;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.NullVector;
import org.apache.arrow.vector.SmallIntVector;
import org.apache.arrow.vector.TimeMicroVector;
import org.apache.arrow.vector.TimeMilliVector;
import org.apache.arrow.vector.TimeNanoVector;
import org.apache.arrow.vector.TimeSecVector;
import org.apache.arrow.vector.TimeStampVector;
import org.apache.arrow.vector.TinyIntVector;
import org.apache.arrow.vector.ValueVector;
import org.apache.arrow.vector.VarBinaryVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.complex.MapVector;
import org.apache.arrow.vector.complex.StructVector;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.ipc.ReadChannel;
import org.apache.arrow.vector.ipc.WriteChannel;
import org.apache.arrow.vector.ipc.message.MessageMetadataResult;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.DateUnit;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.TimeUnit;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.ByteArrayOutputStream;
import java.io.EOFException;
import java.io.FileInputStream;
import java.io.IOException;
import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;

/** Utilities for Arrow. */
@Internal
public final class ArrowUtils {

    private static final Logger LOG = LoggerFactory.getLogger(ArrowUtils.class);

    private static RootAllocator rootAllocator;

    public static synchronized RootAllocator getRootAllocator() {
        if (rootAllocator == null) {
            rootAllocator = new RootAllocator(Long.MAX_VALUE);
        }
        return rootAllocator;
    }

    public static void checkArrowUsable() {
        // Arrow requires the property io.netty.tryReflectionSetAccessible to
        // be set to true for JDK >= 9. Please refer to ARROW-5412 for more details.
        if (System.getProperty("io.netty.tryReflectionSetAccessible") == null) {
            System.setProperty("io.netty.tryReflectionSetAccessible", "true");
        } else if (!io.netty.util.internal.PlatformDependent
                .hasDirectBufferNoCleanerConstructor()) {
            throw new RuntimeException(
                    "Arrow depends on "
                            + "DirectByteBuffer.(long, int) which is not available. Please set the "
                            + "system property 'io.netty.tryReflectionSetAccessible' to 'true'.");
        }
    }

    /** Returns the Arrow schema of the specified type. */
    public static Schema toArrowSchema(RowType rowType) {
        Collection fields =
                rowType.getFields().stream()
                        .map(f -> ArrowUtils.toArrowField(f.getName(), f.getType()))
                        .collect(Collectors.toCollection(ArrayList::new));
        return new Schema(fields);
    }

    private static Field toArrowField(String fieldName, LogicalType logicalType) {
        FieldType fieldType =
                new FieldType(
                        logicalType.isNullable(),
                        logicalType.accept(LogicalTypeToArrowTypeConverter.INSTANCE),
                        null);
        List children = null;
        if (logicalType instanceof ArrayType) {
            children =
                    Collections.singletonList(
                            toArrowField("element", ((ArrayType) logicalType).getElementType()));
        } else if (logicalType instanceof RowType) {
            RowType rowType = (RowType) logicalType;
            children = new ArrayList<>(rowType.getFieldCount());
            for (RowType.RowField field : rowType.getFields()) {
                children.add(toArrowField(field.getName(), field.getType()));
            }
        } else if (logicalType instanceof MapType) {
            MapType mapType = (MapType) logicalType;
            Preconditions.checkArgument(
                    !mapType.getKeyType().isNullable(), "Map key type should be non-nullable");
            children =
                    Collections.singletonList(
                            new Field(
                                    "items",
                                    new FieldType(false, ArrowType.Struct.INSTANCE, null),
                                    Arrays.asList(
                                            toArrowField("key", mapType.getKeyType()),
                                            toArrowField("value", mapType.getValueType()))));
        }
        return new Field(fieldName, fieldType, children);
    }

    /** Creates an {@link ArrowWriter} for the specified {@link VectorSchemaRoot}. */
    public static ArrowWriter createRowDataArrowWriter(
            VectorSchemaRoot root, RowType rowType) {
        ArrowFieldWriter[] fieldWriters =
                new ArrowFieldWriter[root.getFieldVectors().size()];
        List vectors = root.getFieldVectors();
        for (int i = 0; i < vectors.size(); i++) {
            FieldVector vector = vectors.get(i);
            vector.allocateNew();
            fieldWriters[i] = createArrowFieldWriterForRow(vector, rowType.getTypeAt(i));
        }

        return new ArrowWriter<>(root, fieldWriters);
    }

    private static ArrowFieldWriter createArrowFieldWriterForRow(
            ValueVector vector, LogicalType fieldType) {
        if (vector instanceof TinyIntVector) {
            return TinyIntWriter.forRow((TinyIntVector) vector);
        } else if (vector instanceof SmallIntVector) {
            return SmallIntWriter.forRow((SmallIntVector) vector);
        } else if (vector instanceof IntVector) {
            return IntWriter.forRow((IntVector) vector);
        } else if (vector instanceof BigIntVector) {
            return BigIntWriter.forRow((BigIntVector) vector);
        } else if (vector instanceof BitVector) {
            return BooleanWriter.forRow((BitVector) vector);
        } else if (vector instanceof Float4Vector) {
            return FloatWriter.forRow((Float4Vector) vector);
        } else if (vector instanceof Float8Vector) {
            return DoubleWriter.forRow((Float8Vector) vector);
        } else if (vector instanceof VarCharVector) {
            return VarCharWriter.forRow((VarCharVector) vector);
        } else if (vector instanceof FixedSizeBinaryVector) {
            return BinaryWriter.forRow((FixedSizeBinaryVector) vector);
        } else if (vector instanceof VarBinaryVector) {
            return VarBinaryWriter.forRow((VarBinaryVector) vector);
        } else if (vector instanceof DecimalVector) {
            DecimalVector decimalVector = (DecimalVector) vector;
            return DecimalWriter.forRow(
                    decimalVector, getPrecision(decimalVector), decimalVector.getScale());
        } else if (vector instanceof DateDayVector) {
            return DateWriter.forRow((DateDayVector) vector);
        } else if (vector instanceof TimeSecVector
                || vector instanceof TimeMilliVector
                || vector instanceof TimeMicroVector
                || vector instanceof TimeNanoVector) {
            return TimeWriter.forRow(vector);
        } else if (vector instanceof TimeStampVector
                && ((ArrowType.Timestamp) vector.getField().getType()).getTimezone() == null) {
            int precision;
            if (fieldType instanceof LocalZonedTimestampType) {
                precision = ((LocalZonedTimestampType) fieldType).getPrecision();
            } else {
                precision = ((TimestampType) fieldType).getPrecision();
            }
            return TimestampWriter.forRow(vector, precision);
        } else if (vector instanceof MapVector) {
            MapVector mapVector = (MapVector) vector;
            LogicalType keyType = ((MapType) fieldType).getKeyType();
            LogicalType valueType = ((MapType) fieldType).getValueType();
            StructVector structVector = (StructVector) mapVector.getDataVector();
            return MapWriter.forRow(
                    mapVector,
                    createArrowFieldWriterForArray(
                            structVector.getChild(MapVector.KEY_NAME), keyType),
                    createArrowFieldWriterForArray(
                            structVector.getChild(MapVector.VALUE_NAME), valueType));
        } else if (vector instanceof ListVector) {
            ListVector listVector = (ListVector) vector;
            LogicalType elementType = ((ArrayType) fieldType).getElementType();
            return ArrayWriter.forRow(
                    listVector,
                    createArrowFieldWriterForArray(listVector.getDataVector(), elementType));
        } else if (vector instanceof StructVector) {
            RowType rowType = (RowType) fieldType;
            ArrowFieldWriter[] fieldsWriters =
                    new ArrowFieldWriter[rowType.getFieldCount()];
            for (int i = 0; i < fieldsWriters.length; i++) {
                fieldsWriters[i] =
                        createArrowFieldWriterForRow(
                                ((StructVector) vector).getVectorById(i), rowType.getTypeAt(i));
            }
            return RowWriter.forRow((StructVector) vector, fieldsWriters);
        } else if (vector instanceof NullVector) {
            return new NullWriter<>((NullVector) vector);
        } else {
            throw new UnsupportedOperationException(
                    String.format("Unsupported type %s.", fieldType));
        }
    }

    private static ArrowFieldWriter createArrowFieldWriterForArray(
            ValueVector vector, LogicalType fieldType) {
        if (vector instanceof TinyIntVector) {
            return TinyIntWriter.forArray((TinyIntVector) vector);
        } else if (vector instanceof SmallIntVector) {
            return SmallIntWriter.forArray((SmallIntVector) vector);
        } else if (vector instanceof IntVector) {
            return IntWriter.forArray((IntVector) vector);
        } else if (vector instanceof BigIntVector) {
            return BigIntWriter.forArray((BigIntVector) vector);
        } else if (vector instanceof BitVector) {
            return BooleanWriter.forArray((BitVector) vector);
        } else if (vector instanceof Float4Vector) {
            return FloatWriter.forArray((Float4Vector) vector);
        } else if (vector instanceof Float8Vector) {
            return DoubleWriter.forArray((Float8Vector) vector);
        } else if (vector instanceof VarCharVector) {
            return VarCharWriter.forArray((VarCharVector) vector);
        } else if (vector instanceof FixedSizeBinaryVector) {
            return BinaryWriter.forArray((FixedSizeBinaryVector) vector);
        } else if (vector instanceof VarBinaryVector) {
            return VarBinaryWriter.forArray((VarBinaryVector) vector);
        } else if (vector instanceof DecimalVector) {
            DecimalVector decimalVector = (DecimalVector) vector;
            return DecimalWriter.forArray(
                    decimalVector, getPrecision(decimalVector), decimalVector.getScale());
        } else if (vector instanceof DateDayVector) {
            return DateWriter.forArray((DateDayVector) vector);
        } else if (vector instanceof TimeSecVector
                || vector instanceof TimeMilliVector
                || vector instanceof TimeMicroVector
                || vector instanceof TimeNanoVector) {
            return TimeWriter.forArray(vector);
        } else if (vector instanceof TimeStampVector
                && ((ArrowType.Timestamp) vector.getField().getType()).getTimezone() == null) {
            int precision;
            if (fieldType instanceof LocalZonedTimestampType) {
                precision = ((LocalZonedTimestampType) fieldType).getPrecision();
            } else {
                precision = ((TimestampType) fieldType).getPrecision();
            }
            return TimestampWriter.forArray(vector, precision);
        } else if (vector instanceof MapVector) {
            MapVector mapVector = (MapVector) vector;
            LogicalType keyType = ((MapType) fieldType).getKeyType();
            LogicalType valueType = ((MapType) fieldType).getValueType();
            StructVector structVector = (StructVector) mapVector.getDataVector();
            return MapWriter.forArray(
                    mapVector,
                    createArrowFieldWriterForArray(
                            structVector.getChild(MapVector.KEY_NAME), keyType),
                    createArrowFieldWriterForArray(
                            structVector.getChild(MapVector.VALUE_NAME), valueType));
        } else if (vector instanceof ListVector) {
            ListVector listVector = (ListVector) vector;
            LogicalType elementType = ((ArrayType) fieldType).getElementType();
            return ArrayWriter.forArray(
                    listVector,
                    createArrowFieldWriterForArray(listVector.getDataVector(), elementType));
        } else if (vector instanceof StructVector) {
            RowType rowType = (RowType) fieldType;
            ArrowFieldWriter[] fieldsWriters =
                    new ArrowFieldWriter[rowType.getFieldCount()];
            for (int i = 0; i < fieldsWriters.length; i++) {
                fieldsWriters[i] =
                        createArrowFieldWriterForRow(
                                ((StructVector) vector).getVectorById(i), rowType.getTypeAt(i));
            }
            return RowWriter.forArray((StructVector) vector, fieldsWriters);
        } else if (vector instanceof NullVector) {
            return new NullWriter<>((NullVector) vector);
        } else {
            throw new UnsupportedOperationException(
                    String.format("Unsupported type %s.", fieldType));
        }
    }

    /** Creates an {@link ArrowReader} for the specified {@link VectorSchemaRoot}. */
    public static ArrowReader createArrowReader(VectorSchemaRoot root, RowType rowType) {
        List columnVectors = new ArrayList<>();
        List fieldVectors = root.getFieldVectors();
        for (int i = 0; i < fieldVectors.size(); i++) {
            columnVectors.add(createColumnVector(fieldVectors.get(i), rowType.getTypeAt(i)));
        }

        return new ArrowReader(columnVectors.toArray(new ColumnVector[0]));
    }

    public static ColumnVector createColumnVector(ValueVector vector, LogicalType fieldType) {
        if (vector instanceof TinyIntVector) {
            return new ArrowTinyIntColumnVector((TinyIntVector) vector);
        } else if (vector instanceof SmallIntVector) {
            return new ArrowSmallIntColumnVector((SmallIntVector) vector);
        } else if (vector instanceof IntVector) {
            return new ArrowIntColumnVector((IntVector) vector);
        } else if (vector instanceof BigIntVector) {
            return new ArrowBigIntColumnVector((BigIntVector) vector);
        } else if (vector instanceof BitVector) {
            return new ArrowBooleanColumnVector((BitVector) vector);
        } else if (vector instanceof Float4Vector) {
            return new ArrowFloatColumnVector((Float4Vector) vector);
        } else if (vector instanceof Float8Vector) {
            return new ArrowDoubleColumnVector((Float8Vector) vector);
        } else if (vector instanceof VarCharVector) {
            return new ArrowVarCharColumnVector((VarCharVector) vector);
        } else if (vector instanceof FixedSizeBinaryVector) {
            return new ArrowBinaryColumnVector((FixedSizeBinaryVector) vector);
        } else if (vector instanceof VarBinaryVector) {
            return new ArrowVarBinaryColumnVector((VarBinaryVector) vector);
        } else if (vector instanceof DecimalVector) {
            return new ArrowDecimalColumnVector((DecimalVector) vector);
        } else if (vector instanceof DateDayVector) {
            return new ArrowDateColumnVector((DateDayVector) vector);
        } else if (vector instanceof TimeSecVector
                || vector instanceof TimeMilliVector
                || vector instanceof TimeMicroVector
                || vector instanceof TimeNanoVector) {
            return new ArrowTimeColumnVector(vector);
        } else if (vector instanceof TimeStampVector
                && ((ArrowType.Timestamp) vector.getField().getType()).getTimezone() == null) {
            return new ArrowTimestampColumnVector(vector);
        } else if (vector instanceof MapVector) {
            MapVector mapVector = (MapVector) vector;
            LogicalType keyType = ((MapType) fieldType).getKeyType();
            LogicalType valueType = ((MapType) fieldType).getValueType();
            StructVector structVector = (StructVector) mapVector.getDataVector();
            return new ArrowMapColumnVector(
                    mapVector,
                    createColumnVector(structVector.getChild(MapVector.KEY_NAME), keyType),
                    createColumnVector(structVector.getChild(MapVector.VALUE_NAME), valueType));
        } else if (vector instanceof ListVector) {
            ListVector listVector = (ListVector) vector;
            return new ArrowArrayColumnVector(
                    listVector,
                    createColumnVector(
                            listVector.getDataVector(), ((ArrayType) fieldType).getElementType()));
        } else if (vector instanceof StructVector) {
            StructVector structVector = (StructVector) vector;
            ColumnVector[] fieldColumns = new ColumnVector[structVector.size()];
            for (int i = 0; i < fieldColumns.length; ++i) {
                fieldColumns[i] =
                        createColumnVector(
                                structVector.getVectorById(i), ((RowType) fieldType).getTypeAt(i));
            }
            return new ArrowRowColumnVector(structVector, fieldColumns);
        } else if (vector instanceof NullVector) {
            return ArrowNullColumnVector.INSTANCE;
        } else {
            throw new UnsupportedOperationException(
                    String.format("Unsupported type %s.", fieldType));
        }
    }

    public static ArrowTableSource createArrowTableSource(DataType dataType, String fileName)
            throws IOException {
        try (FileInputStream fis = new FileInputStream(fileName)) {
            return new ArrowTableSource(dataType, readArrowBatches(fis.getChannel()));
        }
    }

    public static byte[][] readArrowBatches(ReadableByteChannel channel) throws IOException {
        List results = new ArrayList<>();
        byte[] batch;
        while ((batch = readNextBatch(channel)) != null) {
            results.add(batch);
        }
        return results.toArray(new byte[0][]);
    }

    private static byte[] readNextBatch(ReadableByteChannel channel) throws IOException {
        MessageMetadataResult metadata = MessageSerializer.readMessage(new ReadChannel(channel));
        if (metadata == null) {
            return null;
        }

        long bodyLength = metadata.getMessageBodyLength();

        // Only care about RecordBatch messages and skip the other kind of messages
        if (metadata.getMessage().headerType() == MessageHeader.RecordBatch) {
            // Buffer backed output large enough to hold 8-byte length + complete serialized message
            ByteArrayOutputStreamWithPos baos =
                    new ByteArrayOutputStreamWithPos(
                            (int) (8 + metadata.getMessageLength() + bodyLength));

            // Write message metadata to ByteBuffer output stream
            MessageSerializer.writeMessageBuffer(
                    new WriteChannel(Channels.newChannel(baos)),
                    metadata.getMessageLength(),
                    metadata.getMessageBuffer());

            baos.close();
            ByteBuffer result = ByteBuffer.wrap(baos.getBuf());
            result.position(baos.getPosition());
            result.limit(result.capacity());
            readFully(channel, result);
            return result.array();
        } else {
            if (bodyLength > 0) {
                // Skip message body if not a RecordBatch
                Channels.newInputStream(channel).skip(bodyLength);
            }

            // Proceed to next message
            return readNextBatch(channel);
        }
    }

    /** Fills a buffer with data read from the channel. */
    private static void readFully(ReadableByteChannel channel, ByteBuffer dst) throws IOException {
        int expected = dst.remaining();
        while (dst.hasRemaining()) {
            if (channel.read(dst) < 0) {
                throw new EOFException(
                        String.format("Not enough bytes in channel (expected %d).", expected));
            }
        }
    }

    /** Convert Flink table to Pandas DataFrame. */
    public static CustomIterator collectAsPandasDataFrame(
            Table table, int maxArrowBatchSize) throws Exception {
        checkArrowUsable();
        BufferAllocator allocator =
                getRootAllocator().newChildAllocator("collectAsPandasDataFrame", 0, Long.MAX_VALUE);
        RowType rowType =
                (RowType) table.getResolvedSchema().toSourceRowDataType().getLogicalType();
        DataType defaultRowDataType = TypeConversions.fromLogicalToDataType(rowType);
        VectorSchemaRoot root =
                VectorSchemaRoot.create(ArrowUtils.toArrowSchema(rowType), allocator);
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        ArrowStreamWriter arrowStreamWriter = new ArrowStreamWriter(root, null, baos);
        arrowStreamWriter.start();

        Iterator results = table.execute().collect();
        Iterator appendOnlyResults;
        if (isAppendOnlyTable(table)) {
            appendOnlyResults = results;
        } else {
            appendOnlyResults = filterOutRetractRows(results);
        }

        ArrowWriter arrowWriter = createRowDataArrowWriter(root, rowType);
        Iterator convertedResults =
                new Iterator() {
                    @Override
                    public boolean hasNext() {
                        return appendOnlyResults.hasNext();
                    }

                    @Override
                    public RowData next() {
                        DataFormatConverters.DataFormatConverter converter =
                                DataFormatConverters.getConverterForDataType(defaultRowDataType);
                        return (RowData) converter.toInternal(appendOnlyResults.next());
                    }
                };

        return new CustomIterator() {
            @Override
            public boolean hasNext() {
                return convertedResults.hasNext();
            }

            @Override
            public byte[] next() {
                try {
                    int i = 0;
                    while (convertedResults.hasNext() && i < maxArrowBatchSize) {
                        i++;
                        arrowWriter.write(convertedResults.next());
                    }
                    arrowWriter.finish();
                    arrowStreamWriter.writeBatch();
                    return baos.toByteArray();
                } catch (Throwable t) {
                    String msg = "Failed to serialize the data of the table";
                    LOG.error(msg, t);
                    throw new RuntimeException(msg, t);
                } finally {
                    arrowWriter.reset();
                    baos.reset();

                    if (!hasNext()) {
                        root.close();
                        allocator.close();
                    }
                }
            }
        };
    }

    private static Iterator filterOutRetractRows(Iterator data) {
        LinkedHashMultiset result = LinkedHashMultiset.create();
        while (data.hasNext()) {
            Row element = data.next();
            if (element.getKind() == RowKind.INSERT || element.getKind() == RowKind.UPDATE_AFTER) {
                element.setKind(RowKind.INSERT);
                result.add(element);
            } else {
                element.setKind(RowKind.INSERT);
                if (!result.remove(element)) {
                    throw new RuntimeException(
                            String.format(
                                    "Could not remove element '%s', should never happen.",
                                    element));
                }
            }
        }
        return result.iterator();
    }

    private static boolean isStreamingMode(Table table) {
        TableEnvironment tableEnv = ((TableImpl) table).getTableEnvironment();
        if (tableEnv instanceof TableEnvironmentImpl) {
            final RuntimeExecutionMode mode =
                    tableEnv.getConfig().get(ExecutionOptions.RUNTIME_MODE);
            if (mode == RuntimeExecutionMode.AUTOMATIC) {
                throw new RuntimeException(
                        String.format("Runtime execution mode '%s' is not supported yet.", mode));
            }
            return mode == RuntimeExecutionMode.STREAMING;
        } else {
            return false;
        }
    }

    private static boolean isAppendOnlyTable(Table table) {
        if (isStreamingMode(table)) {
            TableEnvironmentImpl tableEnv =
                    (TableEnvironmentImpl) ((TableImpl) table).getTableEnvironment();
            try {
                OutputConversionModifyOperation modifyOperation =
                        new OutputConversionModifyOperation(
                                table.getQueryOperation(),
                                TypeConversions.fromLegacyInfoToDataType(
                                        TypeExtractor.createTypeInfo(Row.class)),
                                OutputConversionModifyOperation.UpdateMode.APPEND);
                tableEnv.getPlanner().translate(Collections.singletonList(modifyOperation));
            } catch (Throwable t) {
                if (t.getMessage().contains("doesn't support consuming update")
                        || t.getMessage().contains("Table is not an append-only table")) {
                    return false;
                } else {
                    throw new RuntimeException(
                            "Failed to determine whether the given table is append only.", t);
                }
            }
        }
        return true;
    }

    /**
     * A custom iterator to bypass the Py4J Java collection as the next method of
     * py4j.java_collections.JavaIterator will eat all the exceptions thrown in Java which makes it
     * difficult to debug in case of errors.
     */
    private interface CustomIterator {
        boolean hasNext();

        T next();
    }

    private static class LogicalTypeToArrowTypeConverter
            extends LogicalTypeDefaultVisitor {

        private static final LogicalTypeToArrowTypeConverter INSTANCE =
                new LogicalTypeToArrowTypeConverter();

        @Override
        public ArrowType visit(TinyIntType tinyIntType) {
            return new ArrowType.Int(8, true);
        }

        @Override
        public ArrowType visit(SmallIntType smallIntType) {
            return new ArrowType.Int(2 * 8, true);
        }

        @Override
        public ArrowType visit(IntType intType) {
            return new ArrowType.Int(4 * 8, true);
        }

        @Override
        public ArrowType visit(BigIntType bigIntType) {
            return new ArrowType.Int(8 * 8, true);
        }

        @Override
        public ArrowType visit(BooleanType booleanType) {
            return ArrowType.Bool.INSTANCE;
        }

        @Override
        public ArrowType visit(FloatType floatType) {
            return new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE);
        }

        @Override
        public ArrowType visit(DoubleType doubleType) {
            return new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE);
        }

        @Override
        public ArrowType visit(CharType varCharType) {
            return ArrowType.Utf8.INSTANCE;
        }

        @Override
        public ArrowType visit(VarCharType varCharType) {
            return ArrowType.Utf8.INSTANCE;
        }

        @Override
        public ArrowType visit(BinaryType varCharType) {
            return new ArrowType.FixedSizeBinary(varCharType.getLength());
        }

        @Override
        public ArrowType visit(VarBinaryType varCharType) {
            return ArrowType.Binary.INSTANCE;
        }

        @Override
        public ArrowType visit(DecimalType decimalType) {
            return new ArrowType.Decimal(decimalType.getPrecision(), decimalType.getScale());
        }

        @Override
        public ArrowType visit(DateType dateType) {
            return new ArrowType.Date(DateUnit.DAY);
        }

        @Override
        public ArrowType visit(TimeType timeType) {
            if (timeType.getPrecision() == 0) {
                return new ArrowType.Time(TimeUnit.SECOND, 32);
            } else if (timeType.getPrecision() >= 1 && timeType.getPrecision() <= 3) {
                return new ArrowType.Time(TimeUnit.MILLISECOND, 32);
            } else if (timeType.getPrecision() >= 4 && timeType.getPrecision() <= 6) {
                return new ArrowType.Time(TimeUnit.MICROSECOND, 64);
            } else {
                return new ArrowType.Time(TimeUnit.NANOSECOND, 64);
            }
        }

        @Override
        public ArrowType visit(LocalZonedTimestampType localZonedTimestampType) {
            if (localZonedTimestampType.getPrecision() == 0) {
                return new ArrowType.Timestamp(TimeUnit.SECOND, null);
            } else if (localZonedTimestampType.getPrecision() >= 1
                    && localZonedTimestampType.getPrecision() <= 3) {
                return new ArrowType.Timestamp(TimeUnit.MILLISECOND, null);
            } else if (localZonedTimestampType.getPrecision() >= 4
                    && localZonedTimestampType.getPrecision() <= 6) {
                return new ArrowType.Timestamp(TimeUnit.MICROSECOND, null);
            } else {
                return new ArrowType.Timestamp(TimeUnit.NANOSECOND, null);
            }
        }

        @Override
        public ArrowType visit(TimestampType timestampType) {
            if (timestampType.getPrecision() == 0) {
                return new ArrowType.Timestamp(TimeUnit.SECOND, null);
            } else if (timestampType.getPrecision() >= 1 && timestampType.getPrecision() <= 3) {
                return new ArrowType.Timestamp(TimeUnit.MILLISECOND, null);
            } else if (timestampType.getPrecision() >= 4 && timestampType.getPrecision() <= 6) {
                return new ArrowType.Timestamp(TimeUnit.MICROSECOND, null);
            } else {
                return new ArrowType.Timestamp(TimeUnit.NANOSECOND, null);
            }
        }

        @Override
        public ArrowType visit(ArrayType arrayType) {
            return ArrowType.List.INSTANCE;
        }

        @Override
        public ArrowType visit(RowType rowType) {
            return ArrowType.Struct.INSTANCE;
        }

        @Override
        public ArrowType visit(MapType mapType) {
            return new ArrowType.Map(false);
        }

        @Override
        public ArrowType visit(NullType nullType) {
            return ArrowType.Null.INSTANCE;
        }

        @Override
        protected ArrowType defaultMethod(LogicalType logicalType) {
            if (logicalType instanceof LegacyTypeInformationType) {
                Class typeClass =
                        ((LegacyTypeInformationType) logicalType)
                                .getTypeInformation()
                                .getTypeClass();
                if (typeClass == BigDecimal.class) {
                    // Because we can't get precision and scale from legacy BIG_DEC_TYPE_INFO,
                    // we set the precision and scale to default value compatible with python.
                    return new ArrowType.Decimal(38, 18);
                }
            }
            throw new UnsupportedOperationException(
                    String.format(
                            "Python vectorized UDF doesn't support logical type %s currently.",
                            logicalType.asSummaryString()));
        }
    }

    private static int getPrecision(DecimalVector decimalVector) {
        int precision = -1;
        try {
            java.lang.reflect.Field precisionField =
                    decimalVector.getClass().getDeclaredField("precision");
            precisionField.setAccessible(true);
            precision = (int) precisionField.get(decimalVector);
        } catch (NoSuchFieldException | IllegalAccessException e) {
            // should not happen, ignore
        }
        return precision;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy