All Downloads are FREE. Search and download functionalities are using the official Maven repository.

wvlet.airframe.http.grpc.GrpcContext.scala Maven / Gradle / Ivy

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package wvlet.airframe.http.grpc

import io.grpc.*
import wvlet.airframe.http.internal.TLSSupport
import wvlet.airframe.http.{Http, HttpMessage, RPCContext, RPCEncoding}
import wvlet.log.LogSupport

import scala.collection.mutable

object GrpcContext {
  private[grpc] val contextKey = Context.key[GrpcContext]("grpc_context")

  /**
    * Get the current GrpcContext. If it returns None, it means this method is called outside gRPC's local thread for
    * processing the request
    *
    * @return
    */
  def current: Option[GrpcContext] = Option(contextKey.get())

  private[grpc] def currentEncoding = current.map(_.encoding).getOrElse(RPCEncoding.MsgPack)

  private[grpc] val KEY_ACCEPT       = Metadata.Key.of("accept", Metadata.ASCII_STRING_MARSHALLER)
  private[grpc] val KEY_CONTENT_TYPE = Metadata.Key.of("content-type", Metadata.ASCII_STRING_MARSHALLER)

  private[grpc] implicit class RichMetadata(private val m: Metadata) extends AnyVal {
    def accept: String = Option(m.get(KEY_ACCEPT)).getOrElse(RPCEncoding.ApplicationMsgPack)
    def setAccept(s: String): Unit = {
      m.removeAll(KEY_ACCEPT)
      m.put(KEY_ACCEPT, s)
    }
    def setContentType(s: String): Unit = {
      m.removeAll(KEY_CONTENT_TYPE)
      m.put(KEY_CONTENT_TYPE, s)
    }
  }

}

import GrpcContext.*

case class GrpcContext(
    authority: Option[String],
    attributes: Attributes,
    metadata: Metadata,
    descriptor: MethodDescriptor[_, _]
) extends RPCContext
    with TLSSupport
    with LogSupport {

  // Return the accept header
  def accept: String = metadata.accept
  def encoding: RPCEncoding = accept match {
    case RPCEncoding.ApplicationJson =>
      // Json input
      RPCEncoding.JSON
    case _ =>
      // Use msgpack by default
      RPCEncoding.MsgPack
  }

  override def setThreadLocal[A](key: String, value: A): Unit = {
    setTLS(key, value)
  }

  override def getThreadLocal(key: String): Option[Any] = {
    getTLS(key)
  }

  override def httpRequest: HttpMessage.Request = {
    import scala.jdk.CollectionConverters.*
    var request = Http.POST(s"/${descriptor.getFullMethodName}")
    for (k <- metadata.keys().asScala) {
      request = request.withHeader(k, metadata.get(Metadata.Key.of(k, Metadata.ASCII_STRING_MARSHALLER)))
    }
    request
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy