All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.connectrpc.protocgen.connect.Generator.kt Maven / Gradle / Ivy

There is a newer version: 0.7.2
Show newest version
// Copyright 2022-2023 The Connect Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.connectrpc.protocgen.connect

import com.connectrpc.BidirectionalStreamInterface
import com.connectrpc.ClientOnlyStreamInterface
import com.connectrpc.Idempotency
import com.connectrpc.MethodSpec
import com.connectrpc.ProtocolClientInterface
import com.connectrpc.ResponseMessage
import com.connectrpc.ServerOnlyStreamInterface
import com.connectrpc.StreamType
import com.connectrpc.UnaryBlockingCall
import com.connectrpc.protocgen.connect.internal.CodeGenerator
import com.connectrpc.protocgen.connect.internal.Configuration
import com.connectrpc.protocgen.connect.internal.Plugin
import com.connectrpc.protocgen.connect.internal.SourceInfo
import com.connectrpc.protocgen.connect.internal.getClassName
import com.connectrpc.protocgen.connect.internal.getFileJavaPackage
import com.connectrpc.protocgen.connect.internal.parse
import com.connectrpc.protocgen.connect.internal.withSourceInfo
import com.google.protobuf.DescriptorProtos
import com.google.protobuf.DescriptorProtos.FileDescriptorProto
import com.google.protobuf.DescriptorProtos.MethodOptions.IdempotencyLevel
import com.google.protobuf.Descriptors
import com.google.protobuf.compiler.PluginProtos
import com.squareup.kotlinpoet.ClassName
import com.squareup.kotlinpoet.CodeBlock
import com.squareup.kotlinpoet.FileSpec
import com.squareup.kotlinpoet.FunSpec
import com.squareup.kotlinpoet.KModifier
import com.squareup.kotlinpoet.LambdaTypeName
import com.squareup.kotlinpoet.ParameterSpec
import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy
import com.squareup.kotlinpoet.PropertySpec
import com.squareup.kotlinpoet.TypeSpec
import com.squareup.kotlinpoet.asClassName
import com.squareup.kotlinpoet.asTypeName

/*
 * These are constants since com.connectrpc.Headers and com.connectrpc.http.Cancelable
 * are type aliases which doesn't have an underlying class for KotlinPoet to know what to do.
 *
 * The conventional and nicer way is to use the class type: Headers::class.asClassType() but
 * type aliasing does not allow for that.
 *
 * Instead, this is the way to reference these objects for now. If there is ever a desire to
 * move off of type aliases, this can be changed without user API breakage.
 */
private val HEADERS_CLASS_NAME = ClassName("com.connectrpc", "Headers")
private val CANCELABLE_CLASS_NAME = ClassName("com.connectrpc.http", "Cancelable")

class Generator : CodeGenerator {
    private lateinit var descriptorSource: Plugin.DescriptorSource
    private lateinit var configuration: Configuration
    private val protoFileMap = mutableMapOf()

    override fun generate(
        request: PluginProtos.CodeGeneratorRequest,
        descriptorSource: Plugin.DescriptorSource,
        response: Plugin.Response,
    ) {
        this.descriptorSource = descriptorSource
        configuration = parse(request.parameter)
        for (protoFile in request.protoFileList) {
            protoFileMap[protoFile.name] = protoFile
        }
        for (fileName in request.fileToGenerateList) {
            val file =
                descriptorSource.findFileByName(fileName) ?: throw RuntimeException("no descriptor sources found.")
            if (file.services.isEmpty()) {
                // Avoid generating files with no service definitions.
                continue
            }
            val fileMap = parseFile(file)
            for ((className, fileSpec) in fileMap) {
                try {
                    response.addFile("${className.canonicalName.packageToDirectory()}.kt", fileSpec.toString())
                } catch (e: Throwable) {
                    throw Throwable("failure on generating ${file.name}", e)
                }
            }
        }
    }

    override fun getSupportedFeatures(): Array {
        return arrayOf(
            PluginProtos.CodeGeneratorResponse.Feature.FEATURE_PROTO3_OPTIONAL,
            PluginProtos.CodeGeneratorResponse.Feature.FEATURE_SUPPORTS_EDITIONS,
        )
    }

    override fun getMinimumEdition(): DescriptorProtos.Edition {
        return DescriptorProtos.Edition.EDITION_PROTO2
    }

    override fun getMaximumEdition(): DescriptorProtos.Edition {
        return DescriptorProtos.Edition.EDITION_2023
    }

    private fun parseFile(file: Descriptors.FileDescriptor): Map {
        val baseSourceInfo = SourceInfo(protoFileMap[file.name]!!, descriptorSource, emptyList())
        val fileSpecs = mutableMapOf()
        val packageName = getFileJavaPackage(file)
        for ((sourceInfo, service) in file.services.withSourceInfo(
            baseSourceInfo,
            FileDescriptorProto.SERVICE_FIELD_NUMBER,
        )) {
            val interfaceFileSpec = FileSpec.builder(packageName, file.name)
                .addFileComment("Code generated by connect-kotlin. DO NOT EDIT.\n")
                .addFileComment("\n")
                .addFileComment("Source: ${file.name}\n")
                .addType(serviceClientInterface(packageName, service, sourceInfo))
                .build()
            fileSpecs[serviceClientInterfaceClassName(packageName, service)] = interfaceFileSpec

            val implementationFileSpecBuilder = FileSpec.builder(packageName, file.name)
                .addImport(MethodSpec::class.java.`package`.name, "MethodSpec")
                .addImport(StreamType::class.java.`package`.name, "StreamType")
                .addFileComment("Code generated by connect-kotlin. DO NOT EDIT.\n")
                .addFileComment("\n")
                .addFileComment("Source: ${file.name}\n")
                // Set the file package for the generated methods.
                .addType(serviceClientImplementation(packageName, service, sourceInfo))
            for (method in service.methods) {
                if (method.options.hasIdempotencyLevel()) {
                    implementationFileSpecBuilder.addImport(Idempotency::class.java.`package`.name, "Idempotency")
                    break
                }
            }
            val implementationFileSpec = implementationFileSpecBuilder.build()
            fileSpecs[serviceClientImplementationClassName(packageName, service)] = implementationFileSpec
        }
        return fileSpecs
    }

    private fun serviceClientInterface(
        packageName: String,
        service: Descriptors.ServiceDescriptor,
        sourceInfo: SourceInfo,
    ): TypeSpec {
        val interfaceBuilder = TypeSpec.interfaceBuilder(serviceClientInterfaceClassName(packageName, service))
        val functionSpecs = interfaceMethods(service.methods, sourceInfo)
        return interfaceBuilder
            .addKdoc(sourceInfo.comment().sanitizeKdoc())
            .addFunctions(functionSpecs)
            .build()
    }

    private fun interfaceMethods(
        methods: List,
        baseSourceInfo: SourceInfo,
    ): List {
        val functions = mutableListOf()
        val headerParameterSpec = ParameterSpec.builder("headers", HEADERS_CLASS_NAME)
            .defaultValue("%L", "emptyMap()")
            .build()
        for ((sourceInfo, method) in methods.withSourceInfo(
            baseSourceInfo,
            DescriptorProtos.ServiceDescriptorProto.METHOD_FIELD_NUMBER,
        )) {
            val inputClassName = classNameFromType(method.inputType)
            val outputClassName = classNameFromType(method.outputType)
            if (method.isClientStreaming && method.isServerStreaming) {
                val streamingBuilder = FunSpec.builder(method.name.lowerCamelCase())
                    .addKdoc(sourceInfo.comment().sanitizeKdoc())
                    .addModifiers(KModifier.ABSTRACT)
                    .addModifiers(KModifier.SUSPEND)
                    .addParameter(headerParameterSpec)
                    .returns(
                        BidirectionalStreamInterface::class.asClassName()
                            .parameterizedBy(inputClassName, outputClassName),
                    )
                functions.add(streamingBuilder.build())
            } else if (method.isServerStreaming) {
                val serverStreamingFunction = FunSpec.builder(method.name.lowerCamelCase())
                    .addKdoc(sourceInfo.comment().sanitizeKdoc())
                    .addModifiers(KModifier.ABSTRACT)
                    .addModifiers(KModifier.SUSPEND)
                    .addParameter(headerParameterSpec)
                    .returns(
                        ServerOnlyStreamInterface::class.asClassName().parameterizedBy(inputClassName, outputClassName),
                    )
                    .build()
                functions.add(serverStreamingFunction)
            } else if (method.isClientStreaming) {
                val clientStreamingFunction = FunSpec.builder(method.name.lowerCamelCase())
                    .addKdoc(sourceInfo.comment().sanitizeKdoc())
                    .addModifiers(KModifier.ABSTRACT)
                    .addModifiers(KModifier.SUSPEND)
                    .addParameter(headerParameterSpec)
                    .returns(
                        ClientOnlyStreamInterface::class.asClassName().parameterizedBy(inputClassName, outputClassName),
                    )
                    .build()
                functions.add(clientStreamingFunction)
            } else {
                if (configuration.generateCoroutineMethods) {
                    val unarySuspendFunction = FunSpec.builder(method.name.lowerCamelCase())
                        .addKdoc(sourceInfo.comment().sanitizeKdoc())
                        .addModifiers(KModifier.ABSTRACT)
                        .addModifiers(KModifier.SUSPEND)
                        .addParameter("request", inputClassName)
                        .addParameter(headerParameterSpec)
                        .returns(ResponseMessage::class.asClassName().parameterizedBy(outputClassName))
                        .build()
                    functions.add(unarySuspendFunction)
                }
                if (configuration.generateCallbackMethods) {
                    val callbackType = LambdaTypeName.get(
                        parameters = listOf(
                            ParameterSpec(
                                "",
                                ResponseMessage::class.asTypeName().parameterizedBy(outputClassName),
                            ),
                        ),
                        returnType = Unit::class.java.asTypeName(),
                    )
                    val unaryCallbackFunction = FunSpec.builder(method.name.lowerCamelCase())
                        .addKdoc(sourceInfo.comment().sanitizeKdoc())
                        .addModifiers(KModifier.ABSTRACT)
                        .addParameter("request", inputClassName)
                        .addParameter(headerParameterSpec)
                        .addParameter("onResult", callbackType)
                        .returns(CANCELABLE_CLASS_NAME)
                        .build()
                    functions.add(unaryCallbackFunction)
                }
                if (configuration.generateBlockingUnaryMethods) {
                    val unarySuspendFunction = FunSpec.builder("${method.name.lowerCamelCase()}Blocking")
                        .addKdoc(sourceInfo.comment().sanitizeKdoc())
                        .addModifiers(KModifier.ABSTRACT)
                        .addParameter("request", inputClassName)
                        .addParameter(headerParameterSpec)
                        .returns(UnaryBlockingCall::class.asClassName().parameterizedBy(outputClassName))
                        .build()
                    functions.add(unarySuspendFunction)
                }
            }
        }
        return functions
    }

    private fun serviceClientImplementation(
        javaPackageName: String,
        service: Descriptors.ServiceDescriptor,
        sourceInfo: SourceInfo,
    ): TypeSpec {
        // The javaPackageName is used instead of the package name for imports and code references.
        val classBuilder = TypeSpec.classBuilder(serviceClientImplementationClassName(javaPackageName, service))
            .addSuperinterface(serviceClientInterfaceClassName(javaPackageName, service))
            .primaryConstructor(
                FunSpec.constructorBuilder()
                    .addParameter("client", ProtocolClientInterface::class)
                    .build(),
            )
            .addProperty(
                PropertySpec.builder("client", ProtocolClientInterface::class, KModifier.PRIVATE)
                    .initializer("client")
                    .build(),
            )
        val functionSpecs = implementationMethods(service.methods, sourceInfo)
        return classBuilder
            .addKdoc(sourceInfo.comment().sanitizeKdoc())
            .addFunctions(functionSpecs)
            .build()
    }

    private fun implementationMethods(
        methods: List,
        baseSourceInfo: SourceInfo,
    ): List {
        val functions = mutableListOf()
        for ((sourceInfo, method) in methods.withSourceInfo(
            baseSourceInfo,
            DescriptorProtos.ServiceDescriptorProto.METHOD_FIELD_NUMBER,
        )) {
            val inputClassName = classNameFromType(method.inputType)
            val outputClassName = classNameFromType(method.outputType)
            val methodSpecBuilder = CodeBlock.builder()
                .addStatement("MethodSpec(")
                .addStatement("\"${method.service.fullName}/${method.name}\",")
                .indent()
                .addStatement("$inputClassName::class,")
                .addStatement("$outputClassName::class,")
            if (method.isClientStreaming && method.isServerStreaming) {
                methodSpecBuilder.addStatement("StreamType.${StreamType.BIDI.name},")
            } else if (method.isClientStreaming) {
                methodSpecBuilder.addStatement("StreamType.${StreamType.CLIENT.name},")
            } else if (method.isServerStreaming) {
                methodSpecBuilder.addStatement("StreamType.${StreamType.SERVER.name},")
            } else {
                methodSpecBuilder.addStatement("StreamType.${StreamType.UNARY.name},")
            }
            when (method.options.idempotencyLevel) {
                IdempotencyLevel.NO_SIDE_EFFECTS -> methodSpecBuilder.addStatement("idempotency = Idempotency.${Idempotency.NO_SIDE_EFFECTS.name},")
                IdempotencyLevel.IDEMPOTENT -> methodSpecBuilder.addStatement("idempotency = Idempotency.${Idempotency.IDEMPOTENT.name},")
                else -> {
                    // Use default value in method spec.
                }
            }
            val methodSpecCallBlock = methodSpecBuilder
                .unindent()
                .addStatement("),")
                .build()
            if (method.isClientStreaming && method.isServerStreaming) {
                val streamingFunction = FunSpec.builder(method.name.lowerCamelCase())
                    .addKdoc(sourceInfo.comment().sanitizeKdoc())
                    .addModifiers(KModifier.OVERRIDE)
                    .addModifiers(KModifier.SUSPEND)
                    .addParameter("headers", HEADERS_CLASS_NAME)
                    .returns(
                        BidirectionalStreamInterface::class.asClassName()
                            .parameterizedBy(
                                inputClassName,
                                outputClassName,
                            ),
                    )
                    .addStatement(
                        "return %L",
                        CodeBlock.builder()
                            .addStatement("client.stream(")
                            .indent()
                            .addStatement("headers,")
                            .add(methodSpecCallBlock)
                            .unindent()
                            .addStatement(")")
                            .build(),
                    )
                    .build()
                functions.add(streamingFunction)
            } else if (method.isServerStreaming) {
                val serverStreamingFunction = FunSpec.builder(method.name.lowerCamelCase())
                    .addKdoc(sourceInfo.comment().sanitizeKdoc())
                    .addModifiers(KModifier.OVERRIDE)
                    .addModifiers(KModifier.SUSPEND)
                    .addParameter("headers", HEADERS_CLASS_NAME)
                    .returns(
                        ServerOnlyStreamInterface::class.asClassName().parameterizedBy(inputClassName, outputClassName),
                    )
                    .addStatement(
                        "return %L",
                        CodeBlock.builder()
                            .addStatement("client.serverStream(")
                            .indent()
                            .addStatement("headers,")
                            .add(methodSpecCallBlock)
                            .unindent()
                            .addStatement(")")
                            .build(),
                    )
                    .build()
                functions.add(serverStreamingFunction)
            } else if (method.isClientStreaming) {
                val clientStreamingFunction = FunSpec.builder(method.name.lowerCamelCase())
                    .addKdoc(sourceInfo.comment().sanitizeKdoc())
                    .addModifiers(KModifier.OVERRIDE)
                    .addModifiers(KModifier.SUSPEND)
                    .addParameter("headers", HEADERS_CLASS_NAME)
                    .returns(
                        ClientOnlyStreamInterface::class.asClassName().parameterizedBy(inputClassName, outputClassName),
                    )
                    .addStatement(
                        "return %L",
                        CodeBlock.builder()
                            .addStatement("client.clientStream(")
                            .indent()
                            .addStatement("headers,")
                            .add(methodSpecCallBlock)
                            .unindent()
                            .addStatement(")")
                            .build(),
                    )
                    .build()
                functions.add(clientStreamingFunction)
            } else {
                if (configuration.generateCoroutineMethods) {
                    val unarySuspendFunction = FunSpec.builder(method.name.lowerCamelCase())
                        .addKdoc(sourceInfo.comment().sanitizeKdoc())
                        .addModifiers(KModifier.SUSPEND)
                        .addModifiers(KModifier.OVERRIDE)
                        .addParameter("request", inputClassName)
                        .addParameter("headers", HEADERS_CLASS_NAME)
                        .returns(ResponseMessage::class.asClassName().parameterizedBy(outputClassName))
                        .addStatement(
                            "return %L",
                            CodeBlock.builder()
                                .addStatement("client.unary(")
                                .indent()
                                .addStatement("request,")
                                .addStatement("headers,")
                                .add(methodSpecCallBlock)
                                .unindent()
                                .addStatement(")")
                                .build(),
                        )
                        .build()
                    functions.add(unarySuspendFunction)
                }
                if (configuration.generateCallbackMethods) {
                    val callbackType = LambdaTypeName.get(
                        parameters = listOf(
                            ParameterSpec(
                                "",
                                ResponseMessage::class.asTypeName().parameterizedBy(outputClassName),
                            ),
                        ),
                        returnType = Unit::class.java.asTypeName(),
                    )
                    val unaryCallbackFunction = FunSpec.builder(method.name.lowerCamelCase())
                        .addKdoc(sourceInfo.comment().sanitizeKdoc())
                        .addModifiers(KModifier.OVERRIDE)
                        .addParameter("request", inputClassName)
                        .addParameter("headers", HEADERS_CLASS_NAME)
                        .addParameter("onResult", callbackType)
                        .returns(CANCELABLE_CLASS_NAME)
                        .addStatement(
                            "return %L",
                            CodeBlock.builder()
                                .addStatement("client.unary(")
                                .indent()
                                .addStatement("request,")
                                .addStatement("headers,")
                                .add(methodSpecCallBlock)
                                .addStatement("onResult")
                                .unindent()
                                .addStatement(")")
                                .build(),
                        )
                        .build()
                    functions.add(unaryCallbackFunction)
                }
                if (configuration.generateBlockingUnaryMethods) {
                    val unarySuspendFunction = FunSpec.builder("${method.name.lowerCamelCase()}Blocking")
                        .addKdoc(sourceInfo.comment().sanitizeKdoc())
                        .addModifiers(KModifier.OVERRIDE)
                        .addParameter("request", inputClassName)
                        .addParameter("headers", HEADERS_CLASS_NAME)
                        .returns(UnaryBlockingCall::class.asClassName().parameterizedBy(outputClassName))
                        .addStatement(
                            "return %L",
                            CodeBlock.builder()
                                .addStatement("client.unaryBlocking(")
                                .indent()
                                .addStatement("request,")
                                .addStatement("headers,")
                                .add(methodSpecCallBlock)
                                .unindent()
                                .addStatement(")")
                                .build(),
                        )
                        .build()
                    functions.add(unarySuspendFunction)
                }
            }
        }
        return functions
    }

    private fun classNameFromType(descriptor: Descriptors.Descriptor): ClassName {
        // Get the package of the descriptor's file.
        // e.g. "com.connectrpc".
        val packageName = getFileJavaPackage(descriptor.file)
        // Get the fully qualified class name of the descriptor
        // and subtract the file's package.
        // e.g. "com.connectrpc.EmptyMessage.InnerMessage"
        // becomes ["EmptyMessage", "InnerMessage"].
        val names = getClassName(descriptor)
            .removePrefix(packageName)
            .removePrefix(".")
            .split(".")
        // Case when there is a nested entity.
        // e.g Nested message definitions and messages within "*OuterClass.java".
        if (names.size > 1) {
            return ClassName(packageName, names.first(), *names.subList(1, names.size).toTypedArray())
        }
        return ClassName(packageName, names.first())
    }

    private fun String.sanitizeKdoc(): String {
        return this
            // Remove trailing whitespace on each line.
            .replace("[^\\S\n]+\n".toRegex(), "\n")
            .replace("\\s+$".toRegex(), "")
            .replace("\\*/".toRegex(), "*/")
            .replace("/\\*".toRegex(), "/*")
            .replace("""[""", "[")
            .replace("""]""", "]")
            .replace("@", "@")
    }
}

private fun serviceClientInterfaceClassName(packageName: String, service: Descriptors.ServiceDescriptor): ClassName {
    return ClassName(packageName, "${service.name}ClientInterface")
}

private fun serviceClientImplementationClassName(
    packageName: String,
    service: Descriptors.ServiceDescriptor,
): ClassName {
    return ClassName(packageName, "${service.name}Client")
}

private fun String.lowerCamelCase(): String {
    return replaceFirstChar { char -> char.lowercaseChar() }
}

private fun String.packageToDirectory(): String {
    val dir = replace('.', '/')
    if (get(0) == '/') {
        return dir.substring(1)
    }
    return dir
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy