All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.netflix.graphql.dgs.internal.DgsDataLoaderProvider.kt Maven / Gradle / Ivy

There is a newer version: 10.0.1
Show newest version
/*
 * Copyright 2022 Netflix, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.netflix.graphql.dgs.internal

import com.netflix.graphql.dgs.DataLoaderInstrumentationExtensionProvider
import com.netflix.graphql.dgs.DgsComponent
import com.netflix.graphql.dgs.DgsDataLoader
import com.netflix.graphql.dgs.DgsDataLoaderOptionsProvider
import com.netflix.graphql.dgs.DgsDataLoaderRegistryConsumer
import com.netflix.graphql.dgs.DgsDispatchPredicate
import com.netflix.graphql.dgs.exceptions.DgsUnnamedDataLoaderOnFieldException
import com.netflix.graphql.dgs.exceptions.InvalidDataLoaderTypeException
import com.netflix.graphql.dgs.exceptions.UnsupportedSecuredDataLoaderException
import com.netflix.graphql.dgs.internal.utils.DataLoaderNameUtil
import jakarta.annotation.PostConstruct
import org.dataloader.BatchLoader
import org.dataloader.BatchLoaderWithContext
import org.dataloader.DataLoader
import org.dataloader.DataLoaderFactory
import org.dataloader.DataLoaderRegistry
import org.dataloader.MappedBatchLoader
import org.dataloader.MappedBatchLoaderWithContext
import org.dataloader.registries.DispatchPredicate
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import org.springframework.aop.support.AopUtils
import org.springframework.beans.factory.NoSuchBeanDefinitionException
import org.springframework.context.ApplicationContext
import org.springframework.util.ReflectionUtils
import java.util.function.Supplier
import kotlin.system.measureTimeMillis

/**
 * Framework implementation class responsible for finding and configuring data loaders.
 */
class DgsDataLoaderProvider(
    private val applicationContext: ApplicationContext,
    private val dataLoaderOptionsProvider: DgsDataLoaderOptionsProvider = DefaultDataLoaderOptionsProvider()
) {

    private data class LoaderHolder(val theLoader: T, val annotation: DgsDataLoader, val name: String, val dispatchPredicate: DispatchPredicate? = null)

    private val batchLoaders = mutableListOf>>()
    private val batchLoadersWithContext = mutableListOf>>()
    private val mappedBatchLoaders = mutableListOf>>()
    private val mappedBatchLoadersWithContext = mutableListOf>>()

    fun buildRegistry(): DataLoaderRegistry {
        return buildRegistryWithContextSupplier { null }
    }

    fun  buildRegistryWithContextSupplier(contextSupplier: Supplier): DataLoaderRegistry {
        val registry = DgsDataLoaderRegistry()
        val totalTime = measureTimeMillis {
            val extensionProviders = applicationContext
                .getBeanProvider(DataLoaderInstrumentationExtensionProvider::class.java)
                .orderedStream()
                .toList()

            batchLoaders.forEach { registerDataLoader(it, registry, contextSupplier, extensionProviders) }
            batchLoadersWithContext.forEach { registerDataLoader(it, registry, contextSupplier, extensionProviders) }
            mappedBatchLoaders.forEach { registerDataLoader(it, registry, contextSupplier, extensionProviders) }
            mappedBatchLoadersWithContext.forEach { registerDataLoader(it, registry, contextSupplier, extensionProviders) }
        }
        logger.debug("Created DGS dataloader registry in {}ms", totalTime)
        return registry
    }

    @PostConstruct
    internal fun findDataLoaders() {
        addDataLoaderComponents()
        addDataLoaderFields()
    }

    private fun addDataLoaderFields() {
        applicationContext.getBeansWithAnnotation(DgsComponent::class.java).values.forEach { dgsComponent ->
            val javaClass = AopUtils.getTargetClass(dgsComponent)

            javaClass.declaredFields.asSequence().filter { it.isAnnotationPresent(DgsDataLoader::class.java) }
                .forEach { field ->
                    if (AopUtils.isAopProxy(dgsComponent)) {
                        throw UnsupportedSecuredDataLoaderException(dgsComponent::class.java)
                    }

                    val annotation = field.getAnnotation(DgsDataLoader::class.java)
                    ReflectionUtils.makeAccessible(field)

                    if (annotation.name == DgsDataLoader.GENERATE_DATA_LOADER_NAME) {
                        throw DgsUnnamedDataLoaderOnFieldException(field)
                    }

                    fun  createHolder(t: T): LoaderHolder = LoaderHolder(t, annotation, annotation.name)
                    when (val get = field.get(dgsComponent)) {
                        is BatchLoader<*, *> -> batchLoaders.add(createHolder(get))
                        is BatchLoaderWithContext<*, *> -> batchLoadersWithContext.add(createHolder(get))
                        is MappedBatchLoader<*, *> -> mappedBatchLoaders.add(createHolder(get))
                        is MappedBatchLoaderWithContext<*, *> -> mappedBatchLoadersWithContext.add(createHolder(get))
                        else -> throw InvalidDataLoaderTypeException(dgsComponent::class.java)
                    }
                }
        }
    }

    private fun addDataLoaderComponents() {
        val dataLoaders = applicationContext.getBeansWithAnnotation(DgsDataLoader::class.java)
        dataLoaders.values.forEach { dgsComponent ->
            val javaClass = AopUtils.getTargetClass(dgsComponent)
            val annotation = javaClass.getAnnotation(DgsDataLoader::class.java)
            val predicateField = javaClass.declaredFields.asSequence().find { it.isAnnotationPresent(DgsDispatchPredicate::class.java) }
            if (predicateField != null) {
                ReflectionUtils.makeAccessible(predicateField)
                val dispatchPredicate = predicateField.get(dgsComponent)
                if (dispatchPredicate is DispatchPredicate) {
                    addDataLoaders(dgsComponent, javaClass, annotation, dispatchPredicate)
                }
            } else {
                addDataLoaders(dgsComponent, javaClass, annotation, null)
            }
        }
    }

    private fun addDataLoaders(dgsComponent: T, targetClass: Class<*>, annotation: DgsDataLoader, dispatchPredicate: DispatchPredicate?) {
        fun  createHolder(t: T): LoaderHolder =
            LoaderHolder(t, annotation, DataLoaderNameUtil.getDataLoaderName(targetClass, annotation), dispatchPredicate)
        when (dgsComponent) {
            is BatchLoader<*, *> -> batchLoaders.add(createHolder(dgsComponent))
            is BatchLoaderWithContext<*, *> -> batchLoadersWithContext.add(createHolder(dgsComponent))
            is MappedBatchLoader<*, *> -> mappedBatchLoaders.add(createHolder(dgsComponent))
            is MappedBatchLoaderWithContext<*, *> -> mappedBatchLoadersWithContext.add(createHolder(dgsComponent))
            else -> throw InvalidDataLoaderTypeException(dgsComponent::class.java)
        }
    }

    private fun createDataLoader(
        batchLoader: BatchLoader<*, *>,
        dgsDataLoader: DgsDataLoader,
        dataLoaderName: String,
        dataLoaderRegistry: DataLoaderRegistry,
        extensionProviders: Iterable
    ): DataLoader<*, *> {
        val options = dataLoaderOptionsProvider.getOptions(dataLoaderName, dgsDataLoader)

        if (batchLoader is DgsDataLoaderRegistryConsumer) {
            batchLoader.setDataLoaderRegistry(dataLoaderRegistry)
        }

        val extendedBatchLoader = wrappedDataLoader(batchLoader, dataLoaderName, extensionProviders)
        return DataLoaderFactory.newDataLoader(extendedBatchLoader, options)
    }

    private fun createDataLoader(
        batchLoader: MappedBatchLoader<*, *>,
        dgsDataLoader: DgsDataLoader,
        dataLoaderName: String,
        dataLoaderRegistry: DataLoaderRegistry,
        extensionProviders: Iterable
    ): DataLoader<*, *> {
        val options = dataLoaderOptionsProvider.getOptions(dataLoaderName, dgsDataLoader)

        if (batchLoader is DgsDataLoaderRegistryConsumer) {
            batchLoader.setDataLoaderRegistry(dataLoaderRegistry)
        }
        val extendedBatchLoader = wrappedDataLoader(batchLoader, dataLoaderName, extensionProviders)

        return DataLoaderFactory.newMappedDataLoader(extendedBatchLoader, options)
    }

    private fun  createDataLoader(
        batchLoader: BatchLoaderWithContext<*, *>,
        dgsDataLoader: DgsDataLoader,
        dataLoaderName: String,
        supplier: Supplier,
        dataLoaderRegistry: DataLoaderRegistry,
        extensionProviders: Iterable
    ): DataLoader<*, *> {
        val options = dataLoaderOptionsProvider.getOptions(dataLoaderName, dgsDataLoader)
            .setBatchLoaderContextProvider(supplier::get)

        if (batchLoader is DgsDataLoaderRegistryConsumer) {
            batchLoader.setDataLoaderRegistry(dataLoaderRegistry)
        }

        val extendedBatchLoader = wrappedDataLoader(batchLoader, dataLoaderName, extensionProviders)
        return DataLoaderFactory.newDataLoader(extendedBatchLoader, options)
    }

    private fun  createDataLoader(
        batchLoader: MappedBatchLoaderWithContext<*, *>,
        dgsDataLoader: DgsDataLoader,
        dataLoaderName: String,
        supplier: Supplier,
        dataLoaderRegistry: DataLoaderRegistry,
        extensionProviders: Iterable
    ): DataLoader<*, *> {
        val options = dataLoaderOptionsProvider.getOptions(dataLoaderName, dgsDataLoader)
            .setBatchLoaderContextProvider(supplier::get)

        if (batchLoader is DgsDataLoaderRegistryConsumer) {
            batchLoader.setDataLoaderRegistry(dataLoaderRegistry)
        }

        val extendedBatchLoader = wrappedDataLoader(batchLoader, dataLoaderName, extensionProviders)
        return DataLoaderFactory.newMappedDataLoader(extendedBatchLoader, options)
    }

    private fun registerDataLoader(
        holder: LoaderHolder<*>,
        registry: DgsDataLoaderRegistry,
        contextSupplier: Supplier<*>,
        extensionProviders: Iterable
    ) {
        val loader = when (holder.theLoader) {
            is BatchLoader<*, *> -> createDataLoader(holder.theLoader, holder.annotation, holder.name, registry, extensionProviders)
            is BatchLoaderWithContext<*, *> -> createDataLoader(holder.theLoader, holder.annotation, holder.name, contextSupplier, registry, extensionProviders)
            is MappedBatchLoader<*, *> -> createDataLoader(holder.theLoader, holder.annotation, holder.name, registry, extensionProviders)
            is MappedBatchLoaderWithContext<*, *> -> createDataLoader(holder.theLoader, holder.annotation, holder.name, contextSupplier, registry, extensionProviders)
            else -> throw IllegalArgumentException("Data loader ${holder.name} has unknown type")
        }
        if (holder.dispatchPredicate == null) {
            registry.register(holder.name, loader)
        } else {
            registry.registerWithDispatchPredicate(holder.name, loader, holder.dispatchPredicate)
        }
    }

    private inline fun  wrappedDataLoader(
        loader: T,
        name: String,
        extensionProviders: Iterable
    ): T {
        try {
            when (loader) {
                is BatchLoader<*, *> -> {
                    var wrappedBatchLoader: BatchLoader<*, *> = loader
                    extensionProviders.forEach { wrappedBatchLoader = it.provide(wrappedBatchLoader, name) }
                    return wrappedBatchLoader as T
                }
                is BatchLoaderWithContext<*, *> -> {
                    var wrappedBatchLoader: BatchLoaderWithContext<*, *> = loader
                    extensionProviders.forEach { wrappedBatchLoader = it.provide(wrappedBatchLoader, name) }
                    return wrappedBatchLoader as T
                }
                is MappedBatchLoader<*, *> -> {
                    var wrappedBatchLoader: MappedBatchLoader<*, *> = loader
                    extensionProviders.forEach { wrappedBatchLoader = it.provide(wrappedBatchLoader, name) }
                    return wrappedBatchLoader as T
                }
                is MappedBatchLoaderWithContext<*, *> -> {
                    var wrappedBatchLoader: MappedBatchLoaderWithContext<*, *> = loader
                    extensionProviders.forEach { wrappedBatchLoader = it.provide(wrappedBatchLoader, name) }
                    return wrappedBatchLoader as T
                }
            }
        } catch (ex: NoSuchBeanDefinitionException) {
            logger.debug("Unable to wrap the [{} : {}]", name, loader, ex)
        }
        return loader
    }

    private companion object {
        private val logger: Logger = LoggerFactory.getLogger(DgsDataLoaderProvider::class.java)
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy