All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.coxautodev.graphql.tools.SchemaClassScanner.kt Maven / Gradle / Ivy

There is a newer version: 5.2.4
Show newest version
package com.coxautodev.graphql.tools

import com.google.common.collect.BiMap
import com.google.common.collect.HashBiMap
import com.google.common.collect.Maps
import graphql.language.Definition
import graphql.language.FieldDefinition import graphql.language.InputObjectTypeDefinition
import graphql.language.InputValueDefinition
import graphql.language.InterfaceTypeDefinition
import graphql.language.ObjectTypeDefinition
import graphql.language.ScalarTypeDefinition
import graphql.language.SchemaDefinition
import graphql.language.TypeDefinition
import graphql.language.TypeExtensionDefinition
import graphql.language.TypeName
import graphql.language.UnionTypeDefinition
import graphql.schema.GraphQLScalarType
import graphql.schema.idl.ScalarInfo
import org.slf4j.LoggerFactory
import java.lang.reflect.Field
import java.lang.reflect.Method

/**
 * @author Andrew Potter
 */
internal class SchemaClassScanner(initialDictionary: BiMap>, allDefinitions: List, resolvers: List>, private val scalars: CustomScalarMap, private val options: SchemaParserOptions) {

    companion object {
        val log = LoggerFactory.getLogger(SchemaClassScanner::class.java)!!
    }

    private val rootInfo = RootTypeInfo.fromSchemaDefinitions(allDefinitions.filterIsInstance())

    private val queryResolvers = resolvers.filterIsInstance()
    private val mutationResolvers = resolvers.filterIsInstance()
    private val subscriptionResolvers = resolvers.filterIsInstance()

    private val resolverInfos = resolvers.minus(queryResolvers).minus(mutationResolvers).minus(subscriptionResolvers).map { NormalResolverInfo(it, options) }
    private val resolverInfosByDataClass = this.resolverInfos.associateBy { it.dataClassType }

    private val initialDictionary = initialDictionary.mapValues { InitialDictionaryEntry(it.value) }
    private val extensionDefinitions = allDefinitions.filterIsInstance()

    private val definitionsByName = (allDefinitions.filterIsInstance() - extensionDefinitions).associateBy { it.name }
    private val objectDefinitions = (allDefinitions.filterIsInstance() - extensionDefinitions)
    private val objectDefinitionsByName = objectDefinitions.associateBy { it.name }
    private val interfaceDefinitionsByName = allDefinitions.filterIsInstance().associateBy { it.name }

    private val fieldResolverScanner = FieldResolverScanner(options)
    private val typeClassMatcher = TypeClassMatcher(definitionsByName)
    private val dictionary = mutableMapOf()
    private val unvalidatedTypes = mutableSetOf()
    private val queue = linkedSetOf()

    private val fieldResolversByType = mutableMapOf>()

    init {
        initialDictionary.forEach { (name, clazz) ->
            if(!definitionsByName.containsKey(name)) {
                throw SchemaClassScannerError("Class in supplied dictionary '${clazz.name}' specified type name '$name', but a type definition with that name was not found!")
            }
        }

        if(options.allowUnimplementedResolvers) {
            log.warn("Option 'allowUnimplementedResolvers' should only be set to true during development, as it can cause schema errors to be moved to query time instead of schema creation time.  Make sure this is turned off in production.")
        }
    }

    /**
     * Attempts to discover GraphQL Type -> Java Class relationships by matching return types/argument types on known fields
     */
    fun scanForClasses(): SchemaParser {

        // Figure out what query, mutation and subscription types are called
        val rootTypeHolder = RootTypesHolder(rootInfo, definitionsByName, queryResolvers, mutationResolvers, subscriptionResolvers)

        handleRootType(rootTypeHolder.query)
        handleRootType(rootTypeHolder.mutation)
        handleRootType(rootTypeHolder.subscription)

        // Loop over all objects scanning each one only once for more objects to discover.
        while(queue.isNotEmpty()) {
            while (queue.isNotEmpty()) {
                while (queue.isNotEmpty()) {
                    scanQueueItemForPotentialMatches(queue.iterator().run { val item = next(); remove(); item })
                }

                // Require all implementors of discovered interfaces to be discovered or provided.
                handleInterfaceOrUnionSubTypes(getAllObjectTypesImplementingDiscoveredInterfaces(), { "Object type '${it.name}' implements a known interface, but no class was found for that type name.  Please pass a class for type '${it.name}' in the parser's dictionary." })
            }

            // Require all members of discovered unions to be discovered.
            handleInterfaceOrUnionSubTypes(getAllObjectTypeMembersOfDiscoveredUnions(), { "Object type '${it.name}' is a member of a known union, but no class was found for that type name.  Please pass a class for type '${it.name}' in the parser's dictionary." })
        }

        return validateAndCreateParser(rootTypeHolder)
    }

    /**
     * Adds all root resolvers for a type to the list of classes to scan
     */
    private fun handleRootType(rootType: RootType?) {
        if(rootType == null) {
            return
        }

        unvalidatedTypes.add(rootType.type)
        scanInterfacesOfType(rootType.type)
        scanResolverInfoForPotentialMatches(rootType.type, rootType.resolverInfo)
    }

    private fun validateAndCreateParser(rootTypeHolder: RootTypesHolder): SchemaParser {
        initialDictionary.filter { !it.value.accessed }.forEach {
            log.warn("Dictionary mapping was provided but never used, and can be safely deleted: \"${it.key}\" -> ${it.value.get().name}")
        }

        val observedDefinitions = dictionary.keys.toSet() + unvalidatedTypes

        // The dictionary doesn't need to know what classes are used with scalars.
        // In addition, scalars can have duplicate classes so that breaks the bi-map.
        val dictionary = try {
            Maps.unmodifiableBiMap(HashBiMap.create>().also {
                dictionary.filter { it.value.typeClass != null }.mapValuesTo(it) { it.value.typeClass }
            })
        } catch (t: Throwable) {
            throw SchemaClassScannerError("Error creating bimap of type => class", t)
        }
        val scalarDefinitions = observedDefinitions.filterIsInstance()

        // Ensure all scalar definitions have implementations and add the definition to those.
        val scalars = scalarDefinitions.filter { !ScalarInfo.STANDARD_SCALAR_DEFINITIONS.containsKey(it.name) }.map { definition ->
            val provided = scalars[definition.name] ?: throw SchemaClassScannerError("Expected a user-defined GraphQL scalar type with name '${definition.name}' but found none!")
            GraphQLScalarType(provided.name, SchemaParser.getDocumentation(definition) ?: provided.description, provided.coercing, definition)
        }.associateBy { it.name!! }

        (definitionsByName.values - observedDefinitions).forEach { definition ->
            log.warn("Schema type was defined but can never be accessed, and can be safely deleted: ${definition.name}")
        }

        val fieldResolvers = fieldResolversByType.flatMap { it.value.map { it.value } }
        val observedNormalResolverInfos = fieldResolvers.map { it.resolverInfo }.distinct().filterIsInstance()

        (resolverInfos - observedNormalResolverInfos).forEach { resolverInfo ->
            log.warn("Resolver was provided but no methods on it were used in data fetchers, and can be safely deleted: ${resolverInfo.resolver}")
        }

        validateRootResolversWereUsed(rootTypeHolder.query, fieldResolvers)
        validateRootResolversWereUsed(rootTypeHolder.mutation, fieldResolvers)
        validateRootResolversWereUsed(rootTypeHolder.subscription, fieldResolvers)

        return SchemaParser(dictionary, observedDefinitions + extensionDefinitions, scalars, rootInfo, fieldResolversByType.toMap())
    }

    fun validateRootResolversWereUsed(rootType: RootType?, fieldResolvers: List) {
        if(rootType == null) {
            return
        }

        val observedRootTypes = fieldResolvers.filter { it.resolverInfo is RootResolverInfo && it.resolverInfo == rootType.resolverInfo }.map { it.search.type }.toSet()

        rootType.resolvers.forEach { resolver ->
            if(resolver.javaClass !in observedRootTypes) {
                log.warn("Root ${rootType.name} resolver was provided but no methods on it were used in data fetchers for GraphQL type '${rootType.type.name}'!  Either remove the ${rootType.resolverInterface.name} interface from the resolver or remove the resolver entirely: $resolver")
            }
        }
    }

    fun getAllObjectTypesImplementingDiscoveredInterfaces(): List {
        return dictionary.keys.filterIsInstance().map { iface ->
            objectDefinitions.filter { obj -> obj.implements.filterIsInstance().any { it.name == iface.name } }
        }.flatten().distinct()
    }

    fun getAllObjectTypeMembersOfDiscoveredUnions(): List {
        return dictionary.keys.filterIsInstance().map { union ->
            union.memberTypes.filterIsInstance().map { objectDefinitionsByName[it.name] ?: throw SchemaClassScannerError("No object type found with name '${it.name}' for union: $union") }
        }.flatten().distinct()
    }

    fun handleInterfaceOrUnionSubTypes(types: List, failureMessage: (ObjectTypeDefinition) -> String) {
        types.forEach { type ->
            if(!dictionary.containsKey(type)) {
                val initialEntry = initialDictionary[type.name] ?: throw SchemaClassScannerError(failureMessage(type))
                handleFoundType(type, initialEntry.get(), DictionaryReference())
            }
        }
    }

    /**
     * Scan a new object for types that haven't been mapped yet.
     */
    private fun scanQueueItemForPotentialMatches(item: QueueItem) {
        scanResolverInfoForPotentialMatches(item.type, resolverInfosByDataClass[item.clazz] ?: DataClassResolverInfo(item.clazz))
    }

    private fun scanResolverInfoForPotentialMatches(type: ObjectTypeDefinition, resolverInfo: ResolverInfo) {
        type.getExtendedFieldDefinitions(extensionDefinitions).forEach { field ->
            val fieldResolver = fieldResolverScanner.findFieldResolver(field, resolverInfo)

            fieldResolversByType.getOrPut(type, { mutableMapOf() })[fieldResolver.field] = fieldResolver
            fieldResolver.scanForMatches().forEach { potentialMatch ->
                handleFoundType(typeClassMatcher.match(potentialMatch))
            }
        }
    }

    private fun handleFoundType(match: TypeClassMatcher.Match) {
        handleFoundType(match.type, match.clazz, match.reference)
    }

    /**
     * Enter a found type into the dictionary if it doesn't exist yet, add a reference pointing back to where it was discovered.
     */
    private fun handleFoundType(type: TypeDefinition, clazz: Class<*>?, reference: Reference) {
        if(!ignoreDictionaryForType(type)) {
            val realEntry = dictionary.getOrPut(type, { DictionaryEntry() })
            var typeWasSet = false

            if (clazz != null) {
                typeWasSet = realEntry.setTypeIfMissing(clazz)

                if (realEntry.typeClass != clazz) {
                    throw SchemaClassScannerError("Two different classes used for type ${type.name}:\n${realEntry.joinReferences()}\n\n- $clazz:\n|   ${reference.getDescription()}")
                }
            }

            realEntry.addReference(reference)

            // Check if we just added the entry... a little odd, but it works (and thread-safe, FWIW)
            if (typeWasSet && clazz != null) {
                handleNewType(type, clazz)
            }
        } else {
            unvalidatedTypes.add(type)
        }
    }

    private fun ignoreDictionaryForType(type: TypeDefinition): Boolean {
        return type is ScalarTypeDefinition
    }

    /**
     * Handle a newly found type, adding it to the list of actually used types and putting it in the scanning queue if it's an object type.
     */
    private fun handleNewType(graphQLType: TypeDefinition, javaType: Class<*>) {
        when(graphQLType) {
            is ObjectTypeDefinition -> {
                enqueue(graphQLType, javaType)
                scanInterfacesOfType(graphQLType)
            }

            is InputObjectTypeDefinition -> {
                graphQLType.inputValueDefinitions.forEach { inputValueDefinition ->
                    findInputValueType(inputValueDefinition.name, javaType)?.let { inputValueJavaType ->
                        val inputGraphQLType = inputValueDefinition.type.unwrap()
                        if(inputGraphQLType is TypeName && !ScalarInfo.STANDARD_SCALAR_DEFINITIONS.containsKey(inputGraphQLType.name)) {
                            handleFoundType(typeClassMatcher.match(TypeClassMatcher.PotentialMatch.parameterType(inputValueDefinition.type, inputValueJavaType, GenericType(javaType, options).relativeToType(inputValueJavaType), InputObjectReference(inputValueDefinition))))
                        }
                    }
                }
            }
        }
    }

    private fun scanInterfacesOfType(graphQLType: ObjectTypeDefinition) {
        graphQLType.implements.forEach {
            if(it is TypeName) {
                handleFoundType(interfaceDefinitionsByName[it.name] ?: throw SchemaClassScannerError("Object type ${graphQLType.name} declared interface ${it.name}, but no interface with that name was found in the schema!"), null, InterfaceReference(graphQLType))
            }
        }
    }

    private fun enqueue(graphQLType: ObjectTypeDefinition, javaType: Class<*>) {
        queue.add(QueueItem(graphQLType, javaType))
    }

    private fun findInputValueType(name: String, clazz: Class<*>): JavaType? {
        val methods = clazz.methods

        return (methods.find {
            it.name == name
        } ?: methods.find {
            it.name == "get${name.capitalize()}"
        })?.genericReturnType ?: clazz.fields.find {
            it.name == name
        }?.genericType
    }

    private data class QueueItem(val type: ObjectTypeDefinition, val clazz: Class<*>)

    private class DictionaryEntry {
        private val references = mutableListOf()
        var typeClass: Class<*>? = null
            private set

        fun setTypeIfMissing(typeClass: Class<*>): Boolean {
            if(this.typeClass == null) {
                this.typeClass = typeClass
                return true
            }

            return false
        }

        fun addReference(reference: Reference) {
            references.add(reference)
        }

        fun joinReferences() = "- $typeClass:\n|   " + references.map { it.getDescription() }.joinToString("\n|   ")
    }

    abstract class Reference {
        abstract fun getDescription(): String
        override fun toString() = getDescription()
    }

    private class DictionaryReference: Reference() {
        override fun getDescription() = "provided dictionary"
    }

    private class InterfaceReference(private val type: ObjectTypeDefinition): Reference() {
        override fun getDescription() = "interface declarations of ${type.name}"
    }

    private class InputObjectReference(private val type: InputValueDefinition): Reference() {
        override fun getDescription() = "input object $type"
    }

    private class InitialDictionaryEntry(private val clazz: Class<*>) {
        var accessed = false
            private set

        fun get(): Class<*> {
            accessed = true
            return clazz
        }
    }

    class ReturnValueReference(private val method: Method): Reference() {
        override fun getDescription() = "return type of method $method"
    }

    class MethodParameterReference(private val method: Method, private val index: Int): Reference() {
        override fun getDescription() = "parameter $index of method $method"
    }

    class FieldTypeReference(private val field: Field): Reference() {
        override fun getDescription() = "type of field $field"
    }

    class RootTypesHolder(rootInfo: RootTypeInfo, definitionsByName: Map, queryResolvers: List, mutationResolvers: List, subscriptionResolvers: List) {
        val queryName = rootInfo.getQueryName()
        val mutationName = rootInfo.getMutationName()
        val subscriptionName = rootInfo.getSubscriptionName()

        val queryDefinition = definitionsByName[queryName]
        val mutationDefinition = definitionsByName[mutationName]
        val subscriptionDefinition = definitionsByName[subscriptionName]

        val queryResolverInfo = RootResolverInfo(queryResolvers)
        val mutationResolverInfo = RootResolverInfo(mutationResolvers)
        val subscriptionResolverInfo = RootResolverInfo(subscriptionResolvers)

        val query = createRootType("query", queryDefinition, queryName, true, queryResolvers, GraphQLQueryResolver::class.java, queryResolverInfo)
        val mutation = createRootType("mutation", mutationDefinition, mutationName, rootInfo.isMutationRequired(), mutationResolvers, GraphQLMutationResolver::class.java, mutationResolverInfo)
        val subscription = createRootType("subscription", subscriptionDefinition, subscriptionName, rootInfo.isSubscriptionRequired(), subscriptionResolvers, GraphQLSubscriptionResolver::class.java, subscriptionResolverInfo)

        fun createRootType(name: String, type: TypeDefinition?, typeName: String, required: Boolean, resolvers: List, resolverInterface: Class<*>, resolverInfo: RootResolverInfo): RootType? {
            if(type == null) {
                if(required) {
                    throw SchemaClassScannerError("Type definition for root $name type '$typeName' not found!")
                }

                return null
            }

            if(type !is ObjectTypeDefinition) {
                throw SchemaClassScannerError("Expected root query type's type to be ${ObjectTypeDefinition::class.java.simpleName}, but it was ${type.javaClass.simpleName}")
            }

            // Find query resolver class
            if(resolvers.isEmpty()) {
                throw SchemaClassScannerError("No Root resolvers for $name type '$typeName' found!  Provide one or more ${resolverInterface.name} to the builder.")
            }

            return RootType(name, type, resolvers, resolverInterface, resolverInfo)
        }
    }

    class RootType(val name: String, val type: ObjectTypeDefinition, val resolvers: List, val resolverInterface: Class<*>, val resolverInfo: RootResolverInfo)
}

class SchemaClassScannerError(message: String, throwable: Throwable? = null) : RuntimeException(message, throwable)

internal typealias TypeClassDictionary = BiMap>




© 2015 - 2024 Weber Informatics LLC | Privacy Policy