All Downloads are FREE. Search and download functionalities are using the official Maven repository.

scalapb.compiler.FieldTransformations.scala Maven / Gradle / Ivy

The newest version!
package scalapb.compiler

import com.google.protobuf.Message
import scalapb.options.Scalapb
import scalapb.options.Scalapb.MatchType
import com.google.protobuf.Descriptors.FieldDescriptor
import com.google.protobuf.Descriptors.FileDescriptor
import scala.jdk.CollectionConverters._
import com.google.protobuf.DescriptorProtos.{FieldOptions, FieldDescriptorProto}
import scalapb.options.Scalapb.FieldTransformation
import com.google.protobuf.DynamicMessage
import com.google.protobuf.Descriptors.FieldDescriptor.Type
import scalapb.options.Scalapb.ScalaPbOptions.AuxFieldOptions
import com.google.protobuf.Descriptors.Descriptor
import com.google.protobuf.ByteString
import java.util.regex.Pattern
import java.util.regex.Matcher

private[compiler] case class ResolvedFieldTransformation(
    whenFields: Map[FieldDescriptor, Any],
    set: scalapb.options.Scalapb.FieldOptions,
    matchType: MatchType,
    extensions: Set[FieldDescriptor]
)

private[compiler] case class ExtensionResolutionContext(
    currentFile: String,
    extensions: Set[FieldDescriptor]
)

private[compiler] object ResolvedFieldTransformation {
  def apply(
      file: FileDescriptor,
      ft: FieldTransformation
  ): ResolvedFieldTransformation = {
    val context =
      ExtensionResolutionContext(
        file.getFullName(),
        FieldTransformations.fieldExtensionsForFile(file)
      )
    if (
      !ft.getSet().getAllFields().keySet().asScala.subsetOf(Set(Scalapb.field.getDescriptor()))
      || !ft
        .getSet()
        .getUnknownFields()
        .asMap()
        .keySet()
        .asScala
        .subsetOf(Set(Scalapb.field.getNumber()))
    ) {
      throw new GeneratorException(
        s"${file.getFullName}: FieldTransformation.set must contain only [scalapb.field] field."
      )
    }
    ResolvedFieldTransformation(
      FieldTransformations.fieldMap(
        FieldDescriptorProto.parseFrom(ft.getWhen.toByteArray()),
        context = context
      ),
      ft.getSet().getExtension(Scalapb.field),
      ft.getMatchType(),
      context.extensions
    )
  }
}

private[compiler] object FieldTransformations {
  def matches(
      currentFile: String,
      input: FieldDescriptorProto,
      transformation: ResolvedFieldTransformation
  ): Boolean = {
    transformation.matchType match {
      case MatchType.CONTAINS =>
        matchContains(
          input,
          transformation.whenFields,
          ExtensionResolutionContext(currentFile, transformation.extensions)
        )
      case MatchType.EXACT => input == transformation.whenFields
      case MatchType.PRESENCE =>
        matchPresence(
          input,
          transformation.whenFields,
          ExtensionResolutionContext(currentFile, transformation.extensions)
        )
    }
  }

  def matchContains(
      input: Message,
      pattern: Map[FieldDescriptor, Any],
      context: ExtensionResolutionContext
  ): Boolean = {
    pattern.forall { case (fd, v) =>
      if (!fd.isExtension()) {
        if (fd.getType() != Type.MESSAGE)
          if (fd.isRepeated) {
            input.getRepeatedFieldCount(fd) > 0 && input.getField(fd) == v
          } else
            input.hasField(fd) && input.getField(fd) == v
        else {
          input.hasField(fd) && matchContains(
            input.getField(fd).asInstanceOf[Message],
            v.asInstanceOf[Map[FieldDescriptor, Any]],
            context
          )
        }
      } else {
        input.getUnknownFields().hasField(fd.getNumber) &&
        matchContains(
          getExtensionField(input, fd),
          v.asInstanceOf[Map[FieldDescriptor, Any]],
          context
        )
      }
    }
  }

  def matchPresence(
      input: Message,
      pattern: Map[FieldDescriptor, Any],
      context: ExtensionResolutionContext
  ): Boolean = {
    pattern.forall { case (fd, v) =>
      if (!fd.isExtension()) {
        if (fd.getType() != Type.MESSAGE)
          if (fd.isRepeated) {
            input.getRepeatedFieldCount(fd) > 0
          } else {
            input.hasField(fd)
          }
        else {
          input.hasField(fd) && matchPresence(
            input.getField(fd).asInstanceOf[Message],
            v.asInstanceOf[Map[FieldDescriptor, Any]],
            context
          )
        }
      } else {
        input.getUnknownFields().hasField(fd.getNumber) &&
        matchPresence(
          getExtensionField(input, fd),
          v.asInstanceOf[Map[FieldDescriptor, Any]],
          context
        )
      }
    }
  }

  def processFieldTransformations(
      f: FileDescriptor,
      transforms: Seq[ResolvedFieldTransformation]
  ): Seq[AuxFieldOptions] =
    if (transforms.isEmpty) Seq.empty
    else {
      val extensions: Set[FieldDescriptor] = fieldExtensionsForFile(f)
      def processFile: Seq[AuxFieldOptions] =
        f.getMessageTypes().asScala.flatMap(processMessage(_)).toSeq

      def processMessage(m: Descriptor): Seq[AuxFieldOptions] = {
        m.getFields().asScala.flatMap(processField(_)) ++
          m.getNestedTypes().asScala.flatMap(processMessage(_)).toSeq
      }.toSeq

      def processField(fd: FieldDescriptor): Seq[AuxFieldOptions] =
        if (transforms.nonEmpty) {
          val noReg   = FieldDescriptorProto.parseFrom(fd.toProto().toByteArray())
          val context = ExtensionResolutionContext(f.getFullName(), extensions)
          transforms.flatMap { transform =>
            if (matches(f.getFullName(), noReg, transform))
              Seq(
                AuxFieldOptions.newBuilder
                  .setTarget(fd.getFullName())
                  .setOptions(interpolateStrings(transform.set, fd.toProto(), context))
                  .build
              )
            else Seq.empty
          }
        } else Seq.empty
      processFile
    }

  def fieldExtensionsForFile(f: FileDescriptor): Set[FieldDescriptor] = {
    (f.getExtensions()
      .asScala
      .filter(
        // Comparing the descriptors references directly will not work. The google.protobuf.FieldOptions
        // we will get from `getContainingType` is the one we get from parsing the code generation request
        // inputs, which are disjoint from the compilerplugin's FieldOptions.
        _.getContainingType.getFullName == FieldOptions.getDescriptor().getFullName()
      ) ++ f.getDependencies().asScala.flatMap(fieldExtensionsForFile(_))).toSet
  }

  // Like m.getAllFields(), but also resolves unknown fields from extensions available in the scope
  // of the message.
  def fieldMap(
      m: Message,
      context: ExtensionResolutionContext
  ): Map[FieldDescriptor, Any] = {
    val unknownFields = for {
      number <- m.getUnknownFields().asMap().keySet().asScala
    } yield {
      val ext = context.extensions
        .find(_.getNumber == number)
        .getOrElse(
          throw new GeneratorException(
            s"${context.currentFile}: Could not find extension number $number when processing a field " +
              "transformation. A proto file defining this extension needs to be imported directly or transitively in this file."
          )
        )
      ext -> fieldMap(getExtensionField(m, ext), context)
    }

    val knownFields = m.getAllFields().asScala.map { case (field, value) =>
      if (field.getType() == Type.MESSAGE && !field.isOptional()) {
        throw new GeneratorException(
          s"${context.currentFile}: matching is supported only for scalar types and optional message fields."
        )
      }
      (field -> (if (field.getType() == Type.MESSAGE)
                   fieldMap(value.asInstanceOf[Message], context)
                 else value))
    }

    unknownFields.toMap ++ knownFields
  }

  def getExtensionField(
      m: Message,
      ext: FieldDescriptor
  ): Message = {
    if (ext.getType != Type.MESSAGE || !ext.isOptional) {
      throw new GeneratorException(
        s"Unknown extension fields must be optional message types: ${ext}"
      )
    }
    val fieldValue = m.getUnknownFields().getField(ext.getNumber())
    DynamicMessage.parseFrom(
      ext.getMessageType(),
      fieldValue.getLengthDelimitedList().asScala.foldLeft(ByteString.EMPTY)(_.concat(_))
    )
  }

  def splitPath(s: String): List[String] = splitPath0(s, s)

  private def splitPath0(s: String, allPath: String): List[String] = {
    if (s.isEmpty()) throw new GeneratorException(s"Got empty path component in $allPath")
    else {
      val tokenEnd = if (s.startsWith("[")) {
        val end = s.indexOf(']') + 1
        if (end == -1) throw new GeneratorException(s"Unmatched [ in path $allPath")
        if (s.length < end + 1 || s(end) != '.')
          throw new GeneratorException(
            s"Extension can not be the last path component. Expected . after ] in $allPath"
          )
        end
      } else {
        s.indexOf('.')
      }
      if (tokenEnd == -1) s :: Nil
      else s.substring(0, tokenEnd) :: splitPath0(s.substring(tokenEnd + 1), allPath)
    }
  }

  def fieldByPath(
      message: Message,
      path: String,
      context: ExtensionResolutionContext
  ): String =
    if (path.isEmpty()) throw new GeneratorException("Got an empty path")
    else
      fieldByPath(message, splitPath(path), path, context) match {
        case Left(error)  => throw new GeneratorException(error)
        case Right(value) => value
      }

  private[compiler] def fieldByPath(
      message: Message,
      path: List[String],
      allPath: String,
      context: ExtensionResolutionContext
  ): Either[String, String] = {
    for {
      fieldName <- path.headOption.toRight("Got an empty path")
      fd <-
        if (fieldName.startsWith("["))
          context.extensions
            .find(_.getFullName == fieldName.substring(1, fieldName.length() - 1))
            .toRight(
              s"Could not find extension $fieldName when resolving $allPath"
            )
        else
          Option(message.getDescriptorForType().findFieldByName(fieldName))
            .toRight(
              s"Could not find field named $fieldName when resolving $allPath"
            )
      _ <-
        if (
          fd.isExtension() && fd
            .getContainingType()
            .getFullName != message.getDescriptorForType().getFullName()
        ) {
          val containingType = fd.getContainingType().getFullName()

          val dym =
            if (containingType == "google.protobuf.FieldOptions")
              s" Did you mean options.${fieldName} ?"
            else ""

          Left(s"Extension $fieldName is not an extension of ${message
              .getDescriptorForType()
              .getFullName()}, it is an extension of ${fd.getContainingType().getFullName()}.$dym")
        } else Right(())
      _ <-
        if (fd.isRepeated()) Left("Repeated fields are not supported")
        else Right(())
      v = if (fd.isExtension) getExtensionField(message, fd) else message.getField(fd)
      res <- path match {
        case _ :: Nil => Right(v.toString())
        case _ :: tail =>
          if (fd.getType() == Type.MESSAGE)
            fieldByPath(v.asInstanceOf[Message], tail, allPath, context)
          else
            Left(
              s"Type ${fd.getType.toString} does not have a field ${tail.head} in $allPath"
            )
        case Nil => Left("Unexpected empty path")
      }
    } yield res
  }

  // Substitutes $(path) references in string fields within msg with values coming from the data
  // message
  private[compiler] def interpolateStrings[T <: Message](
      msg: T,
      data: Message,
      context: ExtensionResolutionContext
  ): T = {
    val b = msg.toBuilder()
    for {
      (field, value) <- msg.getAllFields().asScala
    } field.getType() match {
      case Type.STRING if (!field.isRepeated()) =>
        b.setField(field, interpolate(value.asInstanceOf[String], data, context))
      case Type.MESSAGE =>
        if (field.isRepeated())
          b.setField(
            field,
            value
              .asInstanceOf[java.util.List[Message]]
              .asScala
              .map(interpolateStrings(_, data, context))
              .asJava
          )
        else
          b.setField(
            field,
            interpolateStrings(value.asInstanceOf[Message], data, context)
          )
      case _ =>
    }
    b.build().asInstanceOf[T]
  }

  val FieldPath: java.util.regex.Pattern =
    raw"[$$]\(([a-zA-Z0-9_.\[\]]*)\)".r.pattern

  // Interpolates paths in the given string with values coming from the data message
  private[compiler] def interpolate(
      value: String,
      data: Message,
      context: ExtensionResolutionContext
  ): String =
    replaceAll(value, FieldPath, m => fieldByPath(data, m.group(1), context))

  // Matcher.replaceAll appeared on Java 9, so we have this Java 8 compatible version instead. Adapted
  // from https://stackoverflow.com/a/43372206/97524
  private[compiler] def replaceAll(
      templateText: String,
      pattern: Pattern,
      replacer: Matcher => String
  ): String = {
    val matcher = pattern.matcher(templateText)
    val result  = new StringBuffer()
    while (matcher.find()) {
      matcher.appendReplacement(result, Matcher.quoteReplacement(replacer.apply(matcher)));
    }
    matcher.appendTail(result);
    result.toString()
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy