All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.mjakubowski84.parquet4s.ScalaPBImplicits.scala Maven / Gradle / Ivy

The newest version!
package com.github.mjakubowski84.parquet4s

import scalapb.descriptors.{Descriptor, FieldDescriptor, ScalaType}
import scalapb.{GeneratedMessage, GeneratedMessageCompanion}
import org.apache.parquet.schema.LogicalTypeAnnotation.{enumType, listType, mapType, stringType}
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.*
import org.apache.parquet.schema.Types.{Builder, GroupBuilder}
import org.apache.parquet.schema.{PrimitiveType, Type, Types}

import scala.util.matching.Regex

object ScalaPBImplicits {
  implicit def scalapbParquetRecordEncoder[T <: GeneratedMessage]: ParquetRecordEncoder[T] =
    new ScalaPBParquetRecordEncoder[T]

  implicit def scalaPBParquetRecordDecoder[T <: GeneratedMessage: GeneratedMessageCompanion]: ParquetRecordDecoder[T] =
    new ScalaPBParquetRecordDecoder[T]

  implicit def scalapbParquetSchemaResolver[T <: GeneratedMessage: GeneratedMessageCompanion]
      : ParquetSchemaResolver[T] =
    new ScalaPBParquetSchemaResolver[T]

  private[parquet4s] val MetadataEnumPrefix: String  = "parquet.proto.enum."
  private[parquet4s] val EnumNameNumberPairRe: Regex = "(.*?):(.*?)(?:,|$)".r

  implicit private[parquet4s] class RichGroupBuilder[T](private val builder: GroupBuilder[T]) extends AnyVal {
    def addField(fd: FieldDescriptor): Builder[? <: Builder[?, GroupBuilder[T]], GroupBuilder[T]] =
      fd.scalaType match {
        case ScalaType.Message(md) =>
          if (fd.isMapField) addMapField(md.fields)
          else if (fd.isRepeated) addRepeatedMessage(md)
          else builder.group(repetition(fd)).addFields(md.fields)
        case _ =>
          val rawType = primitiveType(fd)
          if (fd.isRepeated) addRepeatedPrimitive(rawType)
          else builder.primitive(rawType.getPrimitiveTypeName, repetition(fd)).as(rawType.getLogicalTypeAnnotation)
      }

    def addMapField(fields: Vector[FieldDescriptor]): GroupBuilder[GroupBuilder[T]] = {
      val keyFd   = fields.head
      val valFd   = fields.tail.head
      val keyType = primitiveType(keyFd).asPrimitiveType()
      builder
        .optionalGroup()
        .as(mapType())
        .addMapKey(keyType)
        .addField(valFd)
        .named(MapParquetRecord.ValueFieldName)
        .named(MapParquetRecord.MapKeyValueFieldName)
    }

    def addMapKey(keyType: PrimitiveType): GroupBuilder[GroupBuilder[T]] =
      builder
        .repeatedGroup()
        .primitive(keyType.getPrimitiveTypeName, Type.Repetition.REQUIRED)
        .as(keyType.getLogicalTypeAnnotation)
        .named(MapParquetRecord.KeyFieldName)

    def addRepeatedMessage(fd: Descriptor): GroupBuilder[GroupBuilder[T]] =
      builder
        .optionalGroup()
        .as(listType())
        .repeatedGroup()
        .optionalGroup()
        .addFields(fd.fields)
        .named(ListParquetRecord.ElementName.Element)
        .named(ListParquetRecord.ListFieldName)

    def addRepeatedPrimitive(elementType: PrimitiveType) =
      builder
        .optionalGroup()
        .as(listType())
        .repeatedGroup()
        .primitive(elementType.getPrimitiveTypeName, Type.Repetition.REQUIRED)
        .as(elementType.getLogicalTypeAnnotation)
        .named(ListParquetRecord.ElementName.Element)
        .named(ListParquetRecord.ListFieldName)

    def addFields(fields: Vector[FieldDescriptor]): GroupBuilder[T] =
      fields.foldLeft(builder)((builder, fd) => builder.addField(fd).id(fd.index).named(fd.name))

    def primitiveType(fd: FieldDescriptor): PrimitiveType = {
      val rep = repetition(fd)
      fd.scalaType match {
        case ScalaType.Boolean    => Types.primitive(BOOLEAN, rep).named(fd.name)
        case ScalaType.Enum(_)    => Types.primitive(BINARY, rep).as(enumType()).named(fd.name)
        case ScalaType.Int        => Types.primitive(INT32, rep).named(fd.name)
        case ScalaType.Long       => Types.primitive(INT64, rep).named(fd.name)
        case ScalaType.Float      => Types.primitive(FLOAT, rep).named(fd.name)
        case ScalaType.Double     => Types.primitive(DOUBLE, rep).named(fd.name)
        case ScalaType.String     => Types.primitive(BINARY, rep).as(stringType()).named(fd.name)
        case ScalaType.ByteString => Types.primitive(BINARY, rep).named(fd.name)
        case ScalaType.Message(_) => throw new IllegalArgumentException(s"Field is not primitive type: ${fd.fullName}")
      }
    }

    def repetition(fd: FieldDescriptor): Type.Repetition =
      if (fd.isRequired) Type.Repetition.REQUIRED
      else if (fd.isRepeated) Type.Repetition.REPEATED
      else Type.Repetition.OPTIONAL;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy