com.aliyun.datahub.client.impl.batch.arrow.ArrowDeserializer Maven / Gradle / Ivy
The newest version!
package com.aliyun.datahub.client.impl.batch.arrow;
import com.aliyun.datahub.client.exception.DatahubClientException;
import com.aliyun.datahub.client.impl.batch.BatchConstants;
import com.aliyun.datahub.client.impl.batch.BatchDeserializer;
import com.aliyun.datahub.client.impl.batch.header.BatchHeader;
import com.aliyun.datahub.client.model.*;
import org.apache.arrow.vector.*;
import org.apache.arrow.vector.complex.MapVector;
import org.apache.arrow.vector.complex.impl.UnionMapReader;
import org.apache.arrow.vector.ipc.ReadChannel;
import org.apache.arrow.vector.ipc.message.ArrowMessage;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.ipc.message.MessageChannelReader;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.pojo.Schema;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.InputStream;
import java.math.BigDecimal;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
public class ArrowDeserializer extends BatchDeserializer {
private final static Logger LOGGER = LoggerFactory.getLogger(ArrowDeserializer.class);
@Override
public List deserializeRecord(InputStream inputStream, BatchHeader header) {
try {
VectorSchemaRoot root = deserializeByChannel(inputStream, header.getSchemaVersion());
List recordEntryList = convertRecord(root, header.getSchemaVersion());
root.close();
return recordEntryList;
} catch (Exception e) {
LOGGER.error("Deserialize arrow record failed", e);
throw new DatahubClientException(e.getMessage());
}
}
private VectorSchemaRoot deserializeByChannel(InputStream inputStream, int schemaVersion) throws Exception {
RecordSchema batchSchema = getSchema(schemaVersion);
Schema arrowSchema = ArrowSchemaCache.getSchema(batchSchema);
// arrow反序列化使用对象cache反倒会慢,所以直接创建
VectorSchemaRoot root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.getBufferAllocator());
ReadableByteChannel channel = Channels.newChannel(inputStream);
ReadChannel readChannel = new ReadChannel(channel);
try (MessageChannelReader messageReader = new MessageChannelReader(readChannel, ArrowUtils.getBufferAllocator());
ArrowMessage deserializeMessageBatch = MessageSerializer.deserializeMessageBatch(messageReader)) {
if (deserializeMessageBatch == null) {
LOGGER.error("Deserialize arrow batch failed");
throw new DatahubClientException("Deserialize arrow batch failed");
}
VectorLoader loader = new VectorLoader(root);
loader.load((ArrowRecordBatch) deserializeMessageBatch);
return root;
}
}
public List convertRecord(VectorSchemaRoot root, int schemaVersion) {
RecordSchema dhSchema = getSchema(schemaVersion);
List recordEntryList = new ArrayList<>(root.getRowCount());
for (int rowIdx = 0; rowIdx < root.getRowCount(); ++rowIdx) {
RecordEntry entry = new RecordEntry();
if (dhSchema != null) {
TupleRecordData data = new TupleRecordData(dhSchema);
setColumnValue(root, data, rowIdx);
entry.setRecordData(data);
} else {
VarBinaryVector vector = ((VarBinaryVector) root.getVector(BatchConstants.BLOB_COLUMN_NAME));
entry.setRecordData(new BlobRecordData(vector.get(rowIdx)));
}
setAttribute(root, entry, rowIdx);
// 因为binary的反序列化schemaVersion可能不一致,所以这个需要放到子类中
// 如果可以保证binary中的schemaVersion全部一致,这个逻辑可以放到基类中
entry.innerSetSegmentInfo(schemaVersion, 0, rowIdx);
recordEntryList.add(entry);
}
return recordEntryList;
}
private void setColumnValue(VectorSchemaRoot root, TupleRecordData data, int rowIdx) {
for (int colIdx = 0; colIdx < data.getRecordSchema().getFields().size(); ++colIdx) {
if (root.getVector(colIdx).isNull(rowIdx)) {
continue;
}
FieldType type = data.getRecordSchema().getField(colIdx).getType();
switch (type) {
case BOOLEAN:
int bv = ((BitVector) root.getVector(colIdx)).get(rowIdx);
data.setField(colIdx, bv == 1);
break;
case TINYINT:
data.setField(colIdx, ((TinyIntVector) root.getVector(colIdx)).get(rowIdx));
break;
case SMALLINT:
data.setField(colIdx, ((SmallIntVector) root.getVector(colIdx)).get(rowIdx));
break;
case INTEGER:
data.setField(colIdx, ((IntVector) root.getVector(colIdx)).get(rowIdx));
break;
case BIGINT:
case TIMESTAMP:
data.setField(colIdx, ((BigIntVector) root.getVector(colIdx)).get(rowIdx));
break;
case FLOAT:
data.setField(colIdx, ((Float4Vector) root.getVector(colIdx)).get(rowIdx));
break;
case DOUBLE:
data.setField(colIdx, ((Float8Vector) root.getVector(colIdx)).get(rowIdx));
break;
case STRING:
case JSON:
String sVal = new String(((VarCharVector) root.getVector(colIdx)).get(rowIdx), StandardCharsets.UTF_8);
data.setField(colIdx, sVal);
break;
case DECIMAL:
String dVal = new String(((VarCharVector) root.getVector(colIdx)).get(rowIdx), StandardCharsets.UTF_8);
data.setField(colIdx, new BigDecimal(dVal));
break;
default:
throw new IllegalStateException("Unknown value type: " + type);
}
}
}
private void setAttribute(VectorSchemaRoot root, RecordEntry recordEntry, int rowIdx) {
MapVector vector = ((MapVector) root.getVector(BatchConstants.ATTRIBUTE_COLUMN_NAME));
UnionMapReader reader = vector.getReader();
reader.setPosition(rowIdx);
while (reader.next()) {
recordEntry.addAttribute(reader.key().readText().toString(), reader.value().readText().toString());
}
}
}