All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.aliyun.odps.table.arrow.readers.ArrowBatchNonReusedReader Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.aliyun.odps.table.arrow.readers;

import com.aliyun.odps.table.arrow.ArrowReader;
import org.apache.arrow.flatbuf.MessageHeader;
import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorLoader;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.compression.CompressionCodec;
import org.apache.arrow.vector.compression.NoCompressionCodec;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.ipc.ReadChannel;
import org.apache.arrow.vector.ipc.message.*;
import org.apache.arrow.vector.types.MetadataVersion;
import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.DictionaryUtility;
import org.apache.arrow.vector.util.VectorBatchAppender;
import org.apache.arrow.vector.validate.MetadataV4UnionChecker;

import java.io.IOException;
import java.io.InputStream;
import java.nio.channels.Channels;
import java.util.*;

public class ArrowBatchNonReusedReader implements ArrowReader {

    private final BufferAllocator allocator;
    private final MessageChannelReader messageReader;
    private final CompressionCodec.Factory compressionFactory;

    private boolean initialized = false;
    private int loadedDictionaryCount;
    private Map dictionaries;
    private VectorSchemaRoot currentBatch;
    private Schema originalSchema;
    private List fieldList;

    public ArrowBatchNonReusedReader(InputStream is,
                                     BufferAllocator allocator) {
        this(is, allocator, NoCompressionCodec.Factory.INSTANCE);
    }

    public ArrowBatchNonReusedReader(InputStream is,
                                     BufferAllocator allocator,
                                     CompressionCodec.Factory compressionFactory) {
        this.allocator = allocator;
        this.compressionFactory = compressionFactory;
        this.messageReader = new MessageChannelReader(new ReadChannel(Channels.newChannel(is)),
                this.allocator);
        this.currentBatch = null;
    }

    @Override
    public VectorSchemaRoot getCurrentValue() {
        return currentBatch;
    }

    @Override
    public boolean nextBatch() throws IOException {
        boolean hasNext = loadNextBatch();
        if (!hasNext) {
            this.currentBatch = null;
        }
        return hasNext;
    }

    @Override
    public void close() throws IOException {
        if (initialized) {
            for (Dictionary dictionary : dictionaries.values()) {
                dictionary.getVector().close();
            }
        }
        messageReader.close();
    }

    @Override
    public long bytesRead() {
        return messageReader.bytesRead();
    }


    /**
     * Load the next ArrowRecordBatch to the vector schema root if available.
     *
     * @return true if a batch was read, false on EOS
     * @throws IOException on error
     */
    private boolean loadNextBatch() throws IOException {
        prepareLoadNextBatch();
        MessageResult result = messageReader.readNext();

        // Reached EOS
        if (result == null) {
            return false;
        }

        if (result.getMessage().headerType() == MessageHeader.RecordBatch) {
            ArrowBuf bodyBuffer = result.getBodyBuffer();

            // For zero-length batches, need an empty buffer to deserialize the batch
            if (bodyBuffer == null) {
                bodyBuffer = allocator.getEmpty();
            }

            VectorLoader loader = new VectorLoader(currentBatch, compressionFactory);
            ArrowRecordBatch batch = MessageSerializer.deserializeRecordBatch(result.getMessage(), bodyBuffer);
            try {
                loader.load(batch);
            } finally {
                batch.close();
            }
            checkDictionaries();
            return true;
        } else if (result.getMessage().headerType() == MessageHeader.DictionaryBatch) {
            // if it's dictionary message, read dictionary message out and continue to read unless get a batch or eos.
            ArrowDictionaryBatch dictionaryBatch = readDictionary(result);
            loadDictionary(dictionaryBatch);
            loadedDictionaryCount++;
            return loadNextBatch();
        } else {
            throw new IOException("Expected RecordBatch or DictionaryBatch but header was " +
                    result.getMessage().headerType());
        }
    }


    /**
     * Ensure the reader has been initialized and reset the VectorSchemaRoot row count to 0.
     *
     * @throws IOException on error
     */
    private void prepareLoadNextBatch() throws IOException {
        if (!initialized) {
            initialize();
            this.initialized = true;
        }
        List vectors = new ArrayList<>(fieldList.size());
        for (Field field : fieldList) {
            vectors.add(field.createVector(allocator));
        }
        Schema schema = new Schema(fieldList, originalSchema.getCustomMetadata());
        this.currentBatch = new VectorSchemaRoot(schema, vectors, 0);
        currentBatch.setRowCount(0);
    }

    /**
     * Reads the schema and initializes the fieldList.
     */
    private void initialize() throws IOException {
        this.originalSchema = readSchema();
        this.fieldList = new ArrayList<>(originalSchema.getFields().size());
        Map dictionaries = new HashMap<>();

        // Convert fields with dictionaries to have the index type
        for (Field field : originalSchema.getFields()) {
            Field updated = DictionaryUtility.toMemoryFormat(field, allocator, dictionaries);
            this.fieldList.add(updated);
        }
        this.dictionaries = Collections.unmodifiableMap(dictionaries);
    }

    private Schema readSchema() throws IOException {
        MessageResult result = messageReader.readNext();

        if (result == null) {
            throw new IOException("Unexpected end of input. Missing schema.");
        }

        if (result.getMessage().headerType() != MessageHeader.Schema) {
            throw new IOException("Expected schema but header was " + result.getMessage().headerType());
        }

        final Schema schema = MessageSerializer.deserializeSchema(result.getMessage());
        MetadataV4UnionChecker.checkRead(schema, MetadataVersion.fromFlatbufID(result.getMessage().version()));
        return schema;
    }

    private ArrowDictionaryBatch readDictionary(MessageResult result) throws IOException {

        ArrowBuf bodyBuffer = result.getBodyBuffer();

        // For zero-length batches, need an empty buffer to deserialize the batch
        if (bodyBuffer == null) {
            bodyBuffer = allocator.getEmpty();
        }

        return MessageSerializer.deserializeDictionaryBatch(result.getMessage(), bodyBuffer);
    }

    private void loadDictionary(ArrowDictionaryBatch dictionaryBatch) {
        long id = dictionaryBatch.getDictionaryId();
        Dictionary dictionary = dictionaries.get(id);
        if (dictionary == null) {
            throw new IllegalArgumentException("Dictionary ID " + id + " not defined in schema");
        }
        FieldVector vector = dictionary.getVector();
        // if is deltaVector, concat it with non-delta vector with the same ID.
        if (dictionaryBatch.isDelta()) {
            try (FieldVector deltaVector = vector.getField().createVector(allocator)) {
                load(dictionaryBatch, deltaVector);
                VectorBatchAppender.batchAppend(vector, deltaVector);
            }
            return;
        }

        load(dictionaryBatch, vector);
    }

    private void load(ArrowDictionaryBatch dictionaryBatch, FieldVector vector) {
        VectorSchemaRoot root = new VectorSchemaRoot(
                Collections.singletonList(vector.getField()),
                Collections.singletonList(vector), 0);
        VectorLoader loader = new VectorLoader(root);
        try {
            loader.load(dictionaryBatch.getDictionary());
        } finally {
            dictionaryBatch.close();
        }
    }

    /**
     * When read a record batch, check whether its dictionaries are available.
     */
    private void checkDictionaries() throws IOException {
        // if all dictionaries are loaded, return.
        if (loadedDictionaryCount == dictionaries.size()) {
            return;
        }
        for (FieldVector vector : this.currentBatch.getFieldVectors()) {
            DictionaryEncoding encoding = vector.getField().getDictionary();
            if (encoding != null) {
                // if the dictionaries it needs is not available and the vector is not all null, something was wrong.
                if (!dictionaries.containsKey(encoding.getId()) && vector.getNullCount() < vector.getValueCount()) {
                    throw new IOException("The dictionary was not available, id was:" + encoding.getId());
                }
            }
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy