smithy4s.http.json.SchemaVisitorJCodec.scala Maven / Gradle / Ivy
/*
* Copyright 2021-2022 Disney Streaming
*
* Licensed under the Tomorrow Open Source Technology License, Version 1.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://disneystreaming.github.io/TOST-1.0.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package smithy4s
package http
package json
import java.util.UUID
import java.util
import com.github.plokhotnyuk.jsoniter_scala.core.JsonReader
import com.github.plokhotnyuk.jsoniter_scala.core.JsonWriter
import smithy.api.HttpPayload
import smithy.api.JsonName
import smithy.api.TimestampFormat
import smithy4s.api.Discriminated
import smithy4s.api.Untagged
import smithy4s.internals.DiscriminatedUnionMember
import smithy4s.schema._
import smithy4s.schema.Primitive._
import smithy4s.Timestamp
import scala.collection.compat.immutable.ArraySeq
import scala.collection.immutable.VectorBuilder
import scala.collection.mutable.ListBuffer
import scala.collection.mutable.{Map => MMap}
private[smithy4s] class SchemaVisitorJCodec(
maxArity: Int,
val cache: CompilationCache[JCodec]
) extends SchemaVisitor.Cached[JCodec] { self =>
private val emptyMetadata: MMap[String, Any] = MMap.empty
object PrimitiveJCodecs {
val boolean: JCodec[Boolean] =
new JCodec[Boolean] {
def expecting: String = "boolean"
def decodeValue(cursor: Cursor, in: JsonReader): Boolean =
in.readBoolean()
def encodeValue(x: Boolean, out: JsonWriter): Unit = out.writeVal(x)
def decodeKey(in: JsonReader): Boolean = in.readKeyAsBoolean()
def encodeKey(x: Boolean, out: JsonWriter): Unit = out.writeKey(x)
}
val string: JCodec[String] =
new JCodec[String] {
def expecting: String = "string"
def decodeValue(cursor: Cursor, in: JsonReader): String =
in.readString(null)
def encodeValue(x: String, out: JsonWriter): Unit = out.writeVal(x)
def decodeKey(in: JsonReader): String = in.readKeyAsString()
def encodeKey(x: String, out: JsonWriter): Unit = out.writeKey(x)
}
val int: JCodec[Int] =
new JCodec[Int] {
def expecting: String = "int"
def decodeValue(cursor: Cursor, in: JsonReader): Int = in.readInt()
def encodeValue(x: Int, out: JsonWriter): Unit = out.writeVal(x)
def decodeKey(in: JsonReader): Int = in.readKeyAsInt()
def encodeKey(x: Int, out: JsonWriter): Unit = out.writeKey(x)
}
val long: JCodec[Long] =
new JCodec[Long] {
def expecting: String = "long"
def decodeValue(cursor: Cursor, in: JsonReader): Long = in.readLong()
def encodeValue(x: Long, out: JsonWriter): Unit = out.writeVal(x)
def decodeKey(in: JsonReader): Long = in.readKeyAsLong()
def encodeKey(x: Long, out: JsonWriter): Unit = out.writeKey(x)
}
val float: JCodec[Float] =
new JCodec[Float] {
def expecting: String = "float"
def decodeValue(cursor: Cursor, in: JsonReader): Float = in.readFloat()
def encodeValue(x: Float, out: JsonWriter): Unit = out.writeVal(x)
def decodeKey(in: JsonReader): Float = in.readKeyAsFloat()
def encodeKey(x: Float, out: JsonWriter): Unit = out.writeKey(x)
}
val double: JCodec[Double] =
new JCodec[Double] {
def expecting: String = "double"
def decodeValue(cursor: Cursor, in: JsonReader): Double =
in.readDouble()
def encodeValue(x: Double, out: JsonWriter): Unit = out.writeVal(x)
def decodeKey(in: JsonReader): Double = in.readKeyAsDouble()
def encodeKey(x: Double, out: JsonWriter): Unit = out.writeKey(x)
}
val short: JCodec[Short] =
new JCodec[Short] {
def expecting: String = "short"
def decodeValue(cursor: Cursor, in: JsonReader): Short = in.readShort()
def encodeValue(x: Short, out: JsonWriter): Unit = out.writeVal(x)
def decodeKey(in: JsonReader): Short = in.readKeyAsShort()
def encodeKey(x: Short, out: JsonWriter): Unit = out.writeKey(x)
}
val byte: JCodec[Byte] =
new JCodec[Byte] {
def expecting: String = "byte"
def decodeValue(cursor: Cursor, in: JsonReader): Byte = in.readByte()
def encodeValue(x: Byte, out: JsonWriter): Unit = out.writeVal(x)
def decodeKey(in: JsonReader): Byte = in.readKeyAsByte()
def encodeKey(x: Byte, out: JsonWriter): Unit = out.writeKey(x)
}
val bytes: JCodec[ByteArray] =
new JCodec[ByteArray] {
def expecting: String = "byte-array" // or blob?
override def canBeKey: Boolean = false
def decodeValue(cursor: Cursor, in: JsonReader): ByteArray = ByteArray(
in.readBase64AsBytes(null)
)
def encodeValue(x: ByteArray, out: JsonWriter): Unit =
out.writeBase64Val(x.array, doPadding = true)
def decodeKey(in: JsonReader): ByteArray =
in.decodeError("Cannot use byte array as key")
def encodeKey(x: ByteArray, out: JsonWriter): Unit =
out.encodeError("Cannot use byte array as key")
}
val bigdecimal: JCodec[BigDecimal] =
new JCodec[BigDecimal] {
def expecting: String = "big-decimal"
def decodeValue(cursor: Cursor, in: JsonReader): BigDecimal =
in.readBigDecimal(null)
def decodeKey(in: JsonReader): BigDecimal = in.readKeyAsBigDecimal()
def encodeValue(value: BigDecimal, out: JsonWriter): Unit =
out.writeVal(value)
def encodeKey(value: BigDecimal, out: JsonWriter): Unit =
out.writeVal(value)
}
val bigint: JCodec[BigInt] =
new JCodec[BigInt] {
def expecting: String = "big-int"
def decodeValue(cursor: Cursor, in: JsonReader): BigInt =
in.readBigInt(null)
def decodeKey(in: JsonReader): BigInt = in.readKeyAsBigInt()
def encodeValue(value: BigInt, out: JsonWriter): Unit =
out.writeVal(value)
def encodeKey(value: BigInt, out: JsonWriter): Unit =
out.writeVal(value)
}
val uuid: JCodec[UUID] =
new JCodec[UUID] {
def expecting: String = "uuid"
def decodeValue(cursor: Cursor, in: JsonReader): UUID =
in.readUUID(null)
def encodeValue(x: UUID, out: JsonWriter): Unit = out.writeVal(x)
def decodeKey(in: JsonReader): UUID = in.readKeyAsUUID()
def encodeKey(x: UUID, out: JsonWriter): Unit = out.writeKey(x)
}
val timestampDateTime: JCodec[Timestamp] = new JCodec[Timestamp] {
val expecting: String = Timestamp.showFormat(TimestampFormat.DATE_TIME)
def decodeValue(cursor: Cursor, in: JsonReader): Timestamp =
Timestamp.parse(in.readString(null), TimestampFormat.DATE_TIME) match {
case x: Some[Timestamp] => x.get
case _ => in.decodeError("expected " + expecting)
}
def encodeValue(x: Timestamp, out: JsonWriter): Unit =
out.writeNonEscapedAsciiVal(x.format(TimestampFormat.DATE_TIME))
def decodeKey(in: JsonReader): Timestamp =
Timestamp.parse(in.readKeyAsString(), TimestampFormat.DATE_TIME) match {
case x: Some[Timestamp] => x.get
case _ => in.decodeError("expected " + expecting)
}
def encodeKey(x: Timestamp, out: JsonWriter): Unit =
out.writeNonEscapedAsciiKey(x.format(TimestampFormat.DATE_TIME))
}
val timestampHttpDate: JCodec[Timestamp] = new JCodec[Timestamp] {
val expecting: String = Timestamp.showFormat(TimestampFormat.HTTP_DATE)
def decodeValue(cursor: Cursor, in: JsonReader): Timestamp =
Timestamp.parse(in.readString(null), TimestampFormat.HTTP_DATE) match {
case x: Some[Timestamp] => x.get
case _ => in.decodeError("expected " + expecting)
}
def encodeValue(x: Timestamp, out: JsonWriter): Unit =
out.writeNonEscapedAsciiVal(x.format(TimestampFormat.HTTP_DATE))
def decodeKey(in: JsonReader): Timestamp =
Timestamp.parse(in.readKeyAsString(), TimestampFormat.HTTP_DATE) match {
case x: Some[Timestamp] => x.get
case _ => in.decodeError("expected " + expecting)
}
def encodeKey(x: Timestamp, out: JsonWriter): Unit =
out.writeNonEscapedAsciiKey(x.format(TimestampFormat.HTTP_DATE))
}
val timestampEpochSeconds: JCodec[Timestamp] = new JCodec[Timestamp] {
val expecting: String =
Timestamp.showFormat(TimestampFormat.EPOCH_SECONDS)
def decodeValue(cursor: Cursor, in: JsonReader): Timestamp = {
val timestamp = in.readBigDecimal(null)
val epochSecond = timestamp.toLong
Timestamp(epochSecond, ((timestamp - epochSecond) * 1000000000).toInt)
}
def encodeValue(x: Timestamp, out: JsonWriter): Unit = {
if (x.nano == 0) {
out.writeVal(x.epochSecond)
} else {
out.writeVal(BigDecimal(x.epochSecond) + x.nano / 1000000000.0)
}
}
def decodeKey(in: JsonReader): Timestamp = {
val timestamp = in.readKeyAsBigDecimal()
val epochSecond = timestamp.toLong
Timestamp(epochSecond, ((timestamp - epochSecond) * 1000000000).toInt)
}
def encodeKey(x: Timestamp, out: JsonWriter): Unit =
out.writeKey(BigDecimal(x.epochSecond) + x.nano / 1000000000.0)
}
val unit: JCodec[Unit] =
new JCodec[Unit] {
def expecting: String = "empty object"
override def canBeKey: Boolean = false
def decodeValue(cursor: Cursor, in: JsonReader): Unit =
if (!in.isNextToken('{') || !in.isNextToken('}'))
in.decodeError("Expected empty object")
def encodeValue(x: Unit, out: JsonWriter): Unit = {
out.writeObjectStart()
out.writeObjectEnd()
}
def decodeKey(in: JsonReader): Unit =
in.decodeError("Cannot use Unit as keys")
def encodeKey(x: Unit, out: JsonWriter): Unit =
out.encodeError("Cannot use Unit as keys")
}
def document(maxArity: Int): JCodec[Document] = new JCodec[Document] {
import Document._
override def canBeKey: Boolean = false
def encodeValue(doc: Document, out: JsonWriter): Unit = doc match {
case s: DString => out.writeVal(s.value)
case b: DBoolean => out.writeVal(b.value)
case n: DNumber => out.writeVal(n.value)
case a: DArray =>
out.writeArrayStart()
a.value match {
case x: ArraySeq[Document] =>
val xs = x.unsafeArray.asInstanceOf[Array[Document]]
var i = 0
while (i < xs.length) {
encodeValue(xs(i), out)
i += 1
}
case xs =>
xs.foreach(encodeValue(_, out))
}
out.writeArrayEnd()
case o: DObject =>
out.writeObjectStart()
o.value.foreach { kv =>
out.writeKey(kv._1)
encodeValue(kv._2, out)
}
out.writeObjectEnd()
case _ => out.writeNull()
}
def decodeKey(in: JsonReader): Document =
in.decodeError("Cannot use JSON document as keys")
def encodeKey(x: Document, out: JsonWriter): Unit =
out.encodeError("Cannot use JSON documents as keys")
def expecting: String = "JSON document"
// Borrowed from: https://github.com/plokhotnyuk/jsoniter-scala/blob/e80d51019b39efacff9e695de97dce0c23ae9135/jsoniter-scala-benchmark/src/main/scala/io/circe/CirceJsoniter.scala
def decodeValue(cursor: Cursor, in: JsonReader): Document = {
val b = in.nextToken()
if (b == '"') {
in.rollbackToken()
new DString(in.readString(null))
} else if (b == 'f' || b == 't') {
in.rollbackToken()
new DBoolean(in.readBoolean())
} else if ((b >= '0' && b <= '9') || b == '-') {
in.rollbackToken()
new DNumber(in.readBigDecimal(null))
} else if (b == '[') {
new DArray({
if (in.isNextToken(']')) ArraySeq.empty[Document]
else
ArraySeq.unsafeWrapArray {
in.rollbackToken()
var arr = new Array[Document](4)
var i = 0
while ({
if (i >= maxArity) maxArityError(cursor)
if (i == arr.length)
arr = java.util.Arrays.copyOf(arr, i << 1)
arr(i) = decodeValue(in, null)
i += 1
in.isNextToken(',')
}) {}
if (in.isCurrentToken(']')) {
if (i == arr.length) arr
else java.util.Arrays.copyOf(arr, i)
} else in.arrayEndOrCommaError()
}
})
} else if (b == '{') {
new DObject({
if (in.isNextToken('}')) Map.empty
else {
in.rollbackToken()
// We use the maxArity limit to mitigate DoS vulnerability in default Scala `Map` implementation: https://github.com/scala/bug/issues/11203
val obj = Map.newBuilder[String, Document]
var i = 0
while ({
if (i >= maxArity) maxArityError(cursor)
obj += ((in.readKeyAsString(), decodeValue(in, null)))
i += 1
in.isNextToken(',')
}) {}
if (in.isCurrentToken('}')) obj.result()
else in.objectEndOrCommaError()
}
})
} else in.readNullOrError(DNull, "expected JSON document")
}
private def maxArityError(cursor: Cursor): Nothing =
throw cursor.payloadError(
this,
s"Input $expecting exceeded max arity of $maxArity"
)
}
}
private val documentJCodec = PrimitiveJCodecs.document(maxArity)
override def primitive[P](
shapeId: ShapeId,
hints: Hints,
tag: Primitive[P]
): JCodec[P] = {
tag match {
case PBigDecimal => PrimitiveJCodecs.bigdecimal
case PBigInt => PrimitiveJCodecs.bigint
case PBlob => PrimitiveJCodecs.bytes
case PBoolean => PrimitiveJCodecs.boolean
case PByte => PrimitiveJCodecs.byte
case PDocument => documentJCodec
case PDouble => PrimitiveJCodecs.double
case PFloat => PrimitiveJCodecs.float
case PInt => PrimitiveJCodecs.int
case PLong => PrimitiveJCodecs.long
case PShort => PrimitiveJCodecs.short
case PString => PrimitiveJCodecs.string
case PTimestamp =>
hints.get(TimestampFormat).getOrElse(TimestampFormat.DATE_TIME) match {
case TimestampFormat.DATE_TIME => PrimitiveJCodecs.timestampDateTime
case TimestampFormat.EPOCH_SECONDS =>
PrimitiveJCodecs.timestampEpochSeconds
case TimestampFormat.HTTP_DATE => PrimitiveJCodecs.timestampHttpDate
}
case PUnit => PrimitiveJCodecs.unit
case PUUID => PrimitiveJCodecs.uuid
}
}
private def listImpl[A](member: Schema[A]) = new JCodec[List[A]] {
private[this] val a: JCodec[A] = apply(member)
def expecting: String = "list"
override def canBeKey: Boolean = false
def decodeValue(cursor: Cursor, in: JsonReader): List[A] =
if (in.isNextToken('[')) {
if (in.isNextToken(']')) Nil
else {
in.rollbackToken()
val builder = new ListBuffer[A]
var i = 0
while ({
if (i >= maxArity) maxArityError(cursor)
builder += cursor.under(i)(cursor.decode(a, in))
i += 1
in.isNextToken(',')
}) ()
if (in.isCurrentToken(']')) builder.result()
else in.arrayEndOrCommaError()
}
} else in.decodeError("Expected JSON array")
def encodeValue(xs: List[A], out: JsonWriter): Unit = {
out.writeArrayStart()
var list = xs
while (list ne Nil) {
a.encodeValue(list.head, out)
list = list.tail
}
out.writeArrayEnd()
}
def decodeKey(in: JsonReader): List[A] =
in.decodeError("Cannot use vectors as keys")
def encodeKey(xs: List[A], out: JsonWriter): Unit =
out.encodeError("Cannot use vectors as keys")
private[this] def maxArityError(cursor: Cursor): Nothing =
throw cursor.payloadError(
this,
s"Input $expecting exceeded max arity of $maxArity"
)
}
private def vector[A](
member: Schema[A]
): JCodec[Vector[A]] = new JCodec[Vector[A]] {
private[this] val a = apply(member)
def expecting: String = "list"
override def canBeKey: Boolean = false
def decodeValue(cursor: Cursor, in: JsonReader): Vector[A] =
if (in.isNextToken('[')) {
if (in.isNextToken(']')) Vector.empty
else {
in.rollbackToken()
val builder = Vector.newBuilder[A]
var i = 0
while ({
if (i >= maxArity) maxArityError(cursor)
builder += cursor.under(i)(cursor.decode(a, in))
i += 1
in.isNextToken(',')
}) ()
if (in.isCurrentToken(']')) builder.result()
else in.arrayEndOrCommaError()
}
} else in.decodeError("Expected JSON array")
def encodeValue(xs: Vector[A], out: JsonWriter): Unit = {
out.writeArrayStart()
xs.foreach(x => a.encodeValue(x, out))
out.writeArrayEnd()
}
def decodeKey(in: JsonReader): Vector[A] =
in.decodeError("Cannot use vectors as keys")
def encodeKey(xs: Vector[A], out: JsonWriter): Unit =
out.encodeError("Cannot use vectors as keys")
private[this] def maxArityError(cursor: Cursor): Nothing =
throw cursor.payloadError(
this,
s"Input $expecting exceeded max arity of $maxArity"
)
}
private def indexedSeq[A](
member: Schema[A]
): JCodec[IndexedSeq[A]] = new JCodec[IndexedSeq[A]] {
private[this] val a = apply(member)
def expecting: String = "list"
override def canBeKey: Boolean = false
val withBuilder = CollectionTag.IndexedSeqTag.compactBuilder(member)
def decodeValue(cursor: Cursor, in: JsonReader): IndexedSeq[A] =
if (in.isNextToken('[')) {
if (in.isNextToken(']')) Vector.empty
else {
in.rollbackToken()
withBuilder { put =>
var i = 0
while ({
if (i >= maxArity) maxArityError(cursor)
put(cursor.under(i)(cursor.decode(a, in)))
i += 1
in.isNextToken(',')
}) ()
if (!in.isCurrentToken(']')) {
in.arrayEndOrCommaError()
}
}
}
} else in.decodeError("Expected JSON array")
def encodeValue(xs: IndexedSeq[A], out: JsonWriter): Unit = {
out.writeArrayStart()
xs match {
case x: ArraySeq[A] =>
val xs = x.unsafeArray.asInstanceOf[Array[A]]
var i = 0
while (i < xs.length) {
a.encodeValue(xs(i), out)
i += 1
}
case _ =>
xs.foreach(x => a.encodeValue(x, out))
}
out.writeArrayEnd()
}
def decodeKey(in: JsonReader): IndexedSeq[A] =
in.decodeError("Cannot use vectors as keys")
def encodeKey(xs: IndexedSeq[A], out: JsonWriter): Unit =
out.encodeError("Cannot use vectors as keys")
private[this] def maxArityError(cursor: Cursor): Nothing =
throw cursor.payloadError(
this,
s"Input $expecting exceeded max arity of $maxArity"
)
}
private def set[A](
member: Schema[A]
): JCodec[Set[A]] = new JCodec[Set[A]] {
private[this] val a = apply(member)
def expecting: String = "list"
override def canBeKey: Boolean = false
def decodeValue(cursor: Cursor, in: JsonReader): Set[A] =
if (in.isNextToken('[')) {
if (in.isNextToken(']')) Set.empty
else {
in.rollbackToken()
val builder = Set.newBuilder[A]
var i = 0
while ({
if (i >= maxArity) maxArityError(cursor)
builder += cursor.under(i)(cursor.decode(a, in))
i += 1
in.isNextToken(',')
}) ()
if (in.isCurrentToken(']')) builder.result()
else in.arrayEndOrCommaError()
}
} else in.decodeError("Expected JSON array")
def encodeValue(xs: Set[A], out: JsonWriter): Unit = {
out.writeArrayStart()
xs.foreach(x => a.encodeValue(x, out))
out.writeArrayEnd()
}
def decodeKey(in: JsonReader): Set[A] =
in.decodeError("Cannot use vectors as keys")
def encodeKey(xs: Set[A], out: JsonWriter): Unit =
out.encodeError("Cannot use vectors as keys")
private[this] def maxArityError(cursor: Cursor): Nothing =
throw cursor.payloadError(
this,
s"Input $expecting exceeded max arity of $maxArity"
)
}
private def objectMap[K, V](
jk: JCodec[K],
jv: JCodec[V]
): JCodec[Map[K, V]] = new JCodec[Map[K, V]] {
val expecting: String = "map"
override def canBeKey: Boolean = false
def decodeValue(cursor: Cursor, in: JsonReader): Map[K, V] =
if (in.isNextToken('{')) {
if (in.isNextToken('}')) Map.empty
else {
in.rollbackToken()
val builder = Map.newBuilder[K, V]
var i = 0
while ({
if (i >= maxArity) maxArityError(cursor)
builder += (
(
jk.decodeKey(in),
cursor.under(i)(cursor.decode(jv, in))
)
)
i += 1
in.isNextToken(',')
}) ()
if (in.isCurrentToken('}')) builder.result()
else in.objectEndOrCommaError()
}
} else in.decodeError("Expected JSON object")
def encodeValue(xs: Map[K, V], out: JsonWriter): Unit = {
out.writeObjectStart()
xs.foreach { kv =>
jk.encodeKey(kv._1, out)
jv.encodeValue(kv._2, out)
}
out.writeObjectEnd()
}
def decodeKey(in: JsonReader): Map[K, V] =
in.decodeError("Cannot use maps as keys")
def encodeKey(xs: Map[K, V], out: JsonWriter): Unit =
out.encodeError("Cannot use maps as keys")
private[this] def maxArityError(cursor: Cursor): Nothing =
throw cursor.payloadError(
this,
s"Input $expecting exceeded max arity of $maxArity"
)
}
private def arrayMap[K, V](
k: Schema[K],
v: Schema[V]
): JCodec[Map[K, V]] = {
val kField = Field.required[Schema, (K, V), K]("key", k, _._1)
val vField = Field.required[Schema, (K, V), V]("value", v, _._2)
val kvCodec = Schema.struct(Vector(kField, vField))(fields =>
(fields(0).asInstanceOf[K], fields(1).asInstanceOf[V])
)
listImpl(kvCodec).biject(_.toMap, _.toList)
}
override def collection[C[_], A](
shapeId: ShapeId,
hints: Hints,
tag: CollectionTag[C],
member: Schema[A]
): JCodec[C[A]] = {
tag match {
case CollectionTag.ListTag => listImpl(member)
case CollectionTag.SetTag => set(member)
case CollectionTag.VectorTag => vector(member)
case CollectionTag.IndexedSeqTag => indexedSeq(member)
}
}
override def map[K, V](
shapeId: ShapeId,
hints: Hints,
key: Schema[K],
value: Schema[V]
): JCodec[Map[K, V]] = {
val jk = apply(key)
val jv = apply(value)
if (jk.canBeKey) objectMap(jk, jv)
else arrayMap(key, value)
}
override def biject[A, B](
schema: Schema[A],
bijection: Bijection[A, B]
): JCodec[B] =
apply(schema).biject(bijection, bijection.from)
override def refine[A, B](
schema: Schema[A],
refinement: Refinement[A, B]
): JCodec[B] =
JCodec.jcodecInvariant
.xmap(apply(schema))(refinement.asFunction, refinement.from)
override def lazily[A](suspend: Lazy[Schema[A]]): JCodec[A] = new JCodec[A] {
lazy val underlying = apply(suspend.value)
def expecting: String = underlying.expecting
def decodeValue(cursor: Cursor, in: JsonReader): A =
underlying.decodeValue(cursor, in)
def encodeValue(x: A, out: JsonWriter): Unit =
underlying.encodeValue(x, out)
def decodeKey(in: JsonReader): A = underlying.decodeKey(in)
def encodeKey(x: A, out: JsonWriter): Unit = underlying.encodeKey(x, out)
}
private type Writer[A] = A => JsonWriter => Unit
private def taggedUnion[U](
alternatives: Vector[Alt[Schema, U, _]]
)(dispatch: Alt.Dispatcher[Schema, U]): JCodec[U] =
new JCodec[U] {
val expecting: String = "tagged-union"
override def canBeKey: Boolean = false
def jsonLabel[A](alt: Alt[Schema, U, A]): String =
alt.hints.get(JsonName) match {
case None => alt.label
case Some(x) => x.value
}
private[this] val handlerMap =
new util.HashMap[String, (Cursor, JsonReader) => U] {
def handler[A](alt: Alt[Schema, U, A]) = {
val codec = apply(alt.instance)
(cursor: Cursor, reader: JsonReader) =>
alt.inject(cursor.decode(codec, reader))
}
alternatives.foreach(alt => put(jsonLabel(alt), handler(alt)))
}
def decodeValue(cursor: Cursor, in: JsonReader): U =
if (in.isNextToken('{')) {
if (in.isNextToken('}'))
in.decodeError("Expected a single key/value pair")
else {
in.rollbackToken()
val key = in.readKeyAsString()
val result = cursor.under(key) {
val handler = handlerMap.get(key)
if (handler eq null) in.discriminatorValueError(key)
handler(cursor, in)
}
if (in.isNextToken('}')) result
else {
in.rollbackToken()
in.decodeError(s"Expected no other field after $key")
}
}
} else in.decodeError("Expected JSON object")
val precompiler = new smithy4s.schema.Alt.Precompiler[Schema, Writer] {
def apply[A](label: String, instance: Schema[A]): Writer[A] = {
val jsonLabel =
instance.hints.get(JsonName).map(_.value).getOrElse(label)
val jcodecA = instance.compile(self)
a =>
out => {
out.writeObjectStart()
out.writeKey(jsonLabel)
jcodecA.encodeValue(a, out)
out.writeObjectEnd()
}
}
}
val writer = dispatch.compile(precompiler)
def encodeValue(u: U, out: JsonWriter): Unit = {
writer(u)(out)
}
def decodeKey(in: JsonReader): U =
in.decodeError("Cannot use coproducts as keys")
def encodeKey(u: U, out: JsonWriter): Unit =
out.encodeError("Cannot use coproducts as keys")
}
private def untaggedUnion[U](
alternatives: Vector[Alt[Schema, U, _]]
)(dispatch: Alt.Dispatcher[Schema, U]): JCodec[U] = new JCodec[U] {
def expecting: String = "untaggedUnion"
override def canBeKey: Boolean = false
private[this] val handlerList: Array[(Cursor, JsonReader) => U] = {
val res = Array.newBuilder[(Cursor, JsonReader) => U]
def handler[A](alt: Alt[Schema, U, A]) = {
val codec = apply(alt.instance)
(cursor: Cursor, reader: JsonReader) =>
alt.inject(cursor.decode(codec, reader))
}
alternatives.foreach(alt => res += handler(alt))
res.result()
}
def decodeValue(cursor: Cursor, in: JsonReader): U = {
var z: U = null.asInstanceOf[U]
val len = handlerList.length
var i = 0
while (z == null && i < len) {
in.setMark()
val handler = handlerList(i)
try {
z = handler(cursor, in)
} catch {
case _: Throwable =>
in.rollbackToMark()
i += 1
}
}
if (z != null) z
else cursor.payloadError(this, "Could not decode untagged union")
}
val precompiler = new smithy4s.schema.Alt.Precompiler[Schema, Writer] {
def apply[A](label: String, instance: Schema[A]): Writer[A] = {
val jcodecA = instance.compile(self)
a => out => jcodecA.encodeValue(a, out)
}
}
val writer = dispatch.compile(precompiler)
def encodeValue(u: U, out: JsonWriter): Unit = {
writer(u)(out)
}
def decodeKey(in: JsonReader): U =
in.decodeError("Cannot use coproducts as keys")
def encodeKey(u: U, out: JsonWriter): Unit =
out.encodeError("Cannot use coproducts as keys")
}
private def discriminatedUnion[U](
alternatives: Vector[Alt[Schema, U, _]],
discriminated: Discriminated
)(dispatch: Alt.Dispatcher[Schema, U]): JCodec[U] =
new JCodec[U] {
def expecting: String = "discriminated-union"
override def canBeKey: Boolean = false
def jsonLabel[A](alt: Alt[Schema, U, A]): String =
alt.hints.get(JsonName) match {
case None => alt.label
case Some(x) => x.value
}
private[this] val handlerMap =
new util.HashMap[String, (Cursor, JsonReader) => U] {
def handler[A](
alt: Alt[Schema, U, A]
): (Cursor, JsonReader) => U = {
val codec = apply(alt.instance)
(cursor: Cursor, reader: JsonReader) =>
alt.inject(cursor.decode(codec, reader))
}
alternatives.foreach(alt => put(jsonLabel(alt), handler(alt)))
}
def decodeValue(cursor: Cursor, in: JsonReader): U =
if (in.isNextToken('{')) {
in.setMark()
if (in.skipToKey(discriminated.value)) {
val key = in.readString("")
in.rollbackToMark()
in.rollbackToken()
cursor.under(key) {
val handler = handlerMap.get(key)
if (handler eq null) in.discriminatorValueError(key)
handler(cursor, in)
}
} else
in.decodeError(
s"Unable to find discriminator ${discriminated.value}"
)
} else in.decodeError("Expected JSON object")
val precompiler = new smithy4s.schema.Alt.Precompiler[Schema, Writer] {
def apply[A](label: String, instance: Schema[A]): Writer[A] = {
val jsonLabel =
instance.hints.get(JsonName).map(_.value).getOrElse(label)
val jcodecA = instance
.addHints(
Hints(DiscriminatedUnionMember(discriminated.value, jsonLabel))
)
.compile(self)
a => out => jcodecA.encodeValue(a, out)
}
}
val writer = dispatch.compile(precompiler)
def encodeValue(u: U, out: JsonWriter): Unit = {
writer(u)(out)
}
def decodeKey(in: JsonReader): U =
in.decodeError("Cannot use coproducts as keys")
def encodeKey(x: U, out: JsonWriter): Unit =
out.encodeError("Cannot use coproducts as keys")
}
override def union[U](
shapeId: ShapeId,
hints: Hints,
alternatives: Vector[SchemaAlt[U, _]],
dispatch: Alt.Dispatcher[Schema, U]
): JCodec[U] = hints match {
case Untagged.hint(_) => untaggedUnion(alternatives)(dispatch)
case Discriminated.hint(d) => discriminatedUnion(alternatives, d)(dispatch)
case _ => taggedUnion(alternatives)(dispatch)
}
override def enumeration[E](
shapeId: ShapeId,
hints: Hints,
values: List[EnumValue[E]],
total: E => EnumValue[E]
): JCodec[E] = if (hints.has[IntEnum]) {
handleIntEnum(shapeId, hints, values, total)
} else {
handleEnum(shapeId, hints, values, total)
}
private def handleEnum[E](
shapeId: ShapeId,
hints: Hints,
values: List[EnumValue[E]],
total: E => EnumValue[E]
): JCodec[E] = new JCodec[E] {
def fromName(v: String): Option[E] =
values.find(_.stringValue == v).map(_.value)
val expecting: String =
s"enumeration: [${values.map(_.stringValue).mkString(", ")}]"
def decodeValue(cursor: Cursor, in: JsonReader): E = {
val str = in.readString(null)
fromName(str) match {
case Some(value) => value
case None => in.enumValueError(str)
}
}
def encodeValue(x: E, out: JsonWriter): Unit =
out.writeVal(total(x).stringValue)
def decodeKey(in: JsonReader): E = {
val str = in.readKeyAsString()
fromName(str) match {
case Some(value) => value
case None => in.enumValueError(str)
}
}
def encodeKey(x: E, out: JsonWriter): Unit =
out.writeKey(total(x).stringValue)
}
private def handleIntEnum[E](
shapeId: ShapeId,
hints: Hints,
values: List[EnumValue[E]],
total: E => EnumValue[E]
): JCodec[E] = new JCodec[E] {
def fromOrdinal(v: Int): Option[E] =
values.find(_.intValue == v).map(_.value)
val expecting: String =
s"enumeration: [${values.map(_.stringValue).mkString(", ")}]"
def decodeValue(cursor: Cursor, in: JsonReader): E = {
val i = in.readInt()
fromOrdinal(i) match {
case Some(value) => value
case None => in.enumValueError(i)
}
}
def encodeValue(x: E, out: JsonWriter): Unit =
out.writeVal(total(x).intValue)
def decodeKey(in: JsonReader): E = {
val i = in.readKeyAsInt()
fromOrdinal(i) match {
case Some(value) => value
case None => in.enumValueError(i)
}
}
def encodeKey(x: E, out: JsonWriter): Unit =
out.writeKey(total(x).intValue)
}
private def jsonLabel[A, Z](field: Field[Schema, Z, A]): String =
field.hints.get(JsonName) match {
case None => field.label
case Some(x) => x.value
}
private type Handler = (Cursor, JsonReader, util.HashMap[String, Any]) => Unit
private def fieldHandler[Z, A](
field: Field[Schema, Z, A]
): Handler = {
val codec = apply(field.instance)
val label = field.label
if (field.isRequired) { (cursor, in, mmap) =>
val _ = mmap.put(label, cursor.under(label)(cursor.decode(codec, in)))
} else { (cursor, in, mmap) =>
cursor.under[Unit](label) {
if (in.isNextToken('n')) in.readNullOrError[Unit]((), "Expected null")
else {
in.rollbackToken()
val _ = mmap.put(label, cursor.decode(codec, in))
}
}
}
}
private def fieldEncoder[Z, A](
field: Field[Schema, Z, A]
): (Z, JsonWriter) => Unit = {
field.fold(new Field.Folder[Schema, Z, (Z, JsonWriter) => Unit] {
def onRequired[AA](
label: String,
instance: Schema[AA],
get: Z => AA
): (Z, JsonWriter) => Unit = {
val codec = apply(instance)
val jLabel = jsonLabel(field)
if (jLabel.forall(JsonWriter.isNonEscapedAscii)) {
(z: Z, out: JsonWriter) =>
{
out.writeNonEscapedAsciiKey(jLabel)
codec.encodeValue(get(z), out)
}
} else { (z: Z, out: JsonWriter) =>
{
out.writeKey(jLabel)
codec.encodeValue(get(z), out)
}
}
}
def onOptional[AA](
label: String,
instance: Schema[AA],
get: Z => Option[AA]
): (Z, JsonWriter) => Unit = {
val codec = apply(instance)
val jLabel = jsonLabel(field)
if (jLabel.forall(JsonWriter.isNonEscapedAscii)) {
(z: Z, out: JsonWriter) =>
{
get(z) match {
case Some(aa) =>
out.writeNonEscapedAsciiKey(jLabel)
codec.encodeValue(aa, out)
case _ =>
}
}
} else { (z: Z, out: JsonWriter) =>
{
get(z) match {
case Some(aa) =>
out.writeKey(jLabel)
codec.encodeValue(aa, out)
case _ =>
}
}
}
}
})
}
private type Fields[Z] = Vector[Field[Schema, Z, _]]
private type LabelledFields[Z] = Vector[(SchemaField[Z, _], String, Any)]
private def labelledFields[Z](fields: Fields[Z]): LabelledFields[Z] =
fields.map { field =>
val jLabel = jsonLabel(field)
val decode: Document => Option[Any] =
Document.Decoder.fromSchema(field.instance).decode(_).toOption
val decoded = field.getDefault.flatMap(decode)
val default = decoded.orNull
(field, jLabel, default)
}
private def nonPayloadStruct[Z](
fields: LabelledFields[Z],
structHints: Hints
)(
const: Vector[Any] => Z,
encode: (Z, JsonWriter, Vector[(Z, JsonWriter) => Unit]) => Unit
): JCodec[Z] =
new JCodec[Z] {
private[this] val documentFields =
fields.filter { case (field, _, _) =>
HttpBinding
.fromHints(field.label, field.hints, structHints)
.isEmpty
}
private[this] val handlers =
new util.HashMap[String, Handler](documentFields.length << 1, 0.5f) {
documentFields.foreach { case (field, jLabel, _) =>
put(jLabel, fieldHandler(field))
}
}
private[this] val documentEncoders =
documentFields.map(labelledField => fieldEncoder(labelledField._1))
def expecting: String = "object"
override def canBeKey = false
def decodeValue(cursor: Cursor, in: JsonReader): Z =
decodeValue_(cursor, in)(emptyMetadata)
override def decodeMessage(
in: JsonReader
): scala.collection.Map[String, Any] => Z =
Cursor.withCursor(expecting)(decodeValue_(_, in))
private def decodeValue_(
cursor: Cursor,
in: JsonReader
): scala.collection.Map[String, Any] => Z = {
val buffer = new util.HashMap[String, Any](handlers.size << 1, 0.5f)
if (in.isNextToken('{')) {
// In this case, metadata and payload are mixed together
// and values field values must be sought from either.
if (!in.isNextToken('}')) {
in.rollbackToken()
while ({
val handler = handlers.get(in.readKeyAsString())
if (handler eq null) in.skip()
else handler(cursor, in, buffer)
in.isNextToken(',')
}) ()
if (!in.isCurrentToken('}')) in.objectEndOrCommaError()
}
} else in.decodeError("Expected JSON object")
// At this point, we have parsed the json and retrieved
// all the values that interest us for the construction
// of our domain object.
// We therefore reconcile the values pulled from the json
// with the ones pull the metadata, and call the constructor
// on it.
{ (meta: scala.collection.Map[String, Any]) =>
meta.foreach(kv => buffer.put(kv._1, kv._2))
val stage2 = new VectorBuilder[Any]
fields.foreach { case (f, jsonLabel, default) =>
stage2 += {
val value = buffer.get(f.label)
if (f.isRequired) {
if (value == null) {
if (default == null)
cursor.requiredFieldError(jsonLabel, jsonLabel)
else default
} else value
} else {
Option(value)
}
}
}
const(stage2.result())
}
}
def encodeValue(z: Z, out: JsonWriter): Unit =
encode(z, out, documentEncoders)
def decodeKey(in: JsonReader): Z =
in.decodeError("Cannot use products as keys")
def encodeKey(x: Z, out: JsonWriter): Unit =
out.encodeError("Cannot use products as keys")
}
private def payloadStruct[A, Z](
payloadField: Field[Schema, Z, _],
fields: LabelledFields[Z]
)(codec: JCodec[payloadField.T], const: Vector[Any] => Z): JCodec[Z] =
new JCodec[Z] {
def expecting: String = "object"
override def canBeKey = false
def decodeValue(cursor: Cursor, in: JsonReader): Z =
decodeValue_(cursor, in)(emptyMetadata)
override def decodeMessage(
in: JsonReader
): scala.collection.Map[String, Any] => Z =
Cursor.withCursor(expecting)(decodeValue_(_, in))
private def decodeValue_(
cursor: Cursor,
in: JsonReader
): scala.collection.Map[String, Any] => Z = {
val buffer = new util.HashMap[String, Any](2, 0.5f)
// In this case, one field assumes the whole payload. We use
// its associated codec.
buffer.put(payloadField.label, cursor.decode(codec, in))
// At this point, we have parsed the json and retrieved
// all the values that interest us for the construction
// of our domain object.
// We therefore reconcile the values pulled from the json
// with the ones pull the metadata, and call the constructor
// on it.
{ (meta: scala.collection.Map[String, Any]) =>
meta.foreach(kv => buffer.put(kv._1, kv._2))
val stage2 = new VectorBuilder[Any]
fields.foreach { case (f, jsonLabel, _) =>
stage2 += {
val value = buffer.get(f.label)
if (f.isRequired) {
if (value == null)
cursor.requiredFieldError(jsonLabel, jsonLabel)
value
} else Option(value)
}
}
const(stage2.result())
}
}
def encodeValue(z: Z, out: JsonWriter): Unit =
payloadField.foreachT(z)(codec.encodeValue(_, out))
def decodeKey(in: JsonReader): Z =
in.decodeError("Cannot use products as keys")
def encodeKey(x: Z, out: JsonWriter): Unit =
out.encodeError("Cannot use products as keys")
}
private def basicStruct[A, S](
fields: LabelledFields[S],
structHints: Hints
)(make: Vector[Any] => S): JCodec[S] = {
val encode = {
(
z: S,
out: JsonWriter,
documentEncoders: Vector[(S, JsonWriter) => Unit]
) =>
out.writeObjectStart()
documentEncoders.foreach(encoder => encoder(z, out))
out.writeObjectEnd()
}
nonPayloadStruct(fields, structHints)(make, encode)
}
override def struct[S](
shapeId: ShapeId,
hints: Hints,
fields: Vector[SchemaField[S, _]],
make: IndexedSeq[Any] => S
): JCodec[S] = {
val lFields = labelledFields[S](fields)
(fields.find(_.hints.get(HttpPayload).isDefined), hints) match {
case (Some(payloadField), _) =>
val codec = apply(payloadField.instance)
payloadStruct(payloadField, lFields)(codec, make)
case (None, DiscriminatedUnionMember.hint(d)) =>
val encode =
if (
d.propertyName.forall(JsonWriter.isNonEscapedAscii) &&
d.alternativeLabel.forall(JsonWriter.isNonEscapedAscii)
) {
(
z: S,
out: JsonWriter,
documentEncoders: Vector[(S, JsonWriter) => Unit]
) =>
out.writeObjectStart()
out.writeNonEscapedAsciiKey(d.propertyName)
out.writeNonEscapedAsciiVal(d.alternativeLabel)
documentEncoders.foreach(encoder => encoder(z, out))
out.writeObjectEnd()
} else {
(
z: S,
out: JsonWriter,
documentEncoders: Vector[(S, JsonWriter) => Unit]
) =>
out.writeObjectStart()
out.writeKey(d.propertyName)
out.writeVal(d.alternativeLabel)
documentEncoders.foreach(encoder => encoder(z, out))
out.writeObjectEnd()
}
nonPayloadStruct(lFields, hints)(make, encode)
case _ =>
basicStruct(lFields, hints)(make)
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy