scalapb.UnknownFieldSet.scala Maven / Gradle / Ivy
The newest version!
package scalapb
import com.google.protobuf.{
ByteString,
CodedInputStream,
CodedOutputStream,
InvalidProtocolBufferException
}
import scalapb.lenses.Lens
import scala.collection.mutable
import scala.collection.compat._
final case class UnknownFieldSet(
private[scalapb] val fields: Map[Int, UnknownFieldSet.Field] = Map.empty
) {
def getField(fieldNumber: Int): Option[UnknownFieldSet.Field] = fields.get(fieldNumber)
def withField(fieldNumber: Int, value: UnknownFieldSet.Field) =
new UnknownFieldSet(fields = fields + (fieldNumber -> value))
def writeTo(output: CodedOutputStream): Unit = {
fields.foreach { case (fieldNumber, field) =>
field.writeTo(fieldNumber, output)
}
}
def serializedSize: Int = {
var size: Int = 0
fields.foreach { case (fieldNumber, field) =>
size += field.serializedSize(fieldNumber)
}
size
}
def asMap: Map[Int, UnknownFieldSet.Field] = fields
}
object UnknownFieldSet {
val empty = UnknownFieldSet()
implicit class UnknownFieldSetLens[UpperPB](
lens: _root_.scalapb.lenses.Lens[UpperPB, UnknownFieldSet]
) {
def apply(fieldNumber: Int): Lens[UpperPB, UnknownFieldSet.Field] =
lens.compose(Lens[UnknownFieldSet, UnknownFieldSet.Field]({ t =>
t.fields.getOrElse(fieldNumber, UnknownFieldSet.Field())
})({ (c, t) => c.withField(fieldNumber, t) }))
}
final case class Field(
varint: Seq[Long] = Vector.empty,
fixed64: Seq[Long] = Vector.empty,
fixed32: Seq[Int] = Vector.empty,
lengthDelimited: Seq[ByteString] = Vector.empty
) {
def writeTo(fieldNumber: Int, output: CodedOutputStream): Unit = {
varint.foreach(output.writeUInt64(fieldNumber, _))
fixed32.foreach(output.writeFixed32(fieldNumber, _))
fixed64.foreach(output.writeFixed64(fieldNumber, _))
lengthDelimited.foreach(output.writeBytes(fieldNumber, _))
}
def serializedSize(fieldNumber: Int): Int = {
varint.map(CodedOutputStream.computeUInt64Size(fieldNumber, _)).sum +
fixed32.map(CodedOutputStream.computeFixed32Size(fieldNumber, _)).sum +
fixed64.map(CodedOutputStream.computeFixed64Size(fieldNumber, _)).sum +
lengthDelimited.map(CodedOutputStream.computeBytesSize(fieldNumber, _)).sum
}
}
object Field {
val varintLens = Lens[Field, Seq[Long]](_.varint)((c, v) => c.copy(varint = v))
val fixed64Lens = Lens[Field, Seq[Long]](_.fixed64)((c, v) => c.copy(fixed64 = v))
val fixed32Lens = Lens[Field, Seq[Int]](_.fixed32)((c, v) => c.copy(fixed32 = v))
val lengthDelimitedLens =
Lens[Field, Seq[ByteString]](_.lengthDelimited)((c, v) => c.copy(lengthDelimited = v))
class Builder {
private val varint = Vector.newBuilder[Long]
private val fixed64 = Vector.newBuilder[Long]
private val fixed32 = Vector.newBuilder[Int]
private val lengthDelimited = Vector.newBuilder[ByteString]
def result() =
Field(
varint = varint.result(),
fixed64 = fixed64.result(),
fixed32 = fixed32.result(),
lengthDelimited = lengthDelimited.result()
)
def parseField(tag: Int, input: CodedInputStream) = {
val wireType = WireType.getTagWireType(tag)
wireType match {
case WireType.WIRETYPE_VARINT =>
varint += input.readInt64()
case WireType.WIRETYPE_FIXED64 =>
fixed64 += input.readFixed64()
case WireType.WIRETYPE_LENGTH_DELIMITED =>
lengthDelimited += input.readBytes()
case WireType.WIRETYPE_FIXED32 =>
fixed32 += input.readFixed32()
case _ =>
throw new InvalidProtocolBufferException(
s"Protocol message tag had invalid wire type: ${wireType}"
)
}
}
}
object Builder {
def fromField(f: Field): Field.Builder = {
val b = new Field.Builder
b.varint ++= f.varint
b.fixed32 ++= f.fixed32
b.fixed64 ++= f.fixed64
b.lengthDelimited ++= f.lengthDelimited
b
}
}
}
class Builder {
private val fieldBuilders = new mutable.HashMap[Int, Field.Builder]
def this(base: UnknownFieldSet) = {
this()
if (base.fields.nonEmpty) {
fieldBuilders ++= base.fields.view.mapValues(Field.Builder.fromField)
}
}
def result() =
if (fieldBuilders.isEmpty) UnknownFieldSet.empty
else new UnknownFieldSet(fieldBuilders.view.mapValues(_.result()).toMap)
def parseField(tag: Int, input: CodedInputStream) = {
val fieldNumber = WireType.getTagFieldNumber(tag)
fieldBuilders.getOrElseUpdate(fieldNumber, new Field.Builder()).parseField(tag, input)
}
}
}
object WireType {
def getTagWireType(tag: Int) = tag & 7
def getTagFieldNumber(tag: Int) = tag >>> 3
val WIRETYPE_VARINT = 0
val WIRETYPE_FIXED64 = 1
val WIRETYPE_LENGTH_DELIMITED = 2
val WIRETYPE_START_GROUP = 3
val WIRETYPE_END_GROUP = 4
val WIRETYPE_FIXED32 = 5
sealed trait WireValue
case class Fixed64(value: Seq[Long]) extends WireValue
case class Fixed32(value: Seq[Long]) extends WireValue
case class Varint(value: Seq[Long]) extends WireValue
case class LengthDelimited(value: Seq[ByteString]) extends WireValue
}