All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.clickzetta.platform.arrow.ArrowRecordBatchWriter Maven / Gradle / Ivy

There is a newer version: 2.0.0
Show newest version
package com.clickzetta.platform.arrow;

import com.clickzetta.platform.arrow.writer.*;
import com.clickzetta.platform.client.api.ArrowRow;
import com.clickzetta.platform.operator.Bytes;
import cz.proto.DataTypeCategory;
import cz.proto.ingestion.Ingestion;
import cz.proto.ingestion.v2.IngestionV2;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.OutOfMemoryException;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.memory.UnpooledRootAllocator;
import org.apache.arrow.vector.*;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.complex.MapVector;
import org.apache.arrow.vector.complex.StructVector;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.types.pojo.ArrowType;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.util.*;

public class ArrowRecordBatchWriter {

  private ArrowTable arrowTable;
  private ArrowSchema arrowSchema;
  private List isSetBitMaps;
  private List operationTypes;

  // arrow memory & buffer allocator.
  private BufferAllocator allocator;
  private VectorSchemaRoot root;

  private List fieldWriters;
  private int count = 0;

  public ArrowRecordBatchWriter(ArrowTable arrowTable) {
    this(arrowTable, true, -1);
  }

  public ArrowRecordBatchWriter(ArrowTable arrowTable, boolean pooledAllocator, int targetRowSize) {
    this.arrowTable = arrowTable;
    this.arrowSchema = arrowTable.getArrowSchema();
    this.isSetBitMaps = targetRowSize != -1 ? new ArrayList<>(targetRowSize) : new ArrayList<>();
    this.operationTypes = targetRowSize != -1 ? new ArrayList<>(targetRowSize) : new ArrayList<>();
    this.allocator = pooledAllocator ? new RootAllocator() : new UnpooledRootAllocator();
    this.root = VectorSchemaRoot.create(arrowSchema.getSchema(), allocator);
    initFieldVectorAllocate(targetRowSize);
  }

  private void initFieldVectorAllocate(int initialRowSize) {
    this.fieldWriters = new ArrayList<>(this.arrowSchema.getColumnCount());
    try {
      for (int index = 0; index < this.arrowSchema.getColumnCount(); index++) {
        ValueVector vector = root.getVector(index);
        if (initialRowSize != -1) {
          // try to set initial capacity for arrow value vector.
          vector.setInitialCapacity(initialRowSize);
        }
        vector.allocateNew();
        this.fieldWriters.add(createFieldWriter(vector));
      }
    } catch (OutOfMemoryException oom) {
      close();
      throw oom;
    }
  }

  private ArrowFieldWriter createFieldWriter(ValueVector vector) {
    ArrowType arrowType = vector.getField().getType();
    DataTypeCategory dataTypeCategory = ArrowSchemaConvert.fromArrowType(arrowType);
    switch (dataTypeCategory) {
      case BOOLEAN:
        return new BooleanWriter((BitVector) vector);
      case INT8:
        return new ByteWriter((TinyIntVector) vector);
      case INT16:
        return new ShortWriter((SmallIntVector) vector);
      case INT32:
        return new IntegerWriter((IntVector) vector);
      case INT64:
        return new LongWriter((BigIntVector) vector);
      case FLOAT32:
        return new FloatWriter((Float4Vector) vector);
      case FLOAT64:
        return new DoubleWriter((Float8Vector) vector);
      case DECIMAL:
        return new DecimalWriter((DecimalVector) vector);
      case VARCHAR:
      case CHAR:
      case STRING:
        return new StringWriter((VarCharVector) vector);
      case BINARY:
        return new BinaryWriter((VarBinaryVector) vector);
      case DATE:
        return new DateWriter((DateDayVector) vector);
      case TIMESTAMP_LTZ:
        return new TimestampWriter((TimeStampVector) vector);
      case ARRAY: {
        ListVector listVector = (ListVector) vector;
        ArrowFieldWriter elementVector = createFieldWriter(listVector.getDataVector());
        return new ArrayWriter(listVector, elementVector);
      }
      case MAP: {
        MapVector mapVector = (MapVector) vector;
        StructVector structVector = (StructVector) mapVector.getDataVector();
        ArrowFieldWriter keyWriter = createFieldWriter(structVector.getChild(MapVector.KEY_NAME));
        ArrowFieldWriter valueWriter = createFieldWriter(structVector.getChild(MapVector.VALUE_NAME));
        return new MapWriter(mapVector, structVector, keyWriter, valueWriter);
      }
      case STRUCT: {
        StructVector structVector = (StructVector) vector;
        List childFieldWriters = new ArrayList<>();
        for (int i = 0; i < structVector.size(); i++) {
          childFieldWriters.add(createFieldWriter(structVector.getChildByOrdinal(i)));
        }
        return new StructWriter(structVector, childFieldWriters);
      }
      default:
        throw new UnsupportedOperationException(String.format("not support valueVector with dataTypeCategory: %s", dataTypeCategory));
    }
  }

  public void write(ArrowRow row) throws IOException {
    if (this.arrowTable.getIgsTableType() != Ingestion.IGSTableType.ACID) {
      if (!this.operationTypes.contains(row.getOperationType())) {
        this.operationTypes.add(row.getOperationType());
      }
    } else {
      this.operationTypes.add(row.getOperationType());
    }

    // if some not key columns is not nullable. delete Op can not only encode with key row.
    // so all Ops will encode all columns.
    this.isSetBitMaps.add(row.getColumnsBitSet());
    int index = 0;
    while (index < this.fieldWriters.size()) {
      // first check primary|sort|cluster|partition key not nullable & arrow row value must set.
      if (arrowSchema.getKeyColumnsIndex().contains(index) && row.isNullAt(index)) {
        throw new IOException(
            String.format("primary|sort|cluster|partition key [%s] must set value in row.",
                arrowSchema.getColumnByIndex(index).getName()));
      }
      this.fieldWriters.get(index).write(row, index, row.getColumnsBitSet().get(index));
      index += 1;
    }
    count++;
  }

  public void finish() {
    root.setRowCount(count);
    fieldWriters.forEach(ArrowFieldWriter::finish);
  }

  public void reset() {
    root.setRowCount(0);
    count = 0;
    fieldWriters.forEach(ArrowFieldWriter::reset);
    isSetBitMaps.clear();
    operationTypes.clear();
  }

  public byte[] encodeIsSetBitMaps() throws IOException {
    int totalByteSize = 0;
    ByteBuffer byteBuffer = ByteBuffer.allocate(this.isSetBitMaps.size() * Bytes.getBitSetSize(this.arrowSchema.getColumnCount()));
    for (int i = 0; i < isSetBitMaps.size(); i++) {
      BitSet bitSet = isSetBitMaps.get(i);
      if (bitSet != null && !bitSet.isEmpty()) {
        byte[] bbSize = Bytes.fromBitSet(bitSet, this.arrowSchema.getColumnCount());
        byteBuffer.put(bbSize);
        totalByteSize += bbSize.length;
      }
    }
    byteBuffer.flip();
    byte[] bytes = new byte[totalByteSize];
    byteBuffer.get(bytes);
    return bytes;
  }

  public byte[] encodeArrowRow() throws IOException {
    try (ByteArrayOutputStream out = new ByteArrayOutputStream();
         ArrowStreamWriter writer = new ArrowStreamWriter(root, null, Channels.newChannel(out))) {
      writer.start();
      writer.writeBatch();
      byte[] bytes = out.toByteArray();
      return bytes;
    } catch (Throwable t) {
      throw new IOException(t);
    } finally {
      close();
    }
  }

  public List getOperationTypes() {
    return operationTypes;
  }

  public VectorSchemaRoot getRoot() {
    return root;
  }

  public int getRowCount() {
    return count;
  }

  public void close() {
    try {
      if (root != null) {
        root.clear();
        root.close();
      }
      if (allocator != null) {
        allocator.close();
      }
    } catch (Throwable t) {
      // ignore.
    } finally {
      root = null;
      allocator = null;
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy