All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.aliyun.datahub.client.impl.batch.arrow.ArrowDeserializer Maven / Gradle / Ivy

The newest version!
package com.aliyun.datahub.client.impl.batch.arrow;

import com.aliyun.datahub.client.exception.DatahubClientException;
import com.aliyun.datahub.client.impl.batch.BatchConstants;
import com.aliyun.datahub.client.impl.batch.BatchDeserializer;
import com.aliyun.datahub.client.impl.batch.header.BatchHeader;
import com.aliyun.datahub.client.model.*;
import org.apache.arrow.vector.*;
import org.apache.arrow.vector.complex.MapVector;
import org.apache.arrow.vector.complex.impl.UnionMapReader;
import org.apache.arrow.vector.ipc.ReadChannel;
import org.apache.arrow.vector.ipc.message.ArrowMessage;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.ipc.message.MessageChannelReader;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.pojo.Schema;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.InputStream;
import java.math.BigDecimal;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;

public class ArrowDeserializer extends BatchDeserializer {
    private final static Logger LOGGER = LoggerFactory.getLogger(ArrowDeserializer.class);

    @Override
    public List deserializeRecord(InputStream inputStream, BatchHeader header) {
        try {
            VectorSchemaRoot root = deserializeByChannel(inputStream, header.getSchemaVersion());
            List recordEntryList = convertRecord(root, header.getSchemaVersion());
            root.close();
            return recordEntryList;
        } catch (Exception e) {
            LOGGER.error("Deserialize arrow record failed", e);
            throw new DatahubClientException(e.getMessage());
        }
    }

    private VectorSchemaRoot deserializeByChannel(InputStream inputStream, int schemaVersion) throws Exception {
        RecordSchema batchSchema = getSchema(schemaVersion);
        Schema arrowSchema = ArrowSchemaCache.getSchema(batchSchema);
        // arrow反序列化使用对象cache反倒会慢,所以直接创建
        VectorSchemaRoot root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.getBufferAllocator());

        ReadableByteChannel channel = Channels.newChannel(inputStream);
        ReadChannel readChannel = new ReadChannel(channel);

        try (MessageChannelReader messageReader = new MessageChannelReader(readChannel, ArrowUtils.getBufferAllocator());
             ArrowMessage deserializeMessageBatch = MessageSerializer.deserializeMessageBatch(messageReader)) {
            if (deserializeMessageBatch == null) {
                LOGGER.error("Deserialize arrow batch failed");
                throw new DatahubClientException("Deserialize arrow batch failed");
            }

            VectorLoader loader = new VectorLoader(root);
            loader.load((ArrowRecordBatch) deserializeMessageBatch);
            return root;
        }
    }

    public List convertRecord(VectorSchemaRoot root, int schemaVersion) {
        RecordSchema dhSchema = getSchema(schemaVersion);
        List recordEntryList = new ArrayList<>(root.getRowCount());

        for (int rowIdx = 0; rowIdx < root.getRowCount(); ++rowIdx) {
            RecordEntry entry = new RecordEntry();
            if (dhSchema != null) {
                TupleRecordData data = new TupleRecordData(dhSchema);
                setColumnValue(root, data, rowIdx);
                entry.setRecordData(data);
            } else {
                VarBinaryVector vector = ((VarBinaryVector) root.getVector(BatchConstants.BLOB_COLUMN_NAME));
                entry.setRecordData(new BlobRecordData(vector.get(rowIdx)));
            }

            setAttribute(root, entry, rowIdx);

            // 因为binary的反序列化schemaVersion可能不一致,所以这个需要放到子类中
            // 如果可以保证binary中的schemaVersion全部一致,这个逻辑可以放到基类中
            entry.innerSetSegmentInfo(schemaVersion, 0, rowIdx);
            recordEntryList.add(entry);
        }
        return recordEntryList;
    }

    private void setColumnValue(VectorSchemaRoot root, TupleRecordData data, int rowIdx) {
        for (int colIdx = 0; colIdx < data.getRecordSchema().getFields().size(); ++colIdx) {
            if (root.getVector(colIdx).isNull(rowIdx)) {
                continue;
            }

            FieldType type = data.getRecordSchema().getField(colIdx).getType();
            switch (type) {
                case BOOLEAN:
                    int bv = ((BitVector) root.getVector(colIdx)).get(rowIdx);
                    data.setField(colIdx, bv == 1);
                    break;
                case TINYINT:
                    data.setField(colIdx, ((TinyIntVector) root.getVector(colIdx)).get(rowIdx));
                    break;
                case SMALLINT:
                    data.setField(colIdx, ((SmallIntVector) root.getVector(colIdx)).get(rowIdx));
                    break;
                case INTEGER:
                    data.setField(colIdx, ((IntVector) root.getVector(colIdx)).get(rowIdx));
                    break;
                case BIGINT:
                case TIMESTAMP:
                    data.setField(colIdx, ((BigIntVector) root.getVector(colIdx)).get(rowIdx));
                    break;
                case FLOAT:
                    data.setField(colIdx, ((Float4Vector) root.getVector(colIdx)).get(rowIdx));
                    break;
                case DOUBLE:
                    data.setField(colIdx, ((Float8Vector) root.getVector(colIdx)).get(rowIdx));
                    break;
                case STRING:
                case JSON:
                    String sVal = new String(((VarCharVector) root.getVector(colIdx)).get(rowIdx), StandardCharsets.UTF_8);
                    data.setField(colIdx, sVal);
                    break;
                case DECIMAL:
                    String dVal = new String(((VarCharVector) root.getVector(colIdx)).get(rowIdx), StandardCharsets.UTF_8);
                    data.setField(colIdx, new BigDecimal(dVal));
                    break;
                default:
                    throw new IllegalStateException("Unknown value type: " + type);
            }
        }
    }

    private void setAttribute(VectorSchemaRoot root, RecordEntry recordEntry, int rowIdx) {
        MapVector vector = ((MapVector) root.getVector(BatchConstants.ATTRIBUTE_COLUMN_NAME));
        UnionMapReader reader = vector.getReader();
        reader.setPosition(rowIdx);
        while (reader.next()) {
            recordEntry.addAttribute(reader.key().readText().toString(), reader.value().readText().toString());
        }
    }
}







© 2015 - 2024 Weber Informatics LLC | Privacy Policy