All Downloads are FREE. Search and download functionalities are using the official Maven repository.

scalapb.json4s.JsonFormat.scala Maven / Gradle / Ivy

The newest version!
package scalapb.json4s

import com.fasterxml.jackson.core.Base64Variants
import com.google.protobuf.ByteString
import com.google.protobuf.descriptor.FieldDescriptorProto
import com.google.protobuf.duration.Duration
import com.google.protobuf.struct.NullValue
import com.google.protobuf.timestamp.Timestamp
import scalapb.json4s.JsonFormat.GenericCompanion
import scalapb._
import org.json4s.JsonAST._
import org.json4s.{Reader, Writer, MappingException}

import scala.collection.mutable
import scala.collection.concurrent.TrieMap
import scala.language.existentials
import scala.reflect.ClassTag
import _root_.scalapb.descriptors._
import com.google.protobuf.descriptor.FieldDescriptorProto.Type
import com.google.protobuf.field_mask.FieldMask

import scala.util.Try

class JsonFormatException(msg: String, cause: Exception)
    extends MappingException(msg, cause) {
  def this(msg: String) = this(msg, null)
}

case class Formatter[T](
    writer: (Printer, T) => JValue,
    parser: (Parser, JValue) => T
)

case class EnumFormatter[T](
    writer: (Printer, T) => JValue,
    parser: (Parser, JValue) => Option[T]
)

case class FormatRegistry(
    messageFormatters: Map[GeneratedMessageCompanion[_], Formatter[_]] =
      Map.empty,
    enumFormatters: Map[EnumDescriptor, EnumFormatter[EnumValueDescriptor]] =
      Map.empty,
    registeredCompanions: Seq[GenericCompanion] = Seq.empty
) {
  private def companionForClass[T <: GeneratedMessage](
      ct: Class[_]
  ): GeneratedMessageCompanion[T] = {
    ct
      .getMethod("companion")
      .getReturnType()
      .getField("MODULE$")
      .get(null)
      .asInstanceOf[GeneratedMessageCompanion[T]]
  }

  def registerMessageFormatter[T <: GeneratedMessage](
      writer: (Printer, T) => JValue,
      parser: (Parser, JValue) => T
  )(implicit ct: ClassTag[T]): FormatRegistry = {
    copy(
      messageFormatters =
        messageFormatters + (companionForClass[T](ct.runtimeClass) -> Formatter(
          writer,
          parser
        ))
    )
  }

  def registerEnumFormatter[E <: GeneratedEnum](
      writer: (Printer, EnumValueDescriptor) => JValue,
      parser: (Parser, JValue) => Option[EnumValueDescriptor]
  )(implicit cmp: GeneratedEnumCompanion[E]): FormatRegistry = {
    copy(
      enumFormatters = enumFormatters + (cmp.scalaDescriptor -> EnumFormatter(
        writer,
        parser
      ))
    )
  }

  def registerWriter[T <: GeneratedMessage: ClassTag](
      writer: T => JValue,
      parser: JValue => T
  ): FormatRegistry = {
    registerMessageFormatter(
      (p: Printer, t: T) => writer(t),
      ((p: Parser, v: JValue) => parser(v))
    )
  }

  def getMessageWriter[T <: GeneratedMessage](implicit
      cmp: GeneratedMessageCompanion[T]
  ): Option[(Printer, T) => JValue] = {
    messageFormatters
      .get(cmp)
      .asInstanceOf[Option[Formatter[T]]]
      .map(_.writer)
  }

  @deprecated(
    "Use getMessageWriter that takes an implicit GeneratedMessageCompanion",
    "0.10.2"
  )
  def getMessageWriter[T](
      klass: Class[_ <: T]
  ): Option[(Printer, T) => JValue] = {
    messageFormatters
      .get(companionForClass(klass))
      .asInstanceOf[Option[Formatter[T]]]
      .map(_.writer)
  }

  def getMessageParser[T <: GeneratedMessage](implicit
      cmp: GeneratedMessageCompanion[T]
  ): Option[(Parser, JValue) => T] = {
    messageFormatters
      .get(cmp)
      .asInstanceOf[Option[Formatter[T]]]
      .map(_.parser)
  }

  @deprecated(
    "Use getMessageParser that takes an implicit GeneratedMessageCompanion",
    "0.10.2"
  )
  def getMessageParser[T](
      klass: Class[_ <: T]
  ): Option[(Parser, JValue) => T] = {
    messageFormatters
      .get(companionForClass(klass))
      .asInstanceOf[Option[Formatter[T]]]
      .map(_.parser)
  }

  def getEnumWriter(
      descriptor: EnumDescriptor
  ): Option[(Printer, EnumValueDescriptor) => JValue] = {
    enumFormatters.get(descriptor).map(_.writer)
  }

  def getEnumParser(
      descriptor: EnumDescriptor
  ): Option[(Parser, JValue) => Option[EnumValueDescriptor]] = {
    enumFormatters.get(descriptor).map(_.parser)
  }
}

/** TypeRegistry is used to map the @type field in Any messages to a ScalaPB
  * generated message.
  *
  * TypeRegistries are added to Printers and Parsers to enable printing and
  * parsing of Any messages.
  */
case class TypeRegistry(
    companions: Map[String, GenericCompanion] = Map.empty,
    private val filesSeen: Set[String] = Set.empty
) {
  def addMessage[T <: GeneratedMessage](implicit
      cmp: GeneratedMessageCompanion[T]
  ): TypeRegistry = {
    addMessageByCompanion(cmp)
  }

  def addFile(file: GeneratedFileObject): TypeRegistry = {
    if (filesSeen.contains(file.scalaDescriptor.fullName)) this
    else {
      val withFileSeen = copy(
        filesSeen = filesSeen + (file.scalaDescriptor.fullName)
      )

      val withDeps: TypeRegistry =
        file.dependencies.foldLeft(withFileSeen)((r, f) => r.addFile(f))

      file.messagesCompanions.foldLeft(withDeps)((r, mc) =>
        r.addMessageByCompanion(mc.asInstanceOf[GenericCompanion])
      )
    }
  }

  def addMessageByCompanion(cmp: GenericCompanion): TypeRegistry = {
    // TODO: need to add contained file to follow JsonFormat
    val withNestedMessages =
      cmp.nestedMessagesCompanions.foldLeft(this)((r, mc) =>
        r.addMessageByCompanion(mc.asInstanceOf[GenericCompanion])
      )
    copy(
      companions =
        withNestedMessages.companions + ((TypeRegistry.TypePrefix + cmp.scalaDescriptor.fullName) -> cmp)
    )
  }

  def findType(typeName: String): Option[GenericCompanion] =
    companions.get(typeName)
}

object TypeRegistry {
  private val TypePrefix = "type.googleapis.com/"

  def empty = TypeRegistry(Map.empty)

  val default = TypeRegistry()
    .addFile(com.google.protobuf.struct.StructProto)
    .addFile(com.google.protobuf.wrappers.WrappersProto)
    .addFile(com.google.protobuf.field_mask.FieldMaskProto)
    .addMessage(com.google.protobuf.any.Any)
    .addMessage(com.google.protobuf.timestamp.Timestamp)
    .addMessage(com.google.protobuf.duration.Duration)
    .addMessage(com.google.protobuf.empty.Empty)
}

object Printer {
  private final case class PrinterConfig(
      isIncludingDefaultValueFields: Boolean,
      isPreservingProtoFieldNames: Boolean,
      isFormattingLongAsNumber: Boolean,
      isFormattingEnumsAsNumber: Boolean,
      isFormattingMapEntriesAsKeyValuePairs: Boolean,
      formatRegistry: FormatRegistry,
      typeRegistry: TypeRegistry
  )

  private def defaultConfig =
    PrinterConfig(
      isIncludingDefaultValueFields = false,
      isPreservingProtoFieldNames = false,
      isFormattingLongAsNumber = false,
      isFormattingEnumsAsNumber = false,
      isFormattingMapEntriesAsKeyValuePairs = false,
      formatRegistry = JsonFormat.DefaultRegistry,
      typeRegistry = TypeRegistry.default
    )
}

class Printer private (config: Printer.PrinterConfig) {
  def this() = this(Printer.defaultConfig)

  @deprecated(
    "Use new Printer() and chain includingDefaultValueFields, preservingProtoFieldNames, etc.",
    "0.7.2"
  )
  def this(
      includingDefaultValueFields: Boolean = false,
      preservingProtoFieldNames: Boolean = false,
      formattingLongAsNumber: Boolean = false,
      formattingEnumsAsNumber: Boolean = false,
      mapEntriesAsKeyValuePairs: Boolean = false,
      formatRegistry: FormatRegistry = JsonFormat.DefaultRegistry,
      typeRegistry: TypeRegistry = TypeRegistry.empty
  ) =
    this(
      Printer.PrinterConfig(
        isIncludingDefaultValueFields = includingDefaultValueFields,
        isPreservingProtoFieldNames = preservingProtoFieldNames,
        isFormattingLongAsNumber = formattingLongAsNumber,
        isFormattingEnumsAsNumber = formattingEnumsAsNumber,
        isFormattingMapEntriesAsKeyValuePairs = mapEntriesAsKeyValuePairs,
        formatRegistry = formatRegistry,
        typeRegistry = typeRegistry
      )
    )

  def includingDefaultValueFields: Printer =
    new Printer(config.copy(isIncludingDefaultValueFields = true))

  def preservingProtoFieldNames: Printer =
    new Printer(config.copy(isPreservingProtoFieldNames = true))

  def formattingLongAsNumber: Printer =
    new Printer(config.copy(isFormattingLongAsNumber = true))

  def formattingEnumsAsNumber: Printer =
    new Printer(config.copy(isFormattingEnumsAsNumber = true))

  def mapEntriesAsKeyValuePairs: Printer =
    new Printer(config.copy(isFormattingMapEntriesAsKeyValuePairs = true))

  def withFormatRegistry(formatRegistry: FormatRegistry): Printer =
    new Printer(config.copy(formatRegistry = formatRegistry))

  def withTypeRegistry(typeRegistry: TypeRegistry): Printer =
    new Printer(config.copy(typeRegistry = typeRegistry))

  def typeRegistry: TypeRegistry = config.typeRegistry

  def print[A](m: GeneratedMessage): String = {
    import org.json4s.jackson.JsonMethods._
    compact(toJson(m))
  }

  private type FieldBuilder = mutable.Builder[JField, List[JField]]

  private def serializeMessageField(
      fd: FieldDescriptor,
      name: String,
      value: Any,
      b: FieldBuilder
  ): Unit = {
    value match {
      case null =>
      // We are never printing empty optional messages to prevent infinite recursion.
      case Nil =>
        if (config.isIncludingDefaultValueFields) {
          b += JField(
            name,
            if (fd.isMapField && !config.isFormattingMapEntriesAsKeyValuePairs)
              JObject()
            else JArray(Nil)
          )
        }
      case xs: Iterable[GeneratedMessage] @unchecked =>
        if (fd.isMapField && !config.isFormattingMapEntriesAsKeyValuePairs) {
          val mapEntryDescriptor =
            fd.scalaType.asInstanceOf[ScalaType.Message].descriptor
          val keyDescriptor = mapEntryDescriptor.findFieldByNumber(1).get
          val valueDescriptor = mapEntryDescriptor.findFieldByNumber(2).get
          b += JField(
            name,
            JObject(xs.map { x =>
              val key = x.getField(keyDescriptor) match {
                case PBoolean(v) => v.toString
                case PDouble(v)  => v.toString
                case PFloat(v)   => v.toString
                case PInt(v)     => v.toString
                case PLong(v)    => v.toString
                case PString(v)  => v
                case v =>
                  throw new JsonFormatException(s"Unexpected value for key: $v")
              }
              val value = if (valueDescriptor.protoType.isTypeMessage) {
                toJson(
                  x.getFieldByNumber(valueDescriptor.number)
                    .asInstanceOf[GeneratedMessage]
                )
              } else {
                serializeSingleValue(
                  valueDescriptor,
                  x.getField(valueDescriptor)
                )
              }
              key -> value
            }.toSeq: _*)
          )
        } else {
          b += JField(name, JArray(xs.map(toJson).toList))
        }
      case msg: GeneratedMessage =>
        b += JField(name, toJson(msg))
      case v =>
        throw new JsonFormatException(v.toString)
    }
  }

  private def serializeNonMessageField(
      fd: FieldDescriptor,
      name: String,
      value: PValue,
      b: FieldBuilder
  ) = {
    value match {
      case PEmpty =>
        if (
          config.isIncludingDefaultValueFields && fd.containingOneof.isEmpty
        ) {
          b += JField(name, defaultJValue(fd))
        }
      case PRepeated(xs) =>
        if (xs.nonEmpty || config.isIncludingDefaultValueFields) {
          b += JField(
            name,
            JArray(
              xs.map(
                serializeSingleValue(fd, _)
              ).toList
            )
          )
        }
      case v =>
        if (
          config.isIncludingDefaultValueFields ||
          !fd.isOptional ||
          !fd.file.isProto3 ||
          (v != JsonFormat.defaultValue(fd)) ||
          fd.containingOneof.isDefined
        ) {
          b += JField(
            name,
            serializeSingleValue(fd, v)
          )
        }
    }
  }

  def toJson[A <: GeneratedMessage](m: A): JValue = {
    config.formatRegistry.getMessageWriter[A](
      m.companion.asInstanceOf[GeneratedMessageCompanion[A]]
    ) match {
      case Some(f) => f(this, m)
      case None =>
        val b = List.newBuilder[JField]
        val descriptor = m.companion.scalaDescriptor
        b.sizeHint(descriptor.fields.size)
        descriptor.fields.foreach { f =>
          if (f.asProto.getType != FieldDescriptorProto.Type.TYPE_GROUP) {
            val name =
              if (config.isPreservingProtoFieldNames) f.name
              else JsonFormat.jsonName(f)
            if (f.protoType.isTypeMessage) {
              serializeMessageField(f, name, m.getFieldByNumber(f.number), b)
            } else {
              serializeNonMessageField(f, name, m.getField(f), b)
            }
          }
        }
        JObject(b.result())
    }
  }

  private def defaultJValue(fd: FieldDescriptor): JValue =
    serializeSingleValue(
      fd,
      JsonFormat.defaultValue(fd)
    )

  private def unsignedInt(n: Int): Long = n & 0x00000000ffffffffL
  private def unsignedLong(n: Long): BigInt =
    if (n < 0) BigInt(n & 0x7fffffffffffffffL).setBit(63) else BigInt(n)

  private def formatLong(
      n: Long,
      protoType: FieldDescriptorProto.Type,
      formattingLongAsNumber: Boolean
  ): JValue = {
    val v: BigInt =
      if (protoType.isTypeUint64 || protoType.isTypeFixed64) unsignedLong(n)
      else BigInt(n)
    if (formattingLongAsNumber) JInt(v) else JString(v.toString())
  }

  def serializeSingleValue(fd: FieldDescriptor, value: PValue): JValue = {
    serializeSingleValue(fd, value, config.isFormattingLongAsNumber)
  }

  def serializeSingleValue(
      fd: FieldDescriptor,
      value: PValue,
      formattingLongAsNumber: Boolean
  ): JValue =
    value match {
      case PEnum(e) =>
        config.formatRegistry.getEnumWriter(e.containingEnum) match {
          case Some(writer) => writer(this, e)
          case None =>
            if (config.isFormattingEnumsAsNumber || e.isUnrecognized)
              JInt(e.number)
            else JString(e.name)
        }
      case PInt(v) if fd.protoType.isTypeUint32  => JInt(unsignedInt(v))
      case PInt(v) if fd.protoType.isTypeFixed32 => JInt(unsignedInt(v))
      case PInt(v)                               => JInt(v)
      case PLong(v)   => formatLong(v, fd.protoType, formattingLongAsNumber)
      case PDouble(v) => JDouble(v)
      case PFloat(v) =>
        if (!v.isNaN && !v.isInfinite) JDecimal(BigDecimal.decimal(v))
        else JDouble(v)
      case PBoolean(v) => JBool(v)
      case PString(v)  => JString(v)
      case PByteString(v) =>
        JString(Base64Variants.getDefaultVariant.encode(v.toByteArray))
      case _: PMessage | PRepeated(_) | PEmpty =>
        throw new RuntimeException("Should not happen")
    }
}

object Parser {
  private final case class ParserConfig(
      isIgnoringUnknownFields: Boolean,
      isIgnoringOverlappingOneofFields: Boolean,
      mapEntriesAsKeyValuePairs: Boolean,
      formatRegistry: FormatRegistry,
      typeRegistry: TypeRegistry
  )
}

class Parser private (config: Parser.ParserConfig) {
  def this() =
    this(
      Parser.ParserConfig(
        isIgnoringUnknownFields = false,
        isIgnoringOverlappingOneofFields = false,
        mapEntriesAsKeyValuePairs = false,
        JsonFormat.DefaultRegistry,
        TypeRegistry.default
      )
    )

  @deprecated(
    "Use new Parser() and chain with usingTypeRegistry or formatRegistry",
    "0.7.1"
  )
  def this(
      preservingProtoFieldNames: Boolean = false,
      formatRegistry: FormatRegistry = JsonFormat.DefaultRegistry,
      typeRegistry: TypeRegistry = TypeRegistry.empty
  ) =
    this(
      Parser.ParserConfig(
        isIgnoringUnknownFields = false,
        isIgnoringOverlappingOneofFields = false,
        mapEntriesAsKeyValuePairs = false,
        formatRegistry,
        typeRegistry
      )
    )

  def ignoringUnknownFields: Parser =
    new Parser(config.copy(isIgnoringUnknownFields = true))

  def ignoringOverlappingOneofFields: Parser =
    new Parser(config.copy(isIgnoringOverlappingOneofFields = true))

  def mapEntriesAsKeyValuePairs: Parser =
    new Parser(config.copy(mapEntriesAsKeyValuePairs = true))

  def withFormatRegistry(formatRegistry: FormatRegistry) =
    new Parser(config.copy(formatRegistry = formatRegistry))

  def withTypeRegistry(typeRegistry: TypeRegistry) =
    new Parser(config.copy(typeRegistry = typeRegistry))

  def typeRegistry: TypeRegistry = config.typeRegistry

  def fromJsonString[A <: GeneratedMessage: GeneratedMessageCompanion](
      str: String
  ): A = {
    import org.json4s.jackson.JsonMethods._
    fromJson(parse(str, useBigDecimalForDouble = true))
  }

  def fromJson[A <: GeneratedMessage: GeneratedMessageCompanion](
      value: JValue
  ): A = {
    fromJson(value, false)
  }

  private[json4s] def fromJson[A <: GeneratedMessage](
      value: JValue,
      skipTypeUrl: Boolean
  )(implicit cmp: GeneratedMessageCompanion[A]): A = {
    cmp.messageReads.read(fromJsonToPMessage(cmp, value, skipTypeUrl))
  }

  private def allowNullValueForField(fd: FieldDescriptor): Boolean = {
    fd.scalaType match {
      case ScalaType.Message(t) =>
        t == com.google.protobuf.struct.Value.scalaDescriptor
      case _ => false
    }
  }

  private def fromJsonToPMessage(
      cmp: GeneratedMessageCompanion[_],
      value: JValue,
      skipTypeUrl: Boolean
  ): PMessage = {
    def parseValue(fd: FieldDescriptor, value: JValue): PValue = {
      if (fd.isMapField && !config.mapEntriesAsKeyValuePairs) {
        value match {
          case JObject(vals) =>
            val mapEntryDesc =
              fd.scalaType.asInstanceOf[ScalaType.Message].descriptor
            val keyDescriptor = mapEntryDesc.findFieldByNumber(1).get
            val valueDescriptor = mapEntryDesc.findFieldByNumber(2).get
            PRepeated(vals.iterator.map { case (key, jValue) =>
              val keyObj = keyDescriptor.scalaType match {
                case ScalaType.Boolean =>
                  PBoolean(java.lang.Boolean.valueOf(key))
                case ScalaType.Double =>
                  PDouble(java.lang.Double.valueOf(key))
                case ScalaType.Float  => PFloat(java.lang.Float.valueOf(key))
                case ScalaType.Int    => PInt(java.lang.Integer.valueOf(key))
                case ScalaType.Long   => PLong(java.lang.Long.valueOf(key))
                case ScalaType.String => PString(key)
                case _ =>
                  throw new RuntimeException(
                    s"Unsupported type for key for ${fd.name}"
                  )
              }
              PMessage(
                Map(
                  keyDescriptor -> keyObj,
                  valueDescriptor -> parseSingleValue(
                    cmp.messageCompanionForFieldNumber(fd.number),
                    valueDescriptor,
                    jValue
                  )
                )
              )
            }.toVector)
          case _ =>
            throw new JsonFormatException(
              s"Expected an object for map field ${fd.name} of ${fd.containingMessage.name}"
            )
        }
      } else if (fd.isRepeated) {
        value match {
          case JArray(vals) =>
            PRepeated(vals.map(parseSingleValue(cmp, fd, _)).toVector)
          case _ =>
            throw new JsonFormatException(
              s"Expected an array for repeated field ${fd.name} of ${fd.containingMessage.name}"
            )
        }
      } else parseSingleValue(cmp, fd, value)
    }

    config.formatRegistry.getMessageParser(
      cmp.asInstanceOf[GeneratedMessageCompanion[GeneratedMessage]]
    ) match {
      case Some(p) => p(this, value).toPMessage
      case None =>
        value match {
          case JObject(fields) =>
            val usedOneofs = mutable.Set[OneofDescriptor]()
            val fieldMap = JsonFormat.MemorizedFieldNameMap(cmp.scalaDescriptor)
            val valueMapBuilder = Map.newBuilder[FieldDescriptor, PValue]
            fields.foreach { case (name: String, jValue: JValue) =>
              if (fieldMap.contains(name)) {
                val fd = fieldMap(name)
                if (jValue != JNull || allowNullValueForField(fd)) {
                  fd.containingOneof.foreach(o =>
                    if (
                      !config.isIgnoringOverlappingOneofFields && !usedOneofs
                        .add(o)
                    ) {
                      throw new JsonFormatException(
                        s"Overlapping field '${name}' in oneof"
                      )
                    }
                  )
                  valueMapBuilder += (fd -> parseValue(fd, jValue))
                }
              } else if (
                !config.isIgnoringUnknownFields && !(skipTypeUrl && name == "@type")
              ) {
                throw new JsonFormatException(
                  s"Cannot find field: ${name} in message ${cmp.scalaDescriptor.fullName}"
                )
              }
            }

            PMessage(valueMapBuilder.result())
          case _ =>
            throw new JsonFormatException(s"Expected an object, found ${value}")
        }
    }
  }

  def defaultEnumParser(
      enumDescriptor: EnumDescriptor,
      value: JValue
  ): Option[EnumValueDescriptor] = {
    def enumValueFromInt(v: Int): Option[EnumValueDescriptor] =
      if (enumDescriptor.file.isProto3)
        Some(enumDescriptor.findValueByNumberCreatingIfUnknown(v))
      else
        enumDescriptor.findValueByNumber(v)

    def fail() =
      throw new JsonFormatException(
        s"Invalid enum value: $value for enum type: ${enumDescriptor.fullName}"
      )

    def defaultValue =
      if (config.isIgnoringUnknownFields && enumDescriptor.file.isProto3)
        enumDescriptor.findValueByNumber(0)
      else None

    val res = value match {
      case JDecimal(v) =>
        try {
          enumValueFromInt(v.bigDecimal.intValueExact)
        } catch {
          case e: ArithmeticException => defaultValue
        }
      case JInt(v) =>
        enumValueFromInt(v.toInt)
      case JString(s) =>
        enumDescriptor.values
          .find(_.name == s)
          .orElse {
            try {
              enumValueFromInt(new java.math.BigDecimal(s).intValueExact)
            } catch {
              case _: ArithmeticException   => None
              case _: NumberFormatException => None
            }
          }
          .orElse(defaultValue)
      case _ => fail()
    }

    if (res.isEmpty && !config.isIgnoringUnknownFields) {
      fail()
    }

    res
  }

  protected def parseSingleValue(
      containerCompanion: GeneratedMessageCompanion[_],
      fd: FieldDescriptor,
      value: JValue
  ): PValue =
    fd.scalaType match {
      case ScalaType.Enum(ed) => {
        val res: Option[EnumValueDescriptor] =
          config.formatRegistry.getEnumParser(ed) match {
            case Some(parser) => parser(this, value)
            case None         => defaultEnumParser(ed, value)
          }

        res.fold[PValue](PEmpty)(PEnum.apply)
      }
      case ScalaType.Message(md) =>
        fromJsonToPMessage(
          containerCompanion.messageCompanionForFieldNumber(fd.number),
          value,
          false
        )
      case st =>
        JsonFormat.parsePrimitive(
          fd.protoType,
          value,
          throw new JsonFormatException(
            s"Unexpected value ($value) for field ${fd.name} of ${fd.containingMessage.name}"
          )
        )
    }
}

object JsonFormat {
  import com.google.protobuf.wrappers

  // type GenericCompanion = GeneratedMessageCompanion[T] forSome { type T <: GeneratedMessage}
  type GenericCompanion = GeneratedMessageCompanion[_ <: GeneratedMessage]

  val DefaultRegistry = FormatRegistry()
    .registerWriter(
      (d: Duration) => JString(Durations.writeDuration(d)),
      jv =>
        jv match {
          case JString(str) => Durations.parseDuration(str)
          case _ => throw new JsonFormatException("Expected a string.")
        }
    )
    .registerWriter(
      (t: Timestamp) => JString(Timestamps.writeTimestamp(t)),
      jv =>
        jv match {
          case JString(str) => Timestamps.parseTimestamp(str)
          case _ => throw new JsonFormatException("Expected a string.")
        }
    )
    .registerWriter(
      (m: FieldMask) => JString(FieldMaskUtil.toJsonString(m)),
      jv =>
        jv match {
          case JString(str) => FieldMaskUtil.fromJsonString(str)
          case _ => throw new JsonFormatException("Expected a string.")
        }
    )
    .registerMessageFormatter[wrappers.DoubleValue](
      primitiveWrapperWriter,
      primitiveWrapperParser[wrappers.DoubleValue]
    )
    .registerMessageFormatter[wrappers.FloatValue](
      primitiveWrapperWriter,
      primitiveWrapperParser[wrappers.FloatValue]
    )
    .registerMessageFormatter[wrappers.Int32Value](
      primitiveWrapperWriter,
      primitiveWrapperParser[wrappers.Int32Value]
    )
    .registerMessageFormatter[wrappers.Int64Value](
      primitiveWrapperWriter,
      primitiveWrapperParser[wrappers.Int64Value]
    )
    .registerMessageFormatter[wrappers.UInt32Value](
      primitiveWrapperWriter,
      primitiveWrapperParser[wrappers.UInt32Value]
    )
    .registerMessageFormatter[wrappers.UInt64Value](
      primitiveWrapperWriter,
      primitiveWrapperParser[wrappers.UInt64Value]
    )
    .registerMessageFormatter[wrappers.BoolValue](
      primitiveWrapperWriter,
      primitiveWrapperParser[wrappers.BoolValue]
    )
    .registerMessageFormatter[wrappers.BytesValue](
      primitiveWrapperWriter,
      primitiveWrapperParser[wrappers.BytesValue]
    )
    .registerMessageFormatter[wrappers.StringValue](
      primitiveWrapperWriter,
      primitiveWrapperParser[wrappers.StringValue]
    )
    .registerEnumFormatter[NullValue](
      (_, _) => JNull,
      (parser, value) =>
        value match {
          case JNull => Some(NullValue.NULL_VALUE.scalaValueDescriptor)
          case _ => parser.defaultEnumParser(NullValue.scalaDescriptor, value)
        }
    )
    .registerWriter[com.google.protobuf.struct.Value](
      StructFormat.structValueWriter,
      StructFormat.structValueParser
    )
    .registerWriter[com.google.protobuf.struct.Struct](
      StructFormat.structWriter,
      StructFormat.structParser
    )
    .registerWriter[com.google.protobuf.struct.ListValue](
      StructFormat.listValueWriter,
      StructFormat.listValueParser
    )
    .registerMessageFormatter[com.google.protobuf.any.Any](
      AnyFormat.anyWriter,
      AnyFormat.anyParser
    )

  def primitiveWrapperWriter[T <: GeneratedMessage](implicit
      cmp: GeneratedMessageCompanion[T]
  ): ((Printer, T) => JValue) = {
    val fieldDesc = cmp.scalaDescriptor.findFieldByNumber(1).get
    (printer, t) =>
      printer.serializeSingleValue(
        fieldDesc,
        t.getField(fieldDesc)
      )
  }

  def primitiveWrapperParser[T <: GeneratedMessage](implicit
      cmp: GeneratedMessageCompanion[T]
  ): ((Parser, JValue) => T) = {
    val fieldDesc = cmp.scalaDescriptor.findFieldByNumber(1).get
    (parser, jv) =>
      cmp.messageReads.read(
        PMessage(
          Map(
            fieldDesc -> JsonFormat.parsePrimitive(
              fieldDesc.protoType,
              jv,
              throw new JsonFormatException(
                s"Unexpected value for ${cmp.scalaDescriptor.name}"
              )
            )
          )
        )
      )
  }

  val printer = new Printer()
  val parser = new Parser()

  def toJsonString[A <: GeneratedMessage](m: A): String = printer.print(m)

  def toJson[A <: GeneratedMessage](m: A): JValue = printer.toJson(m)

  def fromJson[A <: GeneratedMessage: GeneratedMessageCompanion](
      value: JValue
  ): A = parser.fromJson(value)

  def fromJsonEither[A <: GeneratedMessage: GeneratedMessageCompanion](
      value: JValue
  ): Either[MappingException, A] = try { Right(parser.fromJson(value)) }
  catch {
    case e: JsonFormatException => Left(new MappingException("fail", e))
  }

  def fromJsonString[A <: GeneratedMessage: GeneratedMessageCompanion](
      str: String
  ): A = {
    parser.fromJsonString(str)
  }

  implicit def protoToReader[T <: GeneratedMessage: GeneratedMessageCompanion]
      : Reader[T] = Reader.from(fromJsonEither(_))

  implicit def protoToWriter[T <: GeneratedMessage]: Writer[T] =
    new Writer[T] {
      def write(obj: T): JValue = printer.toJson(obj)
    }

  def defaultValue(fd: FieldDescriptor): PValue = {
    require(fd.isOptional)
    fd.scalaType match {
      case ScalaType.Int        => PInt(0)
      case ScalaType.Long       => PLong(0L)
      case ScalaType.Float      => PFloat(0)
      case ScalaType.Double     => PDouble(0)
      case ScalaType.Boolean    => PBoolean(false)
      case ScalaType.String     => PString("")
      case ScalaType.ByteString => PByteString(ByteString.EMPTY)
      case ScalaType.Enum(ed)   => PEnum(ed.values(0))
      case ScalaType.Message(_) =>
        throw new RuntimeException("No default value for message")
    }
  }

  @deprecated("Use parsePrimitive(protoType, value, onError) instead.", "0.9.0")
  def parsePrimitive(
      scalaType: ScalaType,
      protoType: FieldDescriptorProto.Type,
      value: JValue,
      onError: => PValue
  ): PValue =
    parsePrimitive(protoType, value, onError)

  def parsePrimitive(
      protoType: FieldDescriptorProto.Type,
      value: JValue,
      onError: => PValue
  ): PValue =
    (protoType, value) match {
      case (Type.TYPE_UINT32 | Type.TYPE_FIXED32, JInt(x)) =>
        parseUint32(x.toString)
      case (Type.TYPE_UINT32 | Type.TYPE_FIXED32, JDouble(x)) =>
        parseUint32(x.toString)
      case (Type.TYPE_UINT32 | Type.TYPE_FIXED32, JDecimal(x)) =>
        parseUint32(x.toString)
      case (Type.TYPE_UINT32 | Type.TYPE_FIXED32, JString(x)) => parseUint32(x)
      case (Type.TYPE_SINT32 | Type.TYPE_INT32 | Type.TYPE_SFIXED32, JInt(x)) =>
        parseInt32(x.toString)
      case (
            Type.TYPE_SINT32 | Type.TYPE_INT32 | Type.TYPE_SFIXED32,
            JDouble(x)
          ) =>
        parseInt32(x.toString)
      case (
            Type.TYPE_SINT32 | Type.TYPE_INT32 | Type.TYPE_SFIXED32,
            JDecimal(x)
          ) =>
        parseInt32(x.toString)
      case (
            Type.TYPE_SINT32 | Type.TYPE_INT32 | Type.TYPE_SFIXED32,
            JString(x)
          ) =>
        parseInt32(x)

      case (Type.TYPE_UINT64 | Type.TYPE_FIXED64, JInt(x)) =>
        parseUint64(x.toString)
      case (Type.TYPE_UINT64 | Type.TYPE_FIXED64, JDouble(x)) =>
        parseUint64(x.toString)
      case (Type.TYPE_UINT64 | Type.TYPE_FIXED64, JDecimal(x)) =>
        parseUint64(x.toString)
      case (Type.TYPE_UINT64 | Type.TYPE_FIXED64, JString(x)) => parseUint64(x)
      case (Type.TYPE_SINT64 | Type.TYPE_INT64 | Type.TYPE_SFIXED64, JInt(x)) =>
        parseInt64(x.toString)
      case (
            Type.TYPE_SINT64 | Type.TYPE_INT64 | Type.TYPE_SFIXED64,
            JDouble(x)
          ) =>
        parseInt64(x.toString)
      case (
            Type.TYPE_SINT64 | Type.TYPE_INT64 | Type.TYPE_SFIXED64,
            JDecimal(x)
          ) =>
        parseInt64(x.toString)
      case (
            Type.TYPE_SINT64 | Type.TYPE_INT64 | Type.TYPE_SFIXED64,
            JString(x)
          ) =>
        parseInt64(x)

      case (Type.TYPE_DOUBLE, JDouble(x))     => parseDouble(x.toString)
      case (Type.TYPE_DOUBLE, JInt(x))        => parseDouble(x.toString)
      case (Type.TYPE_DOUBLE, JDecimal(x))    => parseDouble(x.toString)
      case (Type.TYPE_DOUBLE, JString(v))     => parseDouble(v)
      case (Type.TYPE_FLOAT, JDouble(x))      => parseFloat(x.toString)
      case (Type.TYPE_FLOAT, JInt(x))         => parseFloat(x.toString)
      case (Type.TYPE_FLOAT, JDecimal(x))     => parseFloat(x.toString)
      case (Type.TYPE_FLOAT, JString(v))      => parseFloat(v)
      case (Type.TYPE_BOOL, JBool(b))         => PBoolean(b)
      case (Type.TYPE_BOOL, JString("true"))  => PBoolean(true)
      case (Type.TYPE_BOOL, JString("false")) => PBoolean(false)
      case (Type.TYPE_STRING, JString(s))     => PString(s)
      case (Type.TYPE_BYTES, JString(s)) =>
        PByteString(
          ByteString.copyFrom(Base64Variants.getDefaultVariant.decode(s))
        )
      case _ => onError
    }

  def parseBigDecimal(value: String): BigDecimal = {
    try {
      // JSON doesn't distinguish between integer values and floating point values so "1" and
      // "1.000" are treated as equal in JSON. For this reason we accept floating point values for
      // integer fields as well as long as it actually is an integer (i.e., round(value) == value).
      BigDecimal(value)
    } catch {
      case e: Exception =>
        throw new JsonFormatException(s"Not a numeric value: $value", e)
    }
  }

  def parseInt32(value: String): PValue = {
    try {
      PInt(value.toInt)
    } catch {
      case _: Exception =>
        try {
          PInt(parseBigDecimal(value).toIntExact)
        } catch {
          case e: Exception =>
            throw new JsonFormatException(s"Not an int32 value: $value", e)
        }
    }
  }

  def parseInt64(value: String): PValue = {
    try {
      PLong(value.toLong)
    } catch {
      case _: Exception =>
        val bd = parseBigDecimal(value)
        try {
          PLong(bd.toLongExact)
        } catch {
          case e: Exception =>
            throw new JsonFormatException(s"Not an int64 value: $value", e)
        }
    }
  }

  def parseUint32(value: String): PValue = {
    try {
      val result = value.toLong
      if (result < 0 || result > 0xffffffffL)
        throw new JsonFormatException(s"Out of range uint32 value: $value")
      return PInt(result.toInt)
    } catch {
      case e: JsonFormatException => throw e
      case e: Exception           => // Fall through.
    }
    parseBigDecimal(value).toBigIntExact.map { intVal =>
      if (intVal < 0 || intVal > 0xffffffffL)
        throw new JsonFormatException(s"Out of range uint32 value: $value")
      PInt(intVal.intValue)
    } getOrElse {
      throw new JsonFormatException(s"Not an uint32 value: $value")
    }
  }

  val MAX_UINT64 = BigInt("FFFFFFFFFFFFFFFF", 16)

  def parseUint64(value: String): PValue = {
    parseBigDecimal(value).toBigIntExact.map { intVal =>
      if (intVal < 0 || intVal > MAX_UINT64) {
        throw new JsonFormatException(s"Out of range uint64 value: $value")
      }
      PLong(intVal.longValue)
    } getOrElse {
      throw new JsonFormatException(s"Not an uint64 value: $value")
    }
  }

  def parseDouble(value: String): PDouble =
    value match {
      case "NaN"       => PDouble(Double.NaN)
      case "Infinity"  => PDouble(Double.PositiveInfinity)
      case "-Infinity" => PDouble(Double.NegativeInfinity)
      case v =>
        try {
          val bd = new java.math.BigDecimal(v)
          if (bd.compareTo(MAX_DOUBLE) > 0 || bd.compareTo(MIN_DOUBLE) < 0) {
            throw new JsonFormatException("Out of range double value: " + v)
          }
          PDouble(bd.doubleValue)
        } catch {
          case e: JsonFormatException => throw e
          case e: Exception =>
            throw new JsonFormatException("Not a double value: " + v)
        }
    }

  def parseFloat(str: String): PFloat =
    str match {
      case "NaN"       => PFloat(Float.NaN)
      case "Infinity"  => PFloat(Float.PositiveInfinity)
      case "-Infinity" => PFloat(Float.NegativeInfinity)
      case v =>
        try {
          val value = java.lang.Double.parseDouble(v)
          if (
            (value > Float.MaxValue * (1.0 + EPSILON)) ||
            (value < -Float.MaxValue * (1.0 + EPSILON))
          ) {
            throw new JsonFormatException("Out of range float value: " + value)
          }
          PFloat(value.toFloat)
        } catch {
          case e: JsonFormatException => throw e
          case e: Exception =>
            throw new JsonFormatException("Not a float value: " + v)
        }
    }

  def jsonName(fd: FieldDescriptor): String = {
    // protoc<3 doesn't know about json_name, so we fill it in if it's not populated.
    fd.asProto.jsonName
      .getOrElse(NameUtils.snakeCaseToCamelCase(fd.asProto.getName))
  }

  /** Given a message descriptor, provides a map from field names to field
    * descriptors.
    */
  private[json4s] object MemorizedFieldNameMap {
    private[this] val fieldNameMap
        : TrieMap[Descriptor, Map[String, FieldDescriptor]] =
      TrieMap.empty

    def apply(descriptor: Descriptor): Map[String, FieldDescriptor] = {
      fieldNameMap.getOrElseUpdate(
        descriptor, {
          val mapBuilder = Map.newBuilder[String, FieldDescriptor]
          descriptor.fields.foreach { fd =>
            mapBuilder += fd.name -> fd
            mapBuilder += JsonFormat.jsonName(fd) -> fd
          }
          mapBuilder.result()
        }
      )
    }
  }

  // From protobuf-java's JsonFormat.java:
  val EPSILON: Double = 1e-6

  val MORE_THAN_ONE =
    new java.math.BigDecimal(String.valueOf(1.toDouble + EPSILON))

  val MAX_DOUBLE = new java.math.BigDecimal(String.valueOf(Double.MaxValue))
    .multiply(MORE_THAN_ONE)

  val MIN_DOUBLE = new java.math.BigDecimal(String.valueOf(Double.MinValue))
    .multiply(MORE_THAN_ONE)
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy