All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.mjakubowski84.parquet4s.ScalaPBParquetRecordEncoder.scala Maven / Gradle / Ivy

The newest version!
package com.github.mjakubowski84.parquet4s

import org.apache.parquet.io.api.Binary
import scalapb.GeneratedMessage
import scalapb.descriptors.*
import com.github.mjakubowski84.parquet4s.ScalaPBImplicits.*

import scala.collection.mutable
import scala.util.control.NoStackTrace

case class ScalaPBParquetEncodeException(msg: String) extends RuntimeException(msg) with NoStackTrace

class ScalaPBParquetRecordEncoder[T <: GeneratedMessage] extends ParquetRecordEncoder[T] {
  private val enumMetadata: mutable.Map[String, mutable.Map[String, Int]] = mutable.Map.empty

  override def encode(
      entity: T,
      resolver: EmptyRowParquetRecordResolver,
      configuration: ValueCodecConfiguration
  ): RowParquetRecord =
    encodeMessage(entity.toPMessage)

  override def getMetadata(): Map[String, String] =
    enumMetadata.iterator.map { case (key, value) =>
      val metadataEnumKey    = MetadataEnumPrefix + key
      val metadataEnumValues = value.map { case (name, number) => s"$name:$number" }.mkString(",")
      (metadataEnumKey, metadataEnumValues)
    }.toMap

  private def encodeMessage(msg: PMessage): RowParquetRecord =
    RowParquetRecord(msg.value.view.map { case (fd, v) => fd.name -> encodeField(fd, v) }.toSeq)

  private def encodeField(fd: FieldDescriptor, value: PValue): Value =
    (fd.scalaType, value) match {
      case (ScalaType.Boolean, PBoolean(value)) => BooleanValue(value)
      case (ScalaType.Int, PInt(value))         => IntValue(value)
      case (ScalaType.Long, PLong(value))       => LongValue(value)
      case (ScalaType.Float, PFloat(value))     => FloatValue(value)
      case (ScalaType.Double, PDouble(value))   => DoubleValue(value)
      case (ScalaType.String, PString(value))   => BinaryValue(Binary.fromString(value))
      case (ScalaType.ByteString, PByteString(value)) =>
        BinaryValue(Binary.fromReusedByteBuffer(value.asReadOnlyByteBuffer()))
      case (ScalaType.Message(_), msg: PMessage)                                              => encodeMessage(msg)
      case (ScalaType.Message(md), PRepeated(values)) if fd.isMapField && md.fields.size == 2 => encodeMap(md, values)
      case (_, PRepeated(values))            => ListParquetRecord(values.map(encodeField(fd, _))*)
      case (_, PEmpty)                       => NullValue
      case (ScalaType.Enum(_), PEnum(value)) => encodeEnumValue(value)
      case _ =>
        throw ScalaPBParquetEncodeException(s"Unsupported combination of field and value: ${fd.scalaType}, $value")
    }

  private def encodeMap(md: Descriptor, values: Vector[PValue]) = {
    val entries = values.map {
      case msg: PMessage if msg.value.size == 2 => encodeMapEntry(md.fields(0), md.fields(1), msg)
      case value                                => throw ScalaPBParquetEncodeException(s"Invalid map entry: $value")
    }
    MapParquetRecord(entries*)
  }

  private def encodeMapEntry(keyFd: FieldDescriptor, valueFd: FieldDescriptor, msg: PMessage) =
    (msg.value.get(keyFd), msg.value.get(valueFd)) match {
      case (Some(key), Some(value)) => (encodeField(keyFd, key), encodeField(valueFd, value))
      case _ =>
        throw ScalaPBParquetEncodeException(s"Invalid map entry: key field: $keyFd, value field: $valueFd, msg: $msg")
    }

  private def encodeEnumValue(value: EnumValueDescriptor) = {
    val enumName = value.containingEnum.fullName
    enumMetadata.getOrElseUpdate(enumName, mutable.Map.empty).put(value.name, value.number)
    BinaryValue(Binary.fromString(value.name))
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy