com.aliyun.datahub.client.impl.batch.arrow.ArrowSerializer Maven / Gradle / Ivy
The newest version!
package com.aliyun.datahub.client.impl.batch.arrow;
import com.aliyun.datahub.client.exception.DatahubClientException;
import com.aliyun.datahub.client.exception.InvalidParameterException;
import com.aliyun.datahub.client.impl.batch.BatchConstants;
import com.aliyun.datahub.client.impl.batch.BatchSerializer;
import com.aliyun.datahub.client.impl.batch.BatchType;
import com.aliyun.datahub.client.impl.batch.BatchUtils;
import com.aliyun.datahub.client.impl.batch.header.BatchHeader;
import com.aliyun.datahub.client.impl.batch.header.BatchHeaderV1;
import com.aliyun.datahub.client.model.*;
import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.vector.*;
import org.apache.arrow.vector.complex.MapVector;
import org.apache.arrow.vector.complex.impl.UnionMapWriter;
import org.apache.arrow.vector.ipc.WriteChannel;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.math.BigDecimal;
import java.nio.channels.Channels;
import java.nio.channels.WritableByteChannel;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
public class ArrowSerializer extends BatchSerializer {
private final static Logger LOGGER = LoggerFactory.getLogger(ArrowSerializer.class);
private ArrowBuf attrBuffer;
@Override
public byte[] serializeRecord(List recordList) {
try {
RecordSchema recordSchema = getRecordSchema(recordList.get(0));
ArrowObjectCache.ArrowCacheKey key = new ArrowObjectCache.ArrowCacheKey(getProjectName(), getTopicName(), recordSchema);
// 这里使用cache性能会提升好几倍
VectorSchemaRoot root = ArrowObjectCache.borrowArrowObject(key);
convertRecord(recordList, root);
byte[] buf = serializeByChannel(root);
attrBuffer.close();
ArrowObjectCache.returnArrowObject(key, root);
return buf;
} catch (Exception e) {
LOGGER.error("Serialize arrow record failed", e);
throw new DatahubClientException(e.getMessage());
}
}
@Override
protected BatchHeader getHeader() {
BatchHeaderV1 header = new BatchHeaderV1();
header.setDataType(BatchType.ARROW);
return header;
}
public void convertRecord(List recordList, VectorSchemaRoot root) throws Exception {
if (recordList.get(0).getRecordData() instanceof TupleRecordData) {
convertTupleData(recordList, root);
} else {
convertBlobData(recordList, root);
}
convertAttribute(recordList, root);
root.setRowCount(recordList.size());
}
private void convertTupleData(List recordList, VectorSchemaRoot root) {
for (int idx = 0; idx < recordList.size(); ++idx) {
RecordEntry entry = recordList.get(idx);
setColumnValue((TupleRecordData) entry.getRecordData(), root, idx);
}
}
private void setColumnValue(TupleRecordData data, VectorSchemaRoot root, int rowIdx) {
for (int colIdx = 0; colIdx < data.getRecordSchema().getFields().size(); ++colIdx) {
Object obj = data.getField(colIdx);
if (obj == null) {
root.getVector(colIdx).setNull(rowIdx);
continue;
}
FieldType type = data.getRecordSchema().getField(colIdx).getType();
switch (type) {
case BOOLEAN:
((BitVector) root.getVector(colIdx)).setSafe(rowIdx, (boolean) obj ? 1 : 0);
break;
case TINYINT:
((TinyIntVector) root.getVector(colIdx)).setSafe(rowIdx, BatchUtils.getByteValue(obj));
break;
case SMALLINT:
((SmallIntVector) root.getVector(colIdx)).setSafe(rowIdx, BatchUtils.getShortValue(obj));
break;
case INTEGER:
((IntVector) root.getVector(colIdx)).setSafe(rowIdx, BatchUtils.getIntValue(obj));
break;
case BIGINT:
case TIMESTAMP:
((BigIntVector) root.getVector(colIdx)).setSafe(rowIdx, BatchUtils.getLongValue(obj));
break;
case FLOAT:
((Float4Vector) root.getVector(colIdx)).setSafe(rowIdx, BatchUtils.getFloatValue(obj));
break;
case DOUBLE:
((Float8Vector) root.getVector(colIdx)).setSafe(rowIdx, BatchUtils.getDoubleValue(obj));
break;
case STRING:
case JSON:
((VarCharVector) root.getVector(colIdx)).setSafe(rowIdx, obj.toString().getBytes(StandardCharsets.UTF_8));
break;
case DECIMAL:
((VarCharVector) root.getVector(colIdx)).setSafe(rowIdx, ((BigDecimal) obj).toPlainString().getBytes(StandardCharsets.UTF_8));
break;
default:
throw new InvalidParameterException("Unknown value type: " + type);
}
}
}
private void convertBlobData(List recordList, VectorSchemaRoot root) {
VarBinaryVector vector = ((VarBinaryVector) root.getVector(BatchConstants.BLOB_COLUMN_NAME));
for (int idx = 0; idx < recordList.size(); ++idx) {
RecordEntry entry = recordList.get(idx);
BlobRecordData data = (BlobRecordData) entry.getRecordData();
vector.setSafe(idx, data.getData());
}
}
private ArrowBuf preAllocateAttributeBuffer(List recordList) {
long bytes = 0;
for (RecordEntry recordEntry : recordList) {
Map attrs = recordEntry.getAttributes();
if (attrs == null || attrs.isEmpty()) {
continue;
}
for (Map.Entry entry : attrs.entrySet()) {
bytes += entry.getKey().length() + entry.getValue().length();
}
}
return ArrowUtils.getBufferAllocator().buffer(bytes);
}
private void convertAttribute(List recordList, VectorSchemaRoot root) {
attrBuffer = preAllocateAttributeBuffer(recordList);
UnionMapWriter mapWriter = ((MapVector) root.getVector(BatchConstants.ATTRIBUTE_COLUMN_NAME)).getWriter();
for (RecordEntry entry : recordList) {
mapWriter.startMap();
setAttribute(entry.getAttributes(), mapWriter, attrBuffer);
mapWriter.endMap();
}
}
private void setAttribute(Map attrs, UnionMapWriter mapWriter, ArrowBuf buffer) {
if (attrs == null || attrs.isEmpty()) {
return;
}
for (Map.Entry entry : attrs.entrySet()) {
mapWriter.startEntry();
byte[] keyBytes = entry.getKey().getBytes(StandardCharsets.UTF_8);
int start = (int) buffer.writerIndex();
buffer.writeBytes(keyBytes);
mapWriter.key().varChar().writeVarChar(start, start + keyBytes.length, buffer);
byte[] valBytes = entry.getValue().getBytes(StandardCharsets.UTF_8);
start = (int) buffer.writerIndex();
buffer.writeBytes(valBytes);
mapWriter.value().varChar().writeVarChar(start, start + valBytes.length, buffer);
mapWriter.endEntry();
}
}
private byte[] serializeByChannel(VectorSchemaRoot root) throws IOException {
ByteArrayOutputStream stream = new ByteArrayOutputStream(1024 * 1024);
WritableByteChannel channel = Channels.newChannel(stream);
VectorUnloader loader = new VectorUnloader(root);
try (ArrowRecordBatch recordBatch = loader.getRecordBatch();
WriteChannel writeChannel = new WriteChannel(channel);
) {
MessageSerializer.serialize(writeChannel, recordBatch);
return stream.toByteArray();
}
}
}