zio.schema.codec.ProtobufCodec.scala Maven / Gradle / Ivy
The newest version!
package zio.schema.codec
import java.nio.charset.StandardCharsets
import java.nio.{ ByteBuffer, ByteOrder }
import java.time._
import java.time.format.DateTimeFormatter
import java.util.UUID
import scala.collection.immutable.ListMap
import scala.util.control.NonFatal
import zio.schema.MutableSchemaBasedValueBuilder.{ CreateValueFromSchemaError, ReadingFieldResult }
import zio.schema._
import zio.schema.annotation.fieldDefaultValue
import zio.schema.codec.DecodeError.{ ExtraFields, MalformedField, MissingField }
import zio.schema.codec.ProtobufCodec.Protobuf.WireType.LengthDelimited
import zio.stream.ZPipeline
import zio.{ Cause, Chunk, ChunkBuilder, Unsafe, ZIO }
object ProtobufCodec {
implicit def protobufCodec[A](implicit schema: Schema[A]): BinaryCodec[A] =
new BinaryCodec[A] {
override def decode(whole: Chunk[Byte]): Either[DecodeError, A] =
new Decoder(whole).decode(schema)
override def streamDecoder: ZPipeline[Any, DecodeError, Byte, A] =
ZPipeline.mapChunksZIO(chunk => ZIO.fromEither(new Decoder(chunk).decode(schema).map(Chunk(_))))
override def encode(value: A): Chunk[Byte] =
Encoder.process(schema, value)
override def streamEncoder: ZPipeline[Any, Nothing, A, Byte] =
ZPipeline.mapChunks(
_.flatMap(encode)
)
}
object Protobuf {
sealed trait WireType
object WireType {
case object VarInt extends WireType
case object Bit64 extends WireType
case class LengthDelimited(width: Int) extends WireType
case object StartGroup extends WireType
case object EndGroup extends WireType
case object Bit32 extends WireType
}
/**
* Used when encoding sequence of values to decide whether each value need its own key or values can be packed together without keys (for example numbers).
*/
@scala.annotation.tailrec
private[codec] def canBePacked(schema: Schema[_]): Boolean = schema match {
case Schema.Sequence(element, _, _, _, _) => canBePacked(element)
case Schema.Transform(codec, _, _, _, _) => canBePacked(codec)
case Schema.Primitive(standardType, _) => canBePacked(standardType)
case _: Schema.Tuple2[_, _] => false
case _: Schema.Optional[_] => false
case _: Schema.Fail[_] => false
case _: Schema.Either[_, _] => false
case _: Schema.Fallback[_, _] => false
case lzy @ Schema.Lazy(_) => canBePacked(lzy.schema)
case _ => false
}
private def canBePacked(standardType: StandardType[_]): Boolean = standardType match {
case StandardType.UnitType => false
case StandardType.StringType => false
case StandardType.BoolType => true
case StandardType.ByteType => true
case StandardType.ShortType => true
case StandardType.IntType => true
case StandardType.LongType => true
case StandardType.FloatType => true
case StandardType.DoubleType => true
case StandardType.BinaryType => false
case StandardType.CharType => true
case StandardType.BigIntegerType => false
case StandardType.BigDecimalType => false
case StandardType.UUIDType => false
case StandardType.DayOfWeekType => true
case StandardType.MonthType => true
case StandardType.MonthDayType => false
case StandardType.PeriodType => false
case StandardType.YearType => true
case StandardType.YearMonthType => false
case StandardType.ZoneIdType => false
case StandardType.ZoneOffsetType => true
case StandardType.DurationType => true
case StandardType.InstantType => false
case StandardType.LocalDateType => false
case StandardType.LocalTimeType => false
case StandardType.LocalDateTimeType => false
case StandardType.OffsetTimeType => false
case StandardType.OffsetDateTimeType => false
case StandardType.ZonedDateTimeType => false
case StandardType.CurrencyType => false
}
}
final private[codec] case class EncoderContext(fieldNumber: Option[Int], directByteEncoding: Boolean)
object Encoder extends MutableSchemaBasedValueProcessor[Chunk[Byte], EncoderContext] {
import Protobuf._
override protected def processPrimitive(context: EncoderContext, value: Any, typ: StandardType[Any]): Chunk[Byte] =
if (context.directByteEncoding && typ == StandardType.ByteType) Chunk(value.asInstanceOf[Byte])
else encodePrimitive(context.fieldNumber, typ, value)
override protected def processRecord(
context: EncoderContext,
schema: Schema.Record[_],
value: ListMap[String, Chunk[Byte]]
): Chunk[Byte] = {
val encodedRecord = Chunk.fromIterable(value.values).flatten
encodeKey(WireType.LengthDelimited(encodedRecord.size), context.fieldNumber) ++ encodedRecord
}
override protected def processEnum(
context: EncoderContext,
schema: Schema.Enum[_],
tuple: (String, Chunk[Byte])
): Chunk[Byte] = {
val encoded = tuple._2
encodeKey(WireType.LengthDelimited(encoded.size), context.fieldNumber) ++ encoded
}
override protected def processSequence(
context: EncoderContext,
schema: Schema.Sequence[_, _, _],
value: Chunk[Chunk[Byte]]
): Chunk[Byte] =
if (value.isEmpty) {
val data = encodeKey(WireType.LengthDelimited(0), Some(1))
encodeKey(WireType.LengthDelimited(data.size), context.fieldNumber) ++ encodeKey(
WireType.LengthDelimited(0),
Some(1)
)
} else {
val chunk = value.flatten
val data = encodeKey(
WireType.LengthDelimited(chunk.size),
Some(2)
) ++ chunk
encodeKey(WireType.LengthDelimited(data.size), context.fieldNumber) ++ data
}
override protected def processDictionary(
context: EncoderContext,
schema: Schema.Map[_, _],
value: Chunk[(Chunk[Byte], Chunk[Byte])]
): Chunk[Byte] =
if (value.isEmpty) {
val data = encodeKey(WireType.LengthDelimited(0), Some(1))
encodeKey(WireType.LengthDelimited(data.size), context.fieldNumber) ++ encodeKey(
WireType.LengthDelimited(0),
Some(1)
)
} else {
val chunk = value.map {
case (left, right) =>
val leftDecoder = new Decoder(left)
val rightDecoder = new Decoder(right)
(
leftDecoder.keyDecoder(
DecoderContext(None, packed = false, dictionaryElementContext = None, directByteEncoding = false)
),
rightDecoder.keyDecoder(
DecoderContext(None, packed = false, dictionaryElementContext = None, directByteEncoding = false)
)
) match {
case ((leftWireType, seqIndex), (rightWireType, _)) =>
val data =
encodeKey(leftWireType, Some(1)) ++
leftDecoder.remainder ++
encodeKey(rightWireType, Some(2)) ++
rightDecoder.remainder
encodeKey(WireType.LengthDelimited(data.size), Some(seqIndex)) ++ data
case other =>
throw new IllegalStateException(s"Invalid state in processDictionary: $other")
}
}.flatten
val data = encodeKey(
WireType.LengthDelimited(chunk.size),
Some(2)
) ++ chunk
encodeKey(WireType.LengthDelimited(data.size), context.fieldNumber) ++ data
}
override protected def processSet(
context: EncoderContext,
schema: Schema.Set[_],
value: Set[Chunk[Byte]]
): Chunk[Byte] =
if (value.isEmpty) {
val data = encodeKey(WireType.LengthDelimited(0), Some(1))
encodeKey(WireType.LengthDelimited(data.size), context.fieldNumber) ++ encodeKey(
WireType.LengthDelimited(0),
Some(1)
)
} else {
val chunk = Chunk.fromIterable(value).flatten
val data = encodeKey(
WireType.LengthDelimited(chunk.size),
Some(2)
) ++ chunk
encodeKey(WireType.LengthDelimited(data.size), context.fieldNumber) ++ data
}
override protected def processEither(
context: EncoderContext,
schema: Schema.Either[_, _],
value: Either[Chunk[Byte], Chunk[Byte]]
): Chunk[Byte] = {
val encodedEither = value.merge
encodeKey(WireType.LengthDelimited(encodedEither.size), context.fieldNumber) ++ encodedEither
}
override protected def processFallback(
context: EncoderContext,
schema: Schema.Fallback[_, _],
value: Fallback[Chunk[Byte], Chunk[Byte]]
): Chunk[Byte] = {
val encodedEither = value match {
case Fallback.Left(left) => left
case Fallback.Right(right) => right
case Fallback.Both(left, right) => left ++ right
}
encodeKey(WireType.LengthDelimited(encodedEither.size), context.fieldNumber) ++ encodedEither
}
override protected def processOption(
context: EncoderContext,
schema: Schema.Optional[_],
value: Option[Chunk[Byte]]
): Chunk[Byte] = {
val data = value match {
case Some(bytes) => bytes
case None => encodeKey(WireType.LengthDelimited(0), Some(1))
}
encodeKey(WireType.LengthDelimited(data.size), context.fieldNumber) ++ data
}
override protected def processTuple(
context: EncoderContext,
schema: Schema.Tuple2[_, _],
left: Chunk[Byte],
right: Chunk[Byte]
): Chunk[Byte] = {
val data = left ++ right
encodeKey(WireType.LengthDelimited(data.size), context.fieldNumber) ++ data
}
override protected def processDynamic(context: EncoderContext, value: DynamicValue): Option[Chunk[Byte]] =
None
override protected def fail(context: EncoderContext, message: String): Chunk[Byte] =
throw new RuntimeException(message)
override protected val initialContext: EncoderContext =
EncoderContext(fieldNumber = None, directByteEncoding = false)
override protected def contextForRecordField(
context: EncoderContext,
index: Int,
field: Schema.Field[_, _]
): EncoderContext = {
val fieldNumber = FieldMapping.getFieldNumber(field).getOrElse(index + 1)
context.copy(fieldNumber = Some(fieldNumber))
}
override protected def contextForTuple(context: EncoderContext, index: Int): EncoderContext =
context.copy(fieldNumber = Some(index))
override protected def contextForEnumConstructor(
context: EncoderContext,
index: Int,
c: Schema.Case[_, _]
): EncoderContext =
context.copy(fieldNumber = Some(index + 1))
override protected def contextForEither(context: EncoderContext, e: Either[Unit, Unit]): EncoderContext =
e match {
case Left(_) => context.copy(fieldNumber = Some(1))
case Right(_) => context.copy(fieldNumber = Some(2))
}
override protected def contextForFallback(context: EncoderContext, f: Fallback[Unit, Unit]): EncoderContext =
f match {
case Fallback.Left(_) => context.copy(fieldNumber = Some(1))
case Fallback.Right(_) => context.copy(fieldNumber = Some(2))
case Fallback.Both(_, _) => context.copy(fieldNumber = Some(3))
}
override protected def contextForOption(context: EncoderContext, o: Option[Unit]): EncoderContext =
o match {
case None => context.copy(fieldNumber = Some(1))
case Some(_) => context.copy(fieldNumber = Some(2))
}
override protected def contextForSequence(
context: EncoderContext,
s: Schema.Sequence[_, _, _],
index: Int
): EncoderContext =
if (s.elementSchema == Schema[Byte]) context.copy(fieldNumber = None, directByteEncoding = true)
else if (canBePacked(s.elementSchema)) context.copy(fieldNumber = None)
else context.copy(fieldNumber = Some(index + 1))
override protected def contextForMap(context: EncoderContext, s: Schema.Map[_, _], index: Int): EncoderContext =
if (canBePacked(s.keySchema <*> s.valueSchema)) context.copy(fieldNumber = None)
else context.copy(fieldNumber = Some(index + 1))
override protected def contextForSet(context: EncoderContext, s: Schema.Set[_], index: Int): EncoderContext =
if (canBePacked(s.elementSchema)) context.copy(fieldNumber = None)
else context.copy(fieldNumber = Some(index + 1))
private def encodePrimitive[A](
fieldNumber: Option[Int],
standardType: StandardType[A],
value: A
): Chunk[Byte] =
(standardType, value) match {
case (StandardType.UnitType, _) =>
encodeKey(WireType.LengthDelimited(0), fieldNumber)
case (StandardType.StringType, str: String) =>
val encoded = Chunk.fromArray(str.getBytes(StandardCharsets.UTF_8))
encodeKey(WireType.LengthDelimited(encoded.size), fieldNumber) ++ encoded
case (StandardType.BoolType, b: Boolean) =>
encodeKey(WireType.VarInt, fieldNumber) ++ encodeVarInt(if (b) 1 else 0)
case (StandardType.ShortType, v: Short) =>
encodeKey(WireType.VarInt, fieldNumber) ++ encodeVarInt(v.toLong)
case (StandardType.ByteType, v: Byte) =>
encodeKey(WireType.VarInt, fieldNumber) ++ encodeVarInt(v.toLong)
case (StandardType.IntType, v: Int) =>
encodeKey(WireType.VarInt, fieldNumber) ++ encodeVarInt(v)
case (StandardType.LongType, v: Long) =>
encodeKey(WireType.VarInt, fieldNumber) ++ encodeVarInt(v)
case (StandardType.BigDecimalType, v: java.math.BigDecimal) =>
val unscaled = v.unscaledValue()
val precision = v.precision()
val scale = v.scale()
val encodedRecord =
encodePrimitive(Some(1), StandardType.BigIntegerType, unscaled) ++
encodePrimitive(Some(2), StandardType.IntType, precision) ++
encodePrimitive(Some(3), StandardType.IntType, scale)
encodeKey(WireType.LengthDelimited(encodedRecord.size), fieldNumber) ++ encodedRecord
case (StandardType.BigIntegerType, v: java.math.BigInteger) =>
val encoded = Chunk.fromArray(v.toByteArray)
encodeKey(WireType.LengthDelimited(encoded.size), fieldNumber) ++ encoded
case (StandardType.FloatType, v: Float) =>
val byteBuffer = ByteBuffer.allocate(4)
byteBuffer.order(ByteOrder.LITTLE_ENDIAN)
byteBuffer.putFloat(v)
encodeKey(WireType.Bit32, fieldNumber) ++ Chunk.fromArray(byteBuffer.array)
case (StandardType.DoubleType, v: Double) =>
val byteBuffer = ByteBuffer.allocate(8)
byteBuffer.order(ByteOrder.LITTLE_ENDIAN)
byteBuffer.putDouble(v)
encodeKey(WireType.Bit64, fieldNumber) ++ Chunk.fromArray(byteBuffer.array)
case (StandardType.BinaryType, bytes: Chunk[Byte]) =>
encodeKey(WireType.LengthDelimited(bytes.length), fieldNumber) ++ bytes
case (StandardType.CharType, c: Char) =>
encodePrimitive(fieldNumber, StandardType.StringType, c.toString)
case (StandardType.UUIDType, u: UUID) =>
encodePrimitive(fieldNumber, StandardType.StringType, u.toString)
case (StandardType.DayOfWeekType, v: DayOfWeek) =>
encodePrimitive(fieldNumber, StandardType.IntType, v.getValue)
case (StandardType.MonthType, v: Month) =>
encodePrimitive(fieldNumber, StandardType.IntType, v.getValue)
case (StandardType.MonthDayType, v: MonthDay) =>
val encodedRecord =
encodePrimitive(Some(1), StandardType.IntType, v.getMonthValue) ++
encodePrimitive(Some(2), StandardType.IntType, v.getDayOfMonth)
encodeKey(WireType.LengthDelimited(encodedRecord.size), fieldNumber) ++ encodedRecord
case (StandardType.PeriodType, v: Period) =>
val encodedRecord =
encodePrimitive(Some(1), StandardType.IntType, v.getYears) ++
encodePrimitive(Some(2), StandardType.IntType, v.getMonths) ++
encodePrimitive(Some(3), StandardType.IntType, v.getDays)
encodeKey(WireType.LengthDelimited(encodedRecord.size), fieldNumber) ++ encodedRecord
case (StandardType.YearType, v: Year) =>
encodePrimitive(fieldNumber, StandardType.IntType, v.getValue)
case (StandardType.YearMonthType, v: YearMonth) =>
val encodedRecord =
encodePrimitive(Some(1), StandardType.IntType, v.getYear) ++
encodePrimitive(Some(2), StandardType.IntType, v.getMonthValue)
encodeKey(WireType.LengthDelimited(encodedRecord.size), fieldNumber) ++ encodedRecord
case (StandardType.ZoneIdType, v: ZoneId) =>
encodePrimitive(fieldNumber, StandardType.StringType, v.getId)
case (StandardType.ZoneOffsetType, v: ZoneOffset) =>
encodePrimitive(fieldNumber, StandardType.IntType, v.getTotalSeconds)
case (StandardType.DurationType, v: Duration) =>
val encodedRecord =
encodePrimitive(Some(1), StandardType.LongType, v.getSeconds) ++
encodePrimitive(Some(2), StandardType.IntType, v.getNano)
encodeKey(WireType.LengthDelimited(encodedRecord.size), fieldNumber) ++ encodedRecord
case (StandardType.InstantType, v: Instant) =>
encodePrimitive(fieldNumber, StandardType.StringType, v.toString)
case (StandardType.LocalDateType, v: LocalDate) =>
encodePrimitive(fieldNumber, StandardType.StringType, v.toString)
case (StandardType.LocalTimeType, v: LocalTime) =>
encodePrimitive(fieldNumber, StandardType.StringType, v.toString)
case (StandardType.LocalDateTimeType, v: LocalDateTime) =>
encodePrimitive(fieldNumber, StandardType.StringType, v.toString)
case (StandardType.OffsetTimeType, v: OffsetTime) =>
encodePrimitive(fieldNumber, StandardType.StringType, v.toString)
case (StandardType.OffsetDateTimeType, v: OffsetDateTime) =>
encodePrimitive(fieldNumber, StandardType.StringType, v.toString)
case (StandardType.ZonedDateTimeType, v: ZonedDateTime) =>
encodePrimitive(fieldNumber, StandardType.StringType, v.format(DateTimeFormatter.ISO_ZONED_DATE_TIME))
case (StandardType.CurrencyType, v: java.util.Currency) =>
encodePrimitive(fieldNumber, StandardType.StringType, v.getCurrencyCode)
case (_, _) =>
throw new NotImplementedError(s"No encoder for $standardType")
}
private def encodeVarInt(value: Int): Chunk[Byte] = {
val builder = ChunkBuilder.make[Byte](5)
encodeVarInt(value.toLong, builder)
builder.result()
}
private def encodeVarInt(value: Long): Chunk[Byte] = {
val builder = ChunkBuilder.make[Byte](10)
encodeVarInt(value, builder)
builder.result()
}
private def encodeVarInt(value: Long, builder: ChunkBuilder[Byte]): Unit = {
var current = value
var higherBits = current >>> 7
var done = false
while (!done) {
if (higherBits != 0x00) {
builder += (0x80 | (current & 0x7F)).byteValue()
current = higherBits
higherBits = higherBits >>> 7
} else {
builder += (current & 0x7F).byteValue()
done = true
}
}
}
/**
* Encodes key. Key contains field number out of flatten schema structure and wire type.
* 1 << 3 => 8, 2 << 3 => 16, 3 << 3 => 24
*
* More info:
* https://developers.google.com/protocol-buffers/docs/encoding#structure
*/
private[codec] def encodeKey(wireType: WireType, fieldNumber: Option[Int]): Chunk[Byte] =
fieldNumber.map { fieldNumber =>
val encode = (baseWireType: Int) => encodeVarInt(fieldNumber << 3 | baseWireType)
wireType match {
case WireType.VarInt => encode(0)
case WireType.Bit64 => encode(1)
case WireType.LengthDelimited(length) => encode(2) ++ encodeVarInt(length)
case WireType.StartGroup => encode(3)
case WireType.EndGroup => encode(4)
case WireType.Bit32 => encode(5)
}
}.getOrElse(Chunk.empty)
}
final class DecoderState(chunk: Chunk[Byte], private var position: Int) {
def length(context: DecoderContext): Int = context.limit.getOrElse(chunk.length) - position
def read(count: Int): Chunk[Byte] = {
val oldPosition = position
position += count
chunk.slice(oldPosition, position)
}
def all(context: DecoderContext): Chunk[Byte] = read(length(context))
def peek(context: DecoderContext): Chunk[Byte] =
chunk.slice(position, position + length(context))
def peek: Byte = chunk.byte(position)
def move(count: Int): Unit =
position += count
def currentPosition: Int = position
}
final case class DecoderContext(
limit: Option[Int],
packed: Boolean,
dictionaryElementContext: Option[DecoderContext],
directByteEncoding: Boolean
) {
def limitedTo(state: DecoderState, w: Int): DecoderContext =
copy(limit = Some(state.currentPosition + w))
}
class Decoder(chunk: Chunk[Byte]) extends MutableSchemaBasedValueBuilder[Any, DecoderContext] {
import Protobuf._
private val state: DecoderState = new DecoderState(chunk, 0)
private val fieldMappingCache = new FieldMappingCache()
def decode[A](schema: Schema[A]): scala.util.Either[DecodeError, A] =
try {
Right(create(schema).asInstanceOf[A])
} catch {
case CreateValueFromSchemaError(_, cause) =>
cause match {
case error: DecodeError => Left(error)
case _ =>
Left(DecodeError.ReadError(Cause.fail(cause), cause.getMessage))
}
case NonFatal(err) =>
Left(DecodeError.ReadError(Cause.fail(err), err.getMessage))
}
private def createTypedPrimitive[A](context: DecoderContext, standardType: StandardType[A]): A =
createPrimitive(context, standardType).asInstanceOf[A]
override protected def createPrimitive(context: DecoderContext, typ: StandardType[_]): Any =
typ match {
case StandardType.UnitType => ()
case StandardType.StringType => stringDecoder(context)
case StandardType.BoolType => varIntDecoder(context) != 0
case StandardType.ShortType => varIntDecoder(context).shortValue
case StandardType.ByteType if !context.directByteEncoding => varIntDecoder(context).byteValue
case StandardType.ByteType if context.directByteEncoding =>
val result = state.peek
state.move(1)
result
case StandardType.IntType => varIntDecoder(context).intValue
case StandardType.LongType => varIntDecoder(context)
case StandardType.FloatType => floatDecoder(context)
case StandardType.DoubleType => doubleDecoder(context)
case StandardType.BigIntegerType =>
val bytes = binaryDecoder(context)
new java.math.BigInteger(bytes.toArray)
case StandardType.BigDecimalType =>
val unscaled = createTypedPrimitive(rawFieldDecoder(context, 1), StandardType.BigIntegerType)
val precision = createTypedPrimitive(rawFieldDecoder(context, 2), StandardType.IntType)
val scale = createTypedPrimitive(rawFieldDecoder(context, 3), StandardType.IntType)
val ctx = new java.math.MathContext(precision)
new java.math.BigDecimal(unscaled, scale, ctx)
case StandardType.BinaryType => binaryDecoder(context)
case StandardType.CharType => stringDecoder(context).charAt(0)
case StandardType.UUIDType =>
val uuid = stringDecoder(context)
try UUID.fromString(uuid)
catch {
case NonFatal(_) =>
throw MalformedField(Schema.primitive[UUID], s"Invalid UUID string $uuid")
}
case StandardType.DayOfWeekType =>
DayOfWeek.of(varIntDecoder(context).intValue)
case StandardType.MonthType =>
Month.of(varIntDecoder(context).intValue)
case StandardType.MonthDayType =>
val month = createTypedPrimitive(rawFieldDecoder(context, 1), StandardType.IntType)
val day = createTypedPrimitive(rawFieldDecoder(context, 2), StandardType.IntType)
MonthDay.of(month, day)
case StandardType.PeriodType =>
val years = createTypedPrimitive(rawFieldDecoder(context, 1), StandardType.IntType)
val months = createTypedPrimitive(rawFieldDecoder(context, 2), StandardType.IntType)
val days = createTypedPrimitive(rawFieldDecoder(context, 3), StandardType.IntType)
Period.of(years, months, days)
case StandardType.YearType =>
Year.of(varIntDecoder(context).intValue)
case StandardType.YearMonthType =>
val year = createTypedPrimitive(rawFieldDecoder(context, 1), StandardType.IntType)
val month = createTypedPrimitive(rawFieldDecoder(context, 2), StandardType.IntType)
YearMonth.of(year, month)
case StandardType.ZoneIdType => ZoneId.of(stringDecoder(context))
case StandardType.ZoneOffsetType =>
ZoneOffset.ofTotalSeconds(varIntDecoder(context).intValue)
case StandardType.DurationType =>
val seconds = createTypedPrimitive(rawFieldDecoder(context, 1), StandardType.LongType)
val nanos = createTypedPrimitive(rawFieldDecoder(context, 2), StandardType.IntType)
Duration.ofSeconds(seconds, nanos.toLong)
case StandardType.InstantType =>
Instant.parse(stringDecoder(context))
case StandardType.LocalDateType =>
LocalDate.parse(stringDecoder(context))
case StandardType.LocalTimeType =>
LocalTime.parse(stringDecoder(context))
case StandardType.LocalDateTimeType =>
LocalDateTime.parse(stringDecoder(context))
case StandardType.OffsetTimeType =>
OffsetTime.parse(stringDecoder(context))
case StandardType.OffsetDateTimeType =>
OffsetDateTime.parse(stringDecoder(context))
case StandardType.ZonedDateTimeType =>
ZonedDateTime.parse(stringDecoder(context))
case StandardType.CurrencyType => java.util.Currency.getInstance(stringDecoder(context))
case st => fail(context, s"Unsupported primitive type $st")
}
override protected def startCreatingRecord(context: DecoderContext, record: Schema.Record[_]): DecoderContext =
context
override protected def startReadingField(
context: DecoderContext,
record: Schema.Record[_],
index: Int
): ReadingFieldResult[DecoderContext] =
if (state.length(context) <= 0) {
ReadingFieldResult.Finished()
} else {
keyDecoder(context) match {
case (wt, fieldNumber) =>
val fieldMapping = fieldMappingCache.get(record)
fieldMapping.fieldNumberToIndex.get(fieldNumber) match {
case Some(index) => {
if (record.fields.isDefinedAt(index)) {
ReadingFieldResult.ReadField(wt match {
case LengthDelimited(width) =>
context.limitedTo(state, width)
case _ =>
context
}, index)
} else {
throw ExtraFields(
"Unknown",
s"Failed to decode record. Schema does not contain field number $fieldNumber."
)
}
}
case None =>
wt match {
case WireType.VarInt => {
varIntDecoder(context)
ReadingFieldResult.UpdateContext(context)
}
case WireType.Bit64 => {
state.move(8)
ReadingFieldResult.UpdateContext(context)
}
case LengthDelimited(width) => {
state.move(width)
ReadingFieldResult.UpdateContext(context)
}
case WireType.Bit32 => {
state.move(4)
ReadingFieldResult.UpdateContext(context)
}
case _ =>
throw ExtraFields(
"Unknown",
s"Failed to decode record. Schema does not contain field number $fieldNumber and it's length is unknown"
)
}
}
}
}
override protected def createRecord(
context: DecoderContext,
record: Schema.Record[_],
values: Chunk[(Int, Any)]
): Any =
Unsafe.unsafe { implicit u =>
val array = new Array[Any](record.fields.length)
val mask = Array.fill(record.fields.length)(false)
for ((field, index) <- record.fields.zipWithIndex) {
val defaultValue = field.annotations.collectFirst {
case fieldDefaultValue(defaultValue) => defaultValue
}
defaultValue match {
case Some(defaultValue) =>
array(index) = defaultValue
mask(index) = true
case None =>
}
}
for ((index, value) <- values) {
if (index < array.length) {
array(index) = value
mask(index) = true;
}
}
if (mask.forall(set => set)) {
record.construct(Chunk.fromArray(array)) match {
case Right(result) => result
case Left(message) => throw DecodeError.ReadError(Cause.empty, message)
}
} else {
throw DecodeError.ReadError(
Cause.empty,
s"Failed to decode record. Missing fields: ${mask.zip(record.fields).filter(!_._1).map(_._2).mkString(", ")}"
)
}
}
override protected def startCreatingEnum(
context: DecoderContext,
cases: Chunk[Schema.Case[_, _]]
): (DecoderContext, Int) =
keyDecoder(context) match {
case (wt, fieldNumber) if fieldNumber <= cases.length =>
wt match {
case LengthDelimited(width) =>
(context.limitedTo(state, width), fieldNumber - 1)
case _ =>
(context, fieldNumber - 1)
}
case (_, fieldNumber) =>
throw MissingField(
cases(fieldNumber - 1).schema,
s"Failed to decode enumeration. Schema does not contain field number $fieldNumber."
)
}
override protected def createEnum(
context: DecoderContext,
cases: Chunk[Schema.Case[_, _]],
index: Int,
value: Any
): Any =
value
override protected def startCreatingSequence(
context: DecoderContext,
schema: Schema.Sequence[_, _, _]
): Option[DecoderContext] =
keyDecoder(context) match {
case (LengthDelimited(0), 1) =>
None
case (LengthDelimited(width), 2) =>
Some(context.limitedTo(state, width).copy(packed = canBePacked(schema.elementSchema)))
case (wt, fieldNumber) =>
throw MalformedField(schema, s"Invalid wire type ($wt) or field number ($fieldNumber) for packed sequence")
}
override protected def startCreatingOneSequenceElement(
context: DecoderContext,
schema: Schema.Sequence[_, _, _]
): DecoderContext =
if (context.packed)
if (schema.elementSchema == Schema[Byte])
context.copy(directByteEncoding = true)
else
context
else {
keyDecoder(context) match {
case (wt, _) =>
wt match {
case LengthDelimited(elemWidth) =>
context.limitedTo(state, elemWidth)
case _ =>
throw MalformedField(schema, s"Unexpected wire type $wt for non-packed sequence")
}
}
}
override protected def finishedCreatingOneSequenceElement(
context: DecoderContext,
schema: Schema.Sequence[_, _, _],
index: Int
): Boolean =
state.length(context) > 0
override protected def createSequence(
context: DecoderContext,
schema: Schema.Sequence[_, _, _],
values: Chunk[Any]
): Any =
schema.fromChunk.asInstanceOf[Chunk[Any] => Any](values)
override protected def startCreatingDictionary(
context: DecoderContext,
schema: Schema.Map[_, _]
): Option[DecoderContext] =
keyDecoder(context) match {
case (LengthDelimited(0), 1) =>
None
case (LengthDelimited(width), 2) =>
Some(context.limitedTo(state, width).copy(packed = canBePacked(schema.keySchema.zip(schema.valueSchema))))
case (wt, fieldNumber) =>
throw MalformedField(schema, s"Invalid wire type ($wt) or field number ($fieldNumber) for packed sequence")
}
override protected def startCreatingOneDictionaryElement(
context: DecoderContext,
schema: Schema.Map[_, _]
): DecoderContext = {
val elemContext =
if (context.packed) {
context
} else {
keyDecoder(context) match {
case (wt, _) =>
wt match {
case LengthDelimited(elemWidth) =>
context.limitedTo(state, elemWidth)
case _ =>
throw MalformedField(schema, s"Unexpected wire type $wt for non-packed sequence")
}
}
}
enterFirstTupleElement(elemContext, schema).copy(dictionaryElementContext = Some(elemContext))
}
override protected def startCreatingOneDictionaryValue(
context: DecoderContext,
schema: Schema.Map[_, _]
): DecoderContext =
enterSecondTupleElement(context.dictionaryElementContext.getOrElse(context), schema)
override protected def finishedCreatingOneDictionaryElement(
context: DecoderContext,
schema: Schema.Map[_, _],
index: Int
): Boolean =
state.length(context) > 0
override protected def createDictionary(
context: DecoderContext,
schema: Schema.Map[_, _],
values: Chunk[(Any, Any)]
): Any =
values.toMap
override protected def startCreatingSet(context: DecoderContext, schema: Schema.Set[_]): Option[DecoderContext] =
keyDecoder(context) match {
case (LengthDelimited(0), 1) =>
None
case (LengthDelimited(width), 2) =>
Some(context.limitedTo(state, width).copy(packed = canBePacked(schema.elementSchema)))
case (wt, fieldNumber) =>
throw MalformedField(schema, s"Invalid wire type ($wt) or field number ($fieldNumber) for packed sequence")
}
override protected def startCreatingOneSetElement(context: DecoderContext, schema: Schema.Set[_]): DecoderContext =
if (context.packed) {
context
} else {
keyDecoder(context) match {
case (wt, _) =>
wt match {
case LengthDelimited(elemWidth) =>
context.limitedTo(state, elemWidth)
case _ =>
throw MalformedField(schema, s"Unexpected wire type $wt for non-packed sequence")
}
}
}
override protected def finishedCreatingOneSetElement(
context: DecoderContext,
schema: Schema.Set[_],
index: Int
): Boolean =
state.length(context) > 0
override protected def createSet(context: DecoderContext, schema: Schema.Set[_], values: Chunk[Any]): Any =
values.toSet
override protected def startCreatingOptional(
context: DecoderContext,
schema: Schema.Optional[_]
): Option[DecoderContext] =
keyDecoder(context) match {
case (LengthDelimited(0), 1) => None
case (LengthDelimited(width), 2) => Some(context.limitedTo(state, width))
case (_, 2) => Some(context)
case (_, fieldNumber) =>
throw MalformedField(schema, s"Invalid field number $fieldNumber for option")
}
override protected def createOptional(
context: DecoderContext,
schema: Schema.Optional[_],
value: Option[Any]
): Any =
value
override protected def startCreatingEither(
context: DecoderContext,
schema: Schema.Either[_, _]
): Either[DecoderContext, DecoderContext] =
keyDecoder(context) match {
case (_, fieldNumber) if fieldNumber == 1 => Left(context)
case (_, fieldNumber) if fieldNumber == 2 => Right(context)
case (_, fieldNumber) =>
throw ExtraFields(fieldNumber.toString, s"Invalid field number ($fieldNumber) for either")
}
override protected def createEither(
context: DecoderContext,
schema: Schema.Either[_, _],
value: Either[Any, Any]
): Any =
value
override protected def startCreatingFallback(
context: DecoderContext,
schema: Schema.Fallback[_, _]
): Fallback[DecoderContext, DecoderContext] =
keyDecoder(context) match {
case (_, fieldNumber) if fieldNumber == 1 => Fallback.Left(context)
case (_, fieldNumber) if fieldNumber == 2 => Fallback.Right(context)
case (_, fieldNumber) if fieldNumber == 3 => Fallback.Both(context, context)
case _ =>
throw ExtraFields(fieldNumber.toString, s"Invalid field number ($fieldNumber) for fallback")
}
override protected def startReadingRightFallback(
context: DecoderContext,
schema: Schema.Fallback[_, _]
): DecoderContext =
keyDecoder(context) match {
case (wt, 2) =>
wt match {
case LengthDelimited(width) => context.limitedTo(state, width)
case _ => context
}
case (_, fieldNumber) =>
throw MalformedField(schema, s"Invalid field number $fieldNumber for fallback's right field")
}
override protected def createFallback(
context: DecoderContext,
schema: Schema.Fallback[_, _],
value: Fallback[Any, Any]
): Any =
if (schema.fullDecode) value else value.simplify
override protected def startCreatingTuple(context: DecoderContext, schema: Schema.Tuple2[_, _]): DecoderContext =
enterFirstTupleElement(context, schema)
private def enterFirstTupleElement(context: DecoderContext, schema: Schema[_]): DecoderContext =
keyDecoder(context) match {
case (wt, 1) =>
wt match {
case LengthDelimited(width) => context.limitedTo(state, width)
case _ => context
}
case (_, fieldNumber) =>
throw MalformedField(schema, s"Invalid field number $fieldNumber for tuple's first field")
}
override protected def startReadingSecondTupleElement(
context: DecoderContext,
schema: Schema.Tuple2[_, _]
): DecoderContext =
enterSecondTupleElement(context, schema)
private def enterSecondTupleElement(context: DecoderContext, schema: Schema[_]): DecoderContext =
keyDecoder(context) match {
case (wt, 2) =>
wt match {
case LengthDelimited(width) => context.limitedTo(state, width)
case _ => context
}
case (_, fieldNumber) =>
throw MalformedField(schema, s"Invalid field number $fieldNumber for tuple's second field")
}
override protected def createTuple(
context: DecoderContext,
schema: Schema.Tuple2[_, _],
left: Any,
right: Any
): Any =
(left, right)
override protected def createDynamic(context: DecoderContext): Option[Any] =
None
override protected def transform(
context: DecoderContext,
value: Any,
f: Any => Either[String, Any],
schema: Schema[_]
): Any =
f(value) match {
case Left(value) => throw MalformedField(schema, value)
case Right(value) => value
}
override protected def fail(context: DecoderContext, message: String): Any =
throw DecodeError.ReadError(Cause.empty, message)
override protected val initialContext: DecoderContext =
DecoderContext(limit = None, packed = false, dictionaryElementContext = None, directByteEncoding = false)
/**
* Decodes key which consist out of field type (wire type) and a field number.
*
* 8 >>> 3 => 1, 16 >>> 3 => 2, 24 >>> 3 => 3, 32 >>> 3 => 4
* 0 & 0x07 => 0, 1 & 0x07 => 1, 2 & 0x07 => 2, 9 & 0x07 => 1, 15 & 0x07 => 7
*/
private[codec] def keyDecoder(context: DecoderContext): (WireType, Int) = {
val key = varIntDecoder(context)
val fieldNumber = (key >>> 3).toInt
if (fieldNumber < 1) {
throw ExtraFields(fieldNumber.toString, s"Failed decoding key. Invalid field number $fieldNumber")
} else {
key & 0x07 match {
case 0 => (WireType.VarInt, fieldNumber)
case 1 => (WireType.Bit64, fieldNumber)
case 2 =>
val length = varIntDecoder(context)
(WireType.LengthDelimited(length.toInt), fieldNumber)
case 3 => (WireType.StartGroup, fieldNumber)
case 4 => (WireType.EndGroup, fieldNumber)
case 5 => (WireType.Bit32, fieldNumber)
case n =>
throw ExtraFields(fieldNumber.toString, s"Failed decoding key. Unknown wire type $n")
}
}
}
private def rawFieldDecoder(context: DecoderContext, expectedFieldNumber: Int): DecoderContext =
keyDecoder(context) match {
case (wt, fieldNumber) if fieldNumber == expectedFieldNumber =>
wt match {
case LengthDelimited(width) =>
context.limitedTo(state, width)
case _ =>
context
}
case _ =>
throw ExtraFields(
"Unknown",
s"Failed to decode record. Schema does not contain field number $expectedFieldNumber."
)
}
private def floatDecoder(context: DecoderContext): Float =
if (state.length(context) < 4)
throw MalformedField(
Schema.primitive[Float],
s"Invalid number of bytes for Float. Expected 4, got ${state.length(context)}"
)
else {
val bytes = state.read(4)
ByteBuffer.wrap(bytes.toArray).order(ByteOrder.LITTLE_ENDIAN).getFloat()
}
private def doubleDecoder(context: DecoderContext): Double =
if (state.length(context) < 8)
throw MalformedField(
Schema.primitive[Double],
s"Invalid number of bytes for Double. Expected 8, got ${state.length(context)}"
)
else {
val bytes = state.read(8)
ByteBuffer.wrap(bytes.toArray).order(ByteOrder.LITTLE_ENDIAN).getDouble()
}
private def stringDecoder(context: DecoderContext): String = {
val bytes = state.all(context)
new String(bytes.toArray, StandardCharsets.UTF_8)
}
/**
* Decodes bytes to following types: int32, int64, uint32, uint64, sint32, sint64, bool, enumN.
* Takes index of first byte which is inside 0 - 127 range.
*
* (0 -> 127) & 0x80 => 0, (128 -> 255) & 0x80 => 128
* (0 << 7 => 0, 1 << 7 => 128, 2 << 7 => 256, 3 << 7 => 384
* 1 & 0X7F => 1, 127 & 0x7F => 127, 128 & 0x7F => 0, 129 & 0x7F => 1
*/
private def varIntDecoder(context: DecoderContext): Long = {
val maxLength = state.length(context)
if (maxLength == 0) {
throw MalformedField(Schema.primitive[Long], "Failed to decode VarInt. Unexpected end of chunk")
} else {
var count = 0
var done = false
var result = 0L
while (count < maxLength && !done) {
val byte = state.peek
result = result | (byte & 0x7f).toLong << (count * 7)
state.move(1)
if ((byte & 0x80) == 0) {
done = true
} else {
count += 1
}
}
if (!done) {
throw MalformedField(
Schema.primitive[Long],
"Failed to decode VarInt. No byte within the range 0 - 127 are present"
)
}
result
}
}
private def binaryDecoder(context: DecoderContext): Chunk[Byte] =
state.all(context)
private[codec] def remainder: Chunk[Byte] =
state.peek(DecoderContext(None, packed = false, dictionaryElementContext = None, directByteEncoding = false))
}
}