All Downloads are FREE. Search and download functionalities are using the official Maven repository.

wvlet.airframe.http.grpc.GrpcServer.scala Maven / Gradle / Ivy

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package wvlet.airframe.http.grpc

import io.grpc.*
import wvlet.airframe.codec.MessageCodecFactory
import wvlet.airframe.http.{Router, RxRouter}
import wvlet.airframe.http.grpc.internal.{GrpcRequestLogger, GrpcServiceBuilder}
import wvlet.airframe.{Design, Session}
import wvlet.log.LogSupport
import wvlet.log.io.IOUtil

import java.util.concurrent.ForkJoinPool.ForkJoinWorkerThreadFactory
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.{ExecutorService, ForkJoinPool, ForkJoinWorkerThread}
import scala.language.existentials

/**
  */
case class GrpcServerConfig(
    // The server name
    name: String = "default",
    private val serverPort: Option[Int] = None,
    router: Router = Router.empty,
    interceptors: Seq[ServerInterceptor] = Seq.empty,
    serverInitializer: ServerBuilder[_] => ServerBuilder[_] = identity,
    executorProvider: GrpcServerConfig => ExecutorService = GrpcServer.newAsyncExecutorFactory,
    maxThreads: Int = (Runtime.getRuntime.availableProcessors() * 2).max(2),
    codecFactory: MessageCodecFactory = MessageCodecFactory.defaultFactoryForMapOutput,
    requestLoggerProvider: GrpcServerConfig => GrpcRequestLogger = { (config: GrpcServerConfig) =>
      GrpcRequestLogger
        .newLogger(config.name)
    }
) extends LogSupport {
  lazy val port = serverPort.getOrElse(IOUtil.unusedPort)

  def withName(name: String): GrpcServerConfig = this.copy(name = name)
  def withPort(port: Int): GrpcServerConfig    = this.copy(serverPort = Some(port))

  @deprecated("Use withRouter(RxRouter)", "23.5.0")
  def withRouter(router: Router): GrpcServerConfig   = this.copy(router = router)
  def withRouter(router: RxRouter): GrpcServerConfig = this.copy(router = Router.fromRxRouter(router))

  /**
    * Use this method to customize gRPC server, e.g., setting tracer, add transport filter, etc.
    *
    * @param serverInitializer
    * @return
    */
  def withServerInitializer(serverInitializer: ServerBuilder[_] => ServerBuilder[_]) =
    this.copy(serverInitializer = serverInitializer)

  /**
    * Add an gRPC interceptor
    *
    * @param interceptor
    * @return
    */
  def withInterceptor(interceptor: ServerInterceptor): GrpcServerConfig =
    this.copy(interceptors = interceptors :+ interceptor)
  def noInterceptor: GrpcServerConfig = this.copy(interceptors = Seq.empty)

  /**
    * Set a custom thread pool. The default is Executors.newCachedThreadPool()
    */
  def withExecutorServiceProvider(provider: GrpcServerConfig => ExecutorService) =
    this.copy(executorProvider = provider)

  /**
    * Set the maximum number of grpc server threads. The default is max(the number of CPU x 2, 2). If you are using a
    * custom ExecutorService, this setting might not be effective.
    */
  def withMaxThreads(numThreads: Int) = this.copy(maxThreads = numThreads)

  def withCodecFactory(newCodecFactory: MessageCodecFactory) = this.copy(codecFactory = newCodecFactory)

  def withRequestLoggerProvider(provider: GrpcServerConfig => GrpcRequestLogger) = this
    .copy(requestLoggerProvider = provider)
  // Disable RPC logging
  def noRequestLogging = this.copy(requestLoggerProvider = { (config: GrpcServerConfig) =>
    GrpcRequestLogger.nullLogger
  })

  /**
    * Create and start a new server based on this config.
    */
  def newServer(session: Session): GrpcServer = {
    val grpcService = GrpcServiceBuilder.buildService(this, session)
    grpcService.newServer
  }

  /**
    * Start a standalone gRPC server and execute the given code block. After exiting the code block, it will stop the
    * gRPC server.
    *
    * If you want to keep running the server inside the code block, call server.awaitTermination.
    */
  def start[U](body: GrpcServer => U): U = {
    design.noLifeCycleLogging.run[GrpcServer, U] { server =>
      body(server)
    }
  }

  /**
    * Create a GrpcServer design for Airframe DI
    */
  def design: Design = {
    Design.newDesign
      .bind[GrpcServerConfig].toInstance(this)
      .bind[GrpcServer].toProvider { (config: GrpcServerConfig, session: Session) => config.newServer(session) }
  }

  /**
    * Create a design for GrpcServer and ManagedChannel. Useful for testing purpose
    *
    * @return
    */
  def designWithChannel: Design = {
    design
      .bind[Channel].toProvider { (server: GrpcServer) =>
        ManagedChannelBuilder.forTarget(server.localAddress).usePlaintext().build()
      }
      .onShutdown {
        case m: ManagedChannel => m.shutdownNow()
        case _                 =>
      }
  }
}

class GrpcServer(grpcService: GrpcService, server: Server) extends AutoCloseable with LogSupport {
  def port: Int            = grpcService.config.port
  def localAddress: String = s"localhost:${port}"

  def start: Unit = {
    info(s"Starting gRPC server [${grpcService.config.name}] at ${localAddress}")
    server.start()
  }

  def awaitTermination: Unit = {
    server.awaitTermination()
  }

  override def close(): Unit = {
    info(s"Closing gRPC server [${grpcService.config.name}] at ${localAddress}")
    server.shutdownNow()
    grpcService.close()
  }
}

object GrpcServer {
  private[grpc] def newAsyncExecutorFactory(config: GrpcServerConfig): ExecutorService = {
    new ForkJoinPool(
      // The number of threads
      config.maxThreads,
      new ForkJoinWorkerThreadFactory() {
        private val threadCount = new AtomicInteger()
        override def newThread(pool: ForkJoinPool): ForkJoinWorkerThread = {
          val thread = ForkJoinPool.defaultForkJoinWorkerThreadFactory.newThread(pool)
          val name   = s"grpc-${config.name}-${threadCount.getAndIncrement()}"
          thread.setDaemon(true)
          thread.setName(name)
          thread
        }
      },
      null, // Use the default behavior for unrecoverable exceptions
      true  // Enable asyncMode as grpc server will never join
    );
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy