caliban.tools.SchemaWriter.scala Maven / Gradle / Ivy
package caliban.tools
import caliban.parsing.adt.Definition.TypeSystemDefinition.TypeDefinition
import caliban.parsing.adt.Definition.TypeSystemDefinition.TypeDefinition._
import caliban.parsing.adt.Directives.{ LazyDirective, NewtypeDirective }
import caliban.parsing.adt.Type.{ ListType, NamedType }
import caliban.parsing.adt.{ Directive, Directives, Document, Type }
import scala.annotation.tailrec
import scala.collection.compat._
object SchemaWriter {
def write(
schema: Document,
packageName: Option[String] = None,
effect: String = "zio.UIO",
imports: Option[List[String]] = None,
scalarMappings: Option[Map[String, String]],
isEffectTypeAbstract: Boolean = false,
preserveInputNames: Boolean = false,
addDerives: Boolean = false,
envForDerives: Option[String] = None
): String = {
val (derivesSchema, derivesSchemaAndArgBuilder, envSchemaDerivation, derivesEnvSchema) =
(addDerives, envForDerives) match {
case (false, _) => ("", "", "", "")
case (true, Some(env)) if !env.equalsIgnoreCase("Any") =>
(
" derives caliban.schema.Schema.SemiAuto",
" derives caliban.schema.Schema.SemiAuto, caliban.schema.ArgBuilder",
s"object EnvSchema extends caliban.schema.SchemaDerivation[${safeName(env)}]\n\n",
" derives Operations.EnvSchema.SemiAuto"
)
case (true, _) =>
(
" derives caliban.schema.Schema.SemiAuto",
" derives caliban.schema.Schema.SemiAuto, caliban.schema.ArgBuilder",
"",
" derives caliban.schema.Schema.SemiAuto"
)
}
val interfaceImplementationsMap: Map[InterfaceTypeDefinition, List[ObjectTypeDefinition]] = (for {
objectDef <- schema.objectTypeDefinitions
interfaceDef <- schema.interfaceTypeDefinitions
if objectDef.implements.exists(_.name == interfaceDef.name)
} yield interfaceDef -> objectDef).groupBy(_._1).map { case (definition, tuples) =>
definition -> tuples.map(_._2)
}
val interfacesExtendedForObject: Map[ObjectTypeDefinition, List[InterfaceTypeDefinition]] =
interfaceImplementationsMap.iterator.flatMap { case (i, os) => os.map(o => (o, i)) }.toList.groupMap {
case (o, _) => o
} { case (_, i) => i }
lazy val typeNameToDefinitionMap: Map[String, TypeDefinition] =
schema.typeDefinitions.map(obj => obj.name -> obj).toMap
lazy val typeNameToNestedFields = typeNameToDefinitionMap.map { case (name, t) =>
name -> findNestedFieldTypes(t, typeNameToDefinitionMap).flatMap(typeNameToDefinitionMap.get)
}
def inheritedFromInterface(obj: ObjectTypeDefinition, field: FieldDefinition): Option[InterfaceTypeDefinition] =
interfacesExtendedForObject.get(obj) flatMap { interfaces =>
interfaces.find(_.fields.exists(_.name == field.name))
}
def reservedType(typeDefinition: ObjectTypeDefinition): Boolean =
typeDefinition.name == "Query" || typeDefinition.name == "Mutation" || typeDefinition.name == "Subscription"
def containsNestedDirective(
field: FieldDefinition,
directive: String
): Boolean =
typeNameToDefinitionMap.get(Type.innerType(field.ofType)).fold(false) { t =>
hasFieldWithDirective(t, directive)
} || typeNameToNestedFields
.getOrElse(Type.innerType(field.ofType), List.empty)
.exists(t => hasFieldWithDirective(t, directive))
def writeRootField(field: FieldDefinition, od: ObjectTypeDefinition): String = {
val argsTypeName = if (field.args.nonEmpty) s" ${argsName(field, od)} =>" else ""
val fieldType =
if (isEffectTypeAbstract && containsNestedDirective(field, LazyDirective))
writeParameterizedType(field.ofType)
else
writeType(field.ofType)
s"${safeName(field.name)} :$argsTypeName $effect[$fieldType]"
}
def isAbstractEffectful(typedef: ObjectTypeDefinition): Boolean =
isEffectTypeAbstract && isEffectful(typedef)
def isEffectful(typedef: ObjectTypeDefinition): Boolean = isLocalEffectful(typedef) || isNestedEffectful(typedef)
def isLocalEffectful(typedef: ObjectTypeDefinition): Boolean =
hasFieldWithDirective(typedef, LazyDirective)
def isNestedEffectful(typedef: ObjectTypeDefinition): Boolean =
typeNameToNestedFields
.getOrElse(typedef.name, List.empty)
.exists(t => hasFieldWithDirective(t, LazyDirective))
def generic(op: ObjectTypeDefinition, isRootDefinition: Boolean = false): String =
if ((isRootDefinition && isEffectTypeAbstract) || isAbstractEffectful(op))
s"[${effect}[_]]"
else
s""
def writeRootQueryOrMutationDef(op: ObjectTypeDefinition): String =
s"""
|${writeTypeAnnotations(op)}final case class ${op.name}${generic(op, isRootDefinition = true)}(
|${op.fields.map(c => writeRootField(c, op)).mkString(",\n")}
|)$derivesEnvSchema""".stripMargin
def writeSubscriptionField(field: FieldDefinition, od: ObjectTypeDefinition): String =
"%s:%s ZStream[Any, Nothing, %s]".format(
safeName(field.name),
if (field.args.nonEmpty) s" ${argsName(field, od)} =>" else "",
writeType(field.ofType)
)
def writeRootSubscriptionDef(op: ObjectTypeDefinition): String =
s"""
|${writeTypeAnnotations(op)}final case class ${op.name}(
|${op.fields.map(c => writeSubscriptionField(c, op)).mkString(",\n")}
|)$derivesEnvSchema""".stripMargin
def writeObject(typedef: ObjectTypeDefinition, extend: List[String]): String = {
val extendRendered = extend match {
case Nil => ""
case nonEmpty => s" extends ${nonEmpty.mkString(" with ")}"
}
s"""${writeTypeAnnotations(typedef)}final case class ${typedef.name}${generic(typedef)}(${typedef.fields
.map(field => writeField(field, inheritedFromInterface(typedef, field).getOrElse(typedef), isMethod = false))
.mkString(", ")})$extendRendered$derivesEnvSchema"""
}
def writeInputObject(typedef: InputObjectTypeDefinition): String = {
val name = typedef.name
val maybeAnnotation = if (preserveInputNames) s"""@GQLInputName("$name")\n""" else ""
s"""$maybeAnnotation${writeTypeAnnotations(typedef)}final case class $name(${typedef.fields
.map(writeInputValue)
.mkString(", ")})$derivesSchemaAndArgBuilder"""
}
def writeEnum(typedef: EnumTypeDefinition): String =
s"""${writeTypeAnnotations(typedef)}sealed trait ${typedef.name} extends scala.Product with scala.Serializable$derivesSchemaAndArgBuilder
object ${typedef.name} {
${typedef.enumValuesDefinition
.map(v =>
s"${writeEnumAnnotations(v)}case object ${safeName(v.enumValue)} extends ${typedef.name}$derivesSchemaAndArgBuilder"
)
.mkString("\n")}
}
"""
def writeUnions(unions: List[UnionTypeDefinition]): String =
unions.map(x => writeUnionSealedTrait(x)).mkString("\n")
def writeUnionSealedTrait(union: UnionTypeDefinition): String =
s"""${writeTypeAnnotations(
union
)}sealed trait ${union.name} extends scala.Product with scala.Serializable$derivesSchema"""
def writeInterface(interface: InterfaceTypeDefinition): String =
s"""@GQLInterface
${writeTypeAnnotations(interface)}sealed trait ${interface.name} extends scala.Product with scala.Serializable $derivesEnvSchema {
${interface.fields.map(field => writeField(field, interface, isMethod = true)).mkString("\n")}
}
"""
def writeField(inputField: FieldDefinition, of: TypeDefinition, isMethod: Boolean): String = {
val field = resolveNewTypeFieldDef(inputField).getOrElse(inputField)
val fieldIsEffectWrapped = field.directives.exists(_.name == LazyDirective)
val fieldTypeIsEffectTypeParameterized = isEffectTypeAbstract && containsNestedDirective(field, LazyDirective)
val fieldType = (fieldIsEffectWrapped, fieldTypeIsEffectTypeParameterized) match {
case (true, true) => s"$effect[${writeParameterizedType(field.ofType)}]"
case (true, false) => s"$effect[${writeType(field.ofType)}]"
case (false, true) => writeParameterizedType(field.ofType)
case (false, false) => writeType(field.ofType)
}
val GQLNewTypeDirective = writeGQLNewTypeDirective(field.directives)
if (field.args.nonEmpty) {
s"""$GQLNewTypeDirective${writeFieldAnnotations(field)}${if (isMethod) "def " else ""}${safeName(
field.name
)} : ${argsName(field, of)} => $fieldType"""
} else {
s"""$GQLNewTypeDirective${writeFieldAnnotations(field)}${if (isMethod) "def " else ""}${safeName(
field.name
)} : $fieldType"""
}
}
def writeGQLNewTypeDirective(directives: List[Directive]) =
directives
.find(_.name == NewtypeDirective)
.fold("") { directive =>
val fnName = directive.arguments("name").toInputString.replace("\"", "")
s"""@GQLDirective(Directive("$NewtypeDirective", Map("name" -> StringValue("${safeName(fnName)}"))))\n"""
}
def writeInputValue(value: InputValueDefinition): String = {
val GQLNewTypeInputDirective = writeGQLNewTypeDirective(value.directives)
val inputDef = resolveNewTypeInputDef(value).getOrElse(value)
s"""$GQLNewTypeInputDirective${writeInputAnnotations(inputDef)}${safeName(inputDef.name)} : ${writeType(
inputDef.ofType
)}"""
}
def writeArguments(field: FieldDefinition, of: TypeDefinition): String = {
def fields(args: List[InputValueDefinition]): String =
s"${args.map { arg =>
val resolvedArg = resolveNewTypeInputDef(arg).getOrElse(arg)
s"${safeName(resolvedArg.name)} : ${writeType(resolvedArg.ofType)}"
}.mkString(", ")}"
if (field.args.nonEmpty) {
s"final case class ${argsName(field, of)}(${fields(field.args)})$derivesSchemaAndArgBuilder"
} else {
""
}
}
def writeNewTypeClasses(fieldType: Type, directive: Directive) = {
@tailrec def writeTypeOf(fieldType: Type): String = fieldType match {
case NamedType(name, _) => scalarMappings.flatMap(_.get(name)).getOrElse(name)
case ListType(ftype, _) => writeTypeOf(ftype)
}
val fnName = safeName(directive.arguments("name").toInputString.replace("\"", ""))
val newtype = safeName(writeTypeOf(fieldType))
s"""case class $fnName(value : $newtype) extends AnyVal
|object $fnName {
| implicit val schema: Schema[Any, $fnName] = implicitly[Schema[Any, $newtype]].contramap(_.value)
| implicit val argBuilder: ArgBuilder[$fnName] = implicitly[ArgBuilder[$newtype]].map($fnName(_))
|}""".stripMargin
}
def replaceNameOfInnertype(name: String, ftype: Type): Type =
ftype match {
case NamedType(_, nt) => NamedType(name, nt)
case ListType(nt, opt) => ListType(replaceNameOfInnertype(name, nt), opt)
}
def resolveNewTypeFieldDef(field: FieldDefinition): Option[FieldDefinition] =
if (Directives.isNewType(field.directives)) {
Directives
.newTypeName(field.directives)
.map(name => field.copy(ofType = replaceNameOfInnertype(name, field.ofType)))
} else None
def resolveNewTypeInputDef(field: InputValueDefinition): Option[InputValueDefinition] =
if (Directives.isNewType(field.directives)) {
Directives
.newTypeName(field.directives)
.map(name => field.copy(ofType = replaceNameOfInnertype(name, field.ofType)))
} else None
def argsName(field: FieldDefinition, od: TypeDefinition): String =
s"${od.name.capitalize}${field.name.capitalize}Args"
def escapeDoubleQuotes(input: String): String =
input.replace("\"", "\\\"")
def writeTypeAnnotations(definition: TypeDefinition): String =
writeDescriptionAndDeprecation(definition.description, definition.directives)
def writeFieldAnnotations(definition: FieldDefinition): String =
writeDescriptionAndDeprecation(definition.description, definition.directives)
def writeInputAnnotations(definition: InputValueDefinition): String =
writeDescriptionAndDeprecation(definition.description, definition.directives)
def writeEnumAnnotations(definition: EnumValueDefinition): String =
writeDescriptionAndDeprecation(definition.description, definition.directives)
def writeDescriptionAndDeprecation(description: Option[String], directives: List[Directive]): String =
s"${writeDescription(description)} ${writeDeprecation(directives)}"
def escapeAndWrap(value: String, annotation: String): String = {
val escapedValue = escapeDoubleQuotes(value)
if (escapedValue.contains("\n")) {
s"""@$annotation(\"\"\"$escapedValue\"\"\")
|""".stripMargin
} else {
s"""@$annotation("$escapedValue")
|""".stripMargin
}
}
def writeDeprecation(directives: List[Directive]): String =
Directives.deprecationReason(directives).fold("") { d =>
escapeAndWrap(d, "GQLDeprecated")
}
def writeDescription(description: Option[String]): String =
description.fold("") { d =>
escapeAndWrap(d, "GQLDescription")
}
def writeType(t: Type): String = {
def write(name: String): String = scalarMappings
.flatMap(_.get(name))
.getOrElse(name)
t match {
case NamedType(name, true) => write(name)
case NamedType(name, false) => s"scala.Option[${write(name)}]"
case ListType(ofType, true) => s"List[${writeType(ofType)}]"
case ListType(ofType, false) => s"scala.Option[List[${writeType(ofType)}]]"
}
}
def writeParameterizedType(t: Type): String = {
def write(name: String): String = {
val result = scalarMappings
.flatMap(_.get(name))
.getOrElse(name)
s"$result[$effect]"
}
t match {
case Type.NamedType(name, true) => write(name)
case Type.NamedType(name, false) => s"scala.Option[${write(name)}]"
case Type.ListType(ofType, true) =>
s"List[${writeParameterizedType(ofType)}]"
case Type.ListType(ofType, false) =>
s"scala.Option[List[${writeParameterizedType(ofType)}]]"
}
}
val schemaDef = schema.schemaDefinition
val argsTypes = {
val fromObjects: List[(FieldDefinition, TypeDefinition)] =
schema.objectTypeDefinitions.flatMap { typeDef =>
typeDef.fields.collect {
case f if f.args.nonEmpty && inheritedFromInterface(typeDef, f).isEmpty => (f, typeDef)
}
}
val fromInterfaces: List[(FieldDefinition, TypeDefinition)] =
schema.interfaceTypeDefinitions.flatMap(typeDef =>
typeDef.fields.collect { case f if f.args.nonEmpty => (f, typeDef) }
)
(fromObjects ++ fromInterfaces).map { case (field, typeDef) => writeArguments(field, typeDef) }
.mkString("\n")
}
val newTypeClasses = {
case class FieldAndDirective(fieldName: String, fieldType: Type, directive: Directive)
val fromObjects =
schema.objectTypeDefinitions.flatMap {
_.fields.collect {
case f if f.directives.exists(_.name == NewtypeDirective) => // FIELD DEFINITION
List(FieldAndDirective(f.name, f.ofType, f.directives.filter(_.name == NewtypeDirective).head))
case f if f.args.exists(_.directives.exists(_.name == NewtypeDirective)) => // ARGUMENT DEFINITION
f.args.collect {
case a if a.directives.exists(_.name == NewtypeDirective) =>
FieldAndDirective(a.name, a.ofType, a.directives.filter(_.name == NewtypeDirective).head)
}
}
}.flatten
val fromInputTypes: List[FieldAndDirective] =
schema.inputObjectTypeDefinitions.flatMap {
_.fields.collect {
case f if f.directives.exists(_.name == NewtypeDirective) =>
FieldAndDirective(f.name, f.ofType, f.directives.filter(_.name == NewtypeDirective).head)
}
}
val newtypeClasses = (fromObjects ++ fromInputTypes)
.groupBy(_.directive.arguments("name").toInputString)
.map(_._2.head)
.toList
.sortBy(_.directive.arguments("name").toInputString)
.map(fieldAndDirective => writeNewTypeClasses(fieldAndDirective.fieldType, fieldAndDirective.directive))
.mkString("\n")
if (newtypeClasses.nonEmpty) newtypeClasses + "\n" else ""
}
val unionTypes = schema.unionTypeDefinitions
.map(union => (union, union.memberTypes.flatMap(schema.objectTypeDefinition)))
.toMap
val unions = writeUnions(schema.unionTypeDefinitions)
val interfacesStr = schema.interfaceTypeDefinitions.map { interface =>
writeInterface(interface)
}.mkString("\n")
val objects = schema.objectTypeDefinitions
.filterNot(obj =>
reservedType(obj) ||
schemaDef.exists(_.query.getOrElse("Query") == obj.name) ||
schemaDef.exists(_.mutation.getOrElse("Mutation") == obj.name) ||
schemaDef.exists(_.subscription.getOrElse("Subscription") == obj.name)
)
.map { obj =>
val extendsInterfaces = obj.implements.map(name => name.name)
val partOfUnionTypes = unionTypes.collect { case (u, os) if os.exists(_.name == obj.name) => u.name }
writeObject(obj, extend = extendsInterfaces ++ partOfUnionTypes)
}
.mkString("\n")
val inputs = schema.inputObjectTypeDefinitions.map(writeInputObject).mkString("\n")
val enums = schema.enumTypeDefinitions.map(writeEnum).mkString("\n")
val queries = schema
.objectTypeDefinition(schemaDef.flatMap(_.query).getOrElse("Query"))
.map(t => writeRootQueryOrMutationDef(t))
.getOrElse("")
val mutations = schema
.objectTypeDefinition(schemaDef.flatMap(_.mutation).getOrElse("Mutation"))
.map(t => writeRootQueryOrMutationDef(t))
.getOrElse("")
val subscriptions = schema
.objectTypeDefinition(schemaDef.flatMap(_.subscription).getOrElse("Subscription"))
.map(t => writeRootSubscriptionDef(t))
.getOrElse("")
val additionalImportsString = imports.fold("")(_.map(i => s"import $i").mkString("\n"))
val hasSubscriptions = subscriptions.nonEmpty
val hasTypes = argsTypes.length + objects.length + enums.length + unions.length +
inputs.length + interfacesStr.length > 0
val hasOperations = queries.length + mutations.length + subscriptions.length > 0
val typesAndOperations = s"""
${if (hasTypes)
"object Types {\n" +
argsTypes + "\n" +
newTypeClasses +
objects + "\n" +
inputs + "\n" +
unions + "\n" +
interfacesStr + "\n" +
enums + "\n" +
"\n}\n"
else ""}
${if (hasOperations)
"object Operations {\n" +
envSchemaDerivation +
queries + "\n\n" +
mutations + "\n\n" +
subscriptions + "\n" +
"\n}"
else ""}
"""
s"""${packageName.fold("")(p => s"package $p\n\n")}
${if (hasTypes && hasOperations) "import Types._\n" else ""}
${if (typesAndOperations.contains("@GQL") || newTypeClasses.nonEmpty)
"import caliban.schema.Annotations._\n"
else ""}
${if (newTypeClasses.nonEmpty)
"""|import caliban.Value._
|import caliban.parsing.adt.Directive
|import caliban.schema.{ArgBuilder, Schema}""".stripMargin
else ""}
${if (hasSubscriptions) "import zio.stream.ZStream\n" else ""}
$additionalImportsString
$typesAndOperations
"""
}
/* Get types for all subfields of an object
object A {
field b: B
field c: String
}
object B {
field d: Int
}
result: Set(B, String, Int)
*/
private def findNestedFieldTypes(
definition: TypeDefinition,
typeNameToDefinitionMap: Map[String, TypeDefinition]
): Set[String] = {
def findSubFieldTypes(
obj: TypeDefinition,
typeNameToNestedFields: Map[String, Set[String]]
): (Set[String], Map[String, Set[String]]) =
typeNameToNestedFields
.get(obj.name)
.fold {
val fieldTypes: Set[String] = obj match {
case objectTypeDefinition: ObjectTypeDefinition =>
objectTypeDefinition.fields.map(f => Type.innerType(f.ofType)).toSet
case interfaceTypeDefinition: TypeDefinition.InterfaceTypeDefinition =>
interfaceTypeDefinition.fields.map(f => Type.innerType(f.ofType)).toSet
case inputObjectTypeDefinition: InputObjectTypeDefinition =>
inputObjectTypeDefinition.fields.map(f => Type.innerType(f.ofType)).toSet
case unionTypeDefinition: TypeDefinition.UnionTypeDefinition =>
unionTypeDefinition.memberTypes.toSet
case _ => Set.empty
}
val (subTypes, f) = fieldTypes.foldLeft((Set.empty[String], typeNameToNestedFields)) {
case ((subTypeSet, reference), t) =>
typeNameToDefinitionMap.get(t) match {
case Some(o) =>
val (s, f) = findSubFieldTypes(o, reference.updated(obj.name, fieldTypes))
(subTypeSet ++ s, f)
case None => (subTypeSet, reference)
}
}
val allTypes = fieldTypes ++ subTypes
(allTypes, f.updated(obj.name, allTypes))
} { s =>
(s, typeNameToNestedFields)
}
val (s, _) = findSubFieldTypes(definition, Map.empty)
s
}
private def hasFieldWithDirective(definition: TypeDefinition, directive: String): Boolean =
definition match {
case ot: ObjectTypeDefinition =>
ot.fields.exists(_.directives.exists(_.name == directive))
case it: InterfaceTypeDefinition =>
it.fields.exists(_.directives.exists(_.name == directive))
case iot: InputObjectTypeDefinition =>
iot.fields.exists(_.directives.exists(_.name == directive))
case _: TypeDefinition => false
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy