All Downloads are FREE. Search and download functionalities are using the official Maven repository.

tech.ytsaurus.spyt.format.batch.ArrowBatchReader.scala Maven / Gradle / Ivy

package tech.ytsaurus.spyt.format.batch

import org.apache.arrow.memory.{BufferAllocator, RootAllocator}
import org.apache.arrow.vector.dictionary.Dictionary
import org.apache.arrow.vector.ipc.ArrowStreamReader
import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch}
import org.slf4j.LoggerFactory
import tech.ytsaurus.core.tables.{ColumnSchema, ColumnValueType, TableSchema}
import tech.ytsaurus.spyt.serializers.SchemaConverter.MetadataFields
import tech.ytsaurus.spyt.wrapper.LogLazy
import tech.ytsaurus.spyt.serialization.IndexedDataType
import tech.ytsaurus.spyt.serializers.SchemaConverter
import tech.ytsaurus.spyt.wrapper.table.YtArrowInputStream

import scala.collection.JavaConverters._

class ArrowBatchReader(stream: YtArrowInputStream, schema: StructType,
                       ytSchema: TableSchema) extends BatchReaderBase with LogLazy {
  private val log = LoggerFactory.getLogger(getClass)

  private val indexedSchema = schema.fields.map(f => SchemaConverter.indexedDataType(f.dataType))

  private var _allocator: BufferAllocator = _
  private var _reader: ArrowStreamReader = _
  private var _root: VectorSchemaRoot = _
  private var _dictionaries: java.util.Map[java.lang.Long, Dictionary] = _
  private var _columnVectors: Array[ColumnVector] = _
  private var emptySchema = false

  initialize()

  private def initialize(): Unit = {
    if (stream.isEmptyPage) {
      emptySchema = true
    } else {
      updateReader()
      updateBatch()
    }
  }

  override protected def nextBatchInternal: Boolean = {
    if (emptySchema) {
      false
    } else {
      if (stream.isNextPage) updateReader()
      val batchLoaded = _reader.loadNextBatch()
      if (batchLoaded) {
        updateBatch()
        setNumRows(_root.getRowCount)
        true
      } else {
        closeReader()
        false
      }
    }
  }

  private def closeReader(): Unit = {
    Option(_reader).foreach(_.close(false))
    Option(_allocator).foreach(_.close())
  }

  override protected def finalRead(): Unit = {
    val bytes = new Array[Byte](9)
    val res = stream.read(bytes)
    val isAllowedBytes = bytes.forall(_ == 0) || (bytes.take(4).forall(_ == -1) && bytes.drop(4).forall(_ == 0))
    if (res > 8 || !isAllowedBytes) {
      throw new IllegalStateException(s"Final read failed." +
        s" Bytes read: $res; byte buffer: ${bytes.mkString("[", ", ", "]")}")
    }
  }

  override def close(): Unit = {
    stream.close()
  }

  private def updateReader(): Unit = {
    log.debugLazy(s"Update arrow reader, " +
      s"allocated ${Option(_allocator).map(_.getAllocatedMemory)}, " +
      s"peak allocated ${Option(_allocator).map(_.getPeakMemoryAllocation)}")
    closeReader()

    _allocator = new RootAllocator().newChildAllocator(s"arrow reader", 0, Long.MaxValue)
    _reader = new ArrowStreamReader(stream, _allocator)
    _root = _reader.getVectorSchemaRoot
    _dictionaries = _reader.getDictionaryVectors
  }

  private def createArrowColumnVector(vector: FieldVector, dataType: IndexedDataType,
                                      columnType: ColumnValueType): ArrowColumnVector = {
    val isNullVector = vector.getNullCount == vector.getValueCount
    val dict = Option(vector.getField.getDictionary).flatMap { encoding =>
      if (_dictionaries.containsKey(encoding.getId)) {
        Some(_dictionaries.get(encoding.getId))
      } else if (!isNullVector) {
        throw new UnsupportedOperationException
      } else None
    }
    new ArrowColumnVector(dataType, vector, dict, isNullVector, columnType)
  }

  private def updateBatch(): Unit = {
    log.traceLazy(s"Read arrow batch, " +
      s"allocated ${Option(_allocator).map(_.getAllocatedMemory)}, " +
      s"peak allocated ${Option(_allocator).map(_.getPeakMemoryAllocation)}")

    _columnVectors = new Array[ColumnVector](schema.fields.length)

    val arrowSchema = _root.getSchema.getFields.asScala.map(_.getName)
    val arrowVectors = arrowSchema.zip(_root.getFieldVectors.asScala).toMap
    schema.fields.zipWithIndex.foreach { case (field, index) =>
      val dataType = indexedSchema(index)
      val fieldName = MetadataFields.getOriginalName(field)
      val arrowVector = arrowVectors.get(fieldName)
        .map { vec =>
          val columnSchema = ytSchema.getColumnSchema(ytSchema.findColumn(fieldName))
          if (columnSchema == null) {
            throw new IllegalStateException(s"Column $fieldName not found in schema")
          }
          createArrowColumnVector(vec, dataType, columnSchema.getType)
        }
        .getOrElse(ArrowColumnVector.nullVector(dataType))
      _columnVectors(index) = arrowVector
    }
    _batch = new ColumnarBatch(_columnVectors)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy