All Downloads are FREE. Search and download functionalities are using the official Maven repository.

wvlet.airframe.http.grpc.internal.GrpcRequestHandler.scala Maven / Gradle / Ivy

There is a newer version: 2025.1.8
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package wvlet.airframe.http.grpc.internal

import io.grpc.stub.ServerCalls.{BidiStreamingMethod, ClientStreamingMethod, ServerStreamingMethod, UnaryMethod}
import io.grpc.stub.{ClientCallStreamObserver, ServerCallStreamObserver, StreamObserver}
import wvlet.airframe.codec.{MessageCodec, MessageCodecException, MessageCodecFactory}
import wvlet.airframe.http
import wvlet.airframe.http.{RPCEncoding, RPCMethod}
import wvlet.airframe.http.grpc.{GrpcContext, GrpcResponse}
import wvlet.airframe.http.internal.RPCCallContext
import wvlet.airframe.http.{RPCEncoding, RPCStatus}
import wvlet.airframe.http.router.{Route, HttpRequestMapper}
import wvlet.airframe.msgpack.spi.MsgPack
import wvlet.airframe.msgpack.spi.Value.MapValue
import wvlet.airframe.rx.*
import wvlet.airframe.surface.{CName, MethodSurface, Surface}
import wvlet.log.LogSupport

import java.lang.reflect.InvocationTargetException
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.{Callable, ExecutorService}
import scala.annotation.tailrec
import scala.concurrent.{ExecutionContext, ExecutionException, Promise}
import scala.util.{Failure, Success, Try}

/**
  * This handler receives a MessagePack Map value for an RPC request, and call the corresponding controller method
  */
class GrpcRequestHandler(
    rpcMethod: RPCMethod,
    // Controller instance
    controller: Any,
    // Controller method to call for RPC
    methodSurface: MethodSurface,
    codecFactory: MessageCodecFactory,
    executorService: ExecutorService,
    requestLogger: GrpcRequestLogger
) extends LogSupport {

  private val argCodecs = methodSurface.args.map(a => codecFactory.of(a.surface))

  private val rpcContext = RPCCallContext(rpcMethod, methodSurface, Seq.empty)

  /**
    * Read the input value (MessagePack or Json) and convert to MessagePack Value
    *
    * @param grpcContext
    * @param request
    * @return
    */
  private def readRequestAsValue(grpcContext: Option[GrpcContext], request: MsgPack): MapValue = {
    try {
      val value = if (RPCEncoding.isJsonObjectMessage(request)) {
        // The input is a JSON message
        RPCEncoding.JSON.unpackValue(request)
      } else {
        // Check the message type using the accept header:
        val encoding = grpcContext.map(_.encoding).getOrElse(RPCEncoding.MsgPack)
        encoding.unpackValue(request)
      }

      value match {
        case m: MapValue =>
          m
        case _ =>
          val e = RPCStatus.INVALID_REQUEST_U1.newException(s"Request data is not a MapValue: ${value}")
          reportError(e)
          requestLogger.logError(e, grpcContext, rpcContext)
          throw e
      }
    } catch {
      case e: MessageCodecException =>
        reportError(e)
        requestLogger.logError(e, grpcContext, rpcContext)
        throw e
    }
  }

  private def readStreamingInput[A](grpcContext: Option[GrpcContext], codec: MessageCodec[A], request: MsgPack): A = {
    val encoding = grpcContext.map(_.encoding).getOrElse(RPCEncoding.MsgPack)
    encoding match {
      case RPCEncoding.MsgPack =>
        codec.fromMsgPack(request)
      case RPCEncoding.JSON =>
        codec.fromJson(request)
    }
  }

  private def reportError(e: Throwable): Unit = {

    /**
      * Find the root cause of the exception from wrapped exception classes
      */
    @tailrec
    def findCause(e: Throwable): Throwable = {
      e match {
        case i: InvocationTargetException if i.getTargetException != null =>
          findCause(i.getTargetException)
        case ee: ExecutionException if ee.getCause != null =>
          findCause(ee.getCause)
        case _ =>
          e
      }
    }
    logger.error(findCause(e))
  }

  def invokeMethod(request: MsgPack): Try[Any] = {
    val grpcContext = GrpcContext.current

    // Build method arguments from MsgPack
    val result = Try {
      val m: MapValue = readRequestAsValue(grpcContext, request)
      val mapValue    = HttpRequestMapper.toCanonicalKeyNameMap(m)
      val args = for ((arg, i) <- methodSurface.args.zipWithIndex) yield {
        val argOpt = mapValue.get(CName.toCanonicalName(arg.name)) match {
          case Some(paramValue) =>
            Option(argCodecs(i).fromMsgPack(paramValue.toMsgpack)).orElse {
              throw RPCStatus.INVALID_REQUEST_U1.newException(s"Failed to parse ${paramValue} for ${arg}")
            }
          case None =>
            // If no value is found, use the method parameter's default argument
            arg.getMethodArgDefaultValue(controller)
        }
        argOpt.getOrElse {
          if (arg.surface.isOption) {
            // If the argument is an option type, we can use None for the missing value
            None
          } else {
            val msg = s"No key for ${arg.name} is found in ${m} for calling ${methodSurface}"
            error(msg)
            throw RPCStatus.INVALID_REQUEST_U1.newException(msg)
          }
        }
      }
      trace(s"RPC call ${methodSurface.name}(${args.mkString(", ")})")
      try {
        val ret = methodSurface.call(controller, args: _*)
        requestLogger.logRPC(grpcContext, rpcContext.withRPCArgs(args))
        ret
      } catch {
        case e: Throwable =>
          reportError(e)
          requestLogger.logError(e, grpcContext, rpcContext.withRPCArgs(args))
          throw e
      }
    }
    result
  }

  def invokeClientStreamingMethod(
      responseObserver: StreamObserver[Any],
      clientStreamingType: Surface
  ): StreamObserver[MsgPack] = {
    val codec = codecFactory.of(clientStreamingType)
    // An observer that receives streaming messages from the client
    val grpcContext = GrpcContext.current
    val encoding    = GrpcContext.currentEncoding
    val requestObserver = new StreamObserver[MsgPack] {
      private val isStarted = new AtomicBoolean(false)
      // A queue for passing incoming messages to the application through Rx interface
      private val rx                    = new RxBlockingQueue[Any]
      private val promise: Promise[Any] = Promise()

      private def invokeServerMethod: Unit = {
        // Avoid duplicated execution
        if (isStarted.compareAndSet(false, true)) {
          // We need to run the server RPC method in another thread as rx implementation is blocking
          executorService.submit(new Callable[Unit] {
            override def call(): Unit = {
              Try(methodSurface.call(controller, rx)) match {
                case Success(value) =>
                  promise.success(value)
                case Failure(e) =>
                  promise.failure(e)
              }
            }
          })
        }
      }

      override def onNext(value: MsgPack): Unit = {
        invokeServerMethod
        Try(readStreamingInput(grpcContext, codec, value)) match {
          case Success(v) =>
            rx.add(OnNext(v))
          case Failure(e) =>
            reportError(e)
            rx.add(OnError(e))
        }
      }
      override def onError(t: Throwable): Unit = {
        reportError(t)
        requestLogger.logError(t, grpcContext, rpcContext)
        rx.add(OnError(t))
        responseObserver.onError(GrpcException.wrap(t))
      }
      override def onCompleted(): Unit = {
        invokeServerMethod
        rx.add(OnCompletion)
        promise.future.onComplete {
          case Success(value) =>
            requestLogger.logRPC(grpcContext, rpcContext)
            responseObserver.onNext(GrpcResponse(value, encoding))
            responseObserver.onCompleted()
          case Failure(e) =>
            reportError(e)
            requestLogger.logError(e, grpcContext, rpcContext)
            responseObserver.onError(GrpcException.wrap(e))
        }(ExecutionContext.fromExecutor(executorService))
      }
    }
    requestObserver
  }

  def invokeBidiStreamingMethod(
      responseObserver: StreamObserver[Any],
      clientStreamingType: Surface
  ): StreamObserver[MsgPack] = {

    val codec       = codecFactory.of(clientStreamingType)
    val grpcContext = GrpcContext.current
    val encoding    = GrpcContext.currentEncoding

    val requestObserver = new StreamObserver[MsgPack] {
      private val isStarted             = new AtomicBoolean(false)
      private val rx                    = new RxBlockingQueue[Any]
      private val promise: Promise[Any] = Promise()

      private def invokeServerMethod: Unit = {
        // Avoid duplicated execution
        if (isStarted.compareAndSet(false, true)) {
          // We need to run the server RPC method in another thread as rx implementation is blocking
          executorService.submit(new Callable[Unit] {
            override def call(): Unit = {
              Try(methodSurface.call(controller, rx)) match {
                case Success(value) =>
                  promise.success(value)
                case Failure(e) =>
                  promise.failure(e)
              }
            }
          })
        }
      }

      override def onNext(value: MsgPack): Unit = {
        invokeServerMethod
        Try(readStreamingInput(grpcContext, codec, value)) match {
          case Success(v) =>
            // Add a log for each client-side stream message
            rx.add(OnNext(v))
          case Failure(e) =>
            rx.add(OnError(e))
        }
      }
      override def onError(t: Throwable): Unit = {
        reportError(t)
        rx.add(OnError(t))
        responseObserver.onError(t)
      }
      override def onCompleted(): Unit = {
        invokeServerMethod
        rx.add(OnCompletion)
        promise.future.onComplete {
          case Success(v) =>
            v match {
              case rx: Rx[_] =>
                var c = Cancelable.empty
                c = RxRunner.run(rx) {
                  case OnNext(value) =>
                    responseObserver.onNext(GrpcResponse(value, encoding))
                  case OnError(e) =>
                    reportError(e)
                    requestLogger.logError(e, grpcContext, rpcContext)
                    responseObserver.onError(GrpcException.wrap(e))
                  case OnCompletion =>
                    requestLogger.logRPC(grpcContext, rpcContext)
                    if (isReady(responseObserver)) {
                      responseObserver.onCompleted()
                    }
                }
              case other =>
                requestLogger.logRPC(grpcContext, rpcContext)
                responseObserver.onNext(GrpcResponse(v, encoding))
                responseObserver.onCompleted()
            }
          case Failure(e) =>
            reportError(e)
            requestLogger.logError(e, grpcContext, rpcContext)
            responseObserver.onError(GrpcException.wrap(e))
        }(ExecutionContext.fromExecutor(executorService))
      }
    }

    requestObserver
  }

  /**
    * Check whether the observer is ready for accepting more message
    * @param observer
    * @return
    */
  private def isReady(observer: StreamObserver[_]): Boolean = {
    observer match {
      case so: ServerCallStreamObserver[_] =>
        so.isReady
      case co: ClientCallStreamObserver[_] =>
        co.isReady
      case _ =>
        true
    }
  }

}

private[grpc] class RPCUnaryMethodHandler(rpcRequestHandler: GrpcRequestHandler)
    extends UnaryMethod[MsgPack, Any]
    with LogSupport {
  override def invoke(
      request: MsgPack,
      responseObserver: StreamObserver[Any]
  ): Unit = {
    val encoding = GrpcContext.currentEncoding
    rpcRequestHandler.invokeMethod(request) match {
      case Success(v) =>
        responseObserver.onNext(GrpcResponse(v, encoding))
        responseObserver.onCompleted()
      case Failure(e) =>
        responseObserver.onError(GrpcException.wrap(e))
    }
  }
}

private[grpc] class RPCServerStreamingMethodHandler(rpcRequestHandler: GrpcRequestHandler)
    extends ServerStreamingMethod[MsgPack, Any]
    with LogSupport {
  override def invoke(
      request: MsgPack,
      responseObserver: StreamObserver[Any]
  ): Unit = {
    val encoding = GrpcContext.currentEncoding
    rpcRequestHandler.invokeMethod(request) match {
      case Success(v) =>
        v match {
          case rx: Rx[_] =>
            var c = Cancelable.empty
            c = RxRunner.run(rx) {
              case OnNext(value) =>
                responseObserver.onNext(GrpcResponse(value, encoding))
              case OnError(e) =>
                responseObserver.onError(GrpcException.wrap(e))
              case OnCompletion =>
                responseObserver.onCompleted()
            }
          case other =>
            responseObserver.onNext(GrpcResponse(v, encoding))
            responseObserver.onCompleted()
        }
      case Failure(e) =>
        responseObserver.onError(GrpcException.wrap(e))
    }
  }
}

private[grpc] class RPCClientStreamingMethodHandler(rpcRequestHandler: GrpcRequestHandler, clientStreamingType: Surface)
    extends ClientStreamingMethod[MsgPack, Any] {

  override def invoke(
      responseObserver: StreamObserver[Any]
  ): StreamObserver[MsgPack] = {
    val requestObserver = rpcRequestHandler.invokeClientStreamingMethod(responseObserver, clientStreamingType)
    requestObserver
  }
}

private[grpc] class RPCBidiStreamingMethodHandler(rpcRequestHandler: GrpcRequestHandler, clientStreamingType: Surface)
    extends BidiStreamingMethod[MsgPack, Any] {
  override def invoke(
      responseObserver: StreamObserver[Any]
  ): StreamObserver[MsgPack] = {
    val requestObserver = rpcRequestHandler.invokeBidiStreamingMethod(responseObserver, clientStreamingType)
    requestObserver
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy