All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.pekko.grpc.internal.AbstractGrpcProtocol.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * license agreements; and to You under the Apache License, version 2.0:
 *
 *   https://www.apache.org/licenses/LICENSE-2.0
 *
 * This file is part of the Apache Pekko project, which was derived from Akka.
 */

/*
 * Copyright (C) 2020-2021 Lightbend Inc. 
 */

package org.apache.pekko.grpc.internal
import org.apache.pekko
import pekko.NotUsed
import pekko.grpc.GrpcProtocol
import pekko.grpc.GrpcProtocol.{ Frame, GrpcProtocolReader, GrpcProtocolWriter }
import pekko.http.javadsl.{ model => jmodel }
import pekko.http.scaladsl.model.HttpEntity.ChunkStreamPart
import pekko.http.scaladsl.model.{ ContentType, HttpHeader, HttpResponse, MediaType, Trailer }
import pekko.stream.Attributes
import pekko.stream.impl.io.ByteStringParser
import pekko.stream.impl.io.ByteStringParser.{ ByteReader, ParseResult, ParseStep }
import pekko.stream.scaladsl.Flow
import pekko.stream.stage.GraphStageLogic
import pekko.util.{ ByteString, ByteStringBuilder }
import io.grpc.StatusException

import java.nio.ByteOrder
import scala.collection.immutable

abstract class AbstractGrpcProtocol(subType: String) extends GrpcProtocol {

  override val contentType: ContentType.Binary =
    MediaType.applicationBinary(s"$subType+proto", MediaType.Compressible).toContentType

  override val mediaTypes: Set[jmodel.MediaType] =
    Set(contentType.mediaType, MediaType.applicationBinary(subType, MediaType.Compressible))

  private lazy val knownWriters = Codecs.supportedCodecs.map(c => c -> writer(c)).toMap.withDefault(writer)
  private lazy val knownReaders = Codecs.supportedCodecs.map(c => c -> reader(c)).toMap.withDefault(reader)

  /**
   * Obtains a writer for this protocol:
   * @param codec the compression codec to apply to data frame contents.
   */
  override def newWriter(codec: Codec): GrpcProtocolWriter = knownWriters(codec)

  /**
   * Obtains a reader for this protocol.
   *
   * @param codec the codec to use for compressed frames.
   */
  override def newReader(codec: Codec): GrpcProtocolReader = knownReaders(codec)

  protected def writer(codec: Codec): GrpcProtocolWriter

  protected def reader(codec: Codec): GrpcProtocolReader

}
object AbstractGrpcProtocol {

  /** Field marker to signal the start of an uncompressed frame */
  private val notCompressed: ByteString = ByteString(0)

  /** Field marker to signal the start of an frame compressed according to the Message-Encoding from the header */
  private val compressed: ByteString = ByteString(1)

  def fieldType(codec: Codec) = if (codec == Identity) notCompressed else compressed

  /**
   * Adjusts the compressibility of a content type to suit a message encoding.
   * @param contentType the content type for the gRPC protocol.
   * @param codec the message encoding being used to encode objects.
   * @return the provided content type, with the compressibility adapted to reflect whether HTTP transport level compression should be used.
   */
  private def adjustCompressibility(contentType: ContentType.Binary, codec: Codec): ContentType.Binary =
    contentType.mediaType
      .withComp(codec match {
        case Identity => MediaType.Compressible
        case _        => MediaType.NotCompressible
      })
      .toContentType

  @deprecated("In favor of overload that takes Booleans to specify flags", "akka-grpc 2.1.0")
  def encodeFrameData(frameType: ByteString, data: ByteString): ByteString = {
    val length = data.length
    val encodedLength = ByteString.fromArrayUnsafe(
      Array((length >> 24).toByte, (length >> 16).toByte, (length >> 8).toByte, length.toByte))
    frameType ++ encodedLength ++ data
  }
  def encodeFrameData(data: ByteString, isCompressed: Boolean, isTrailer: Boolean): ByteString = {
    implicit val byteOrder = ByteOrder.BIG_ENDIAN
    val length = data.length
    val builder = new ByteStringBuilder()
    builder.sizeHint(5)
    val flags =
      (if (isCompressed) 1 else 0) | (if (isTrailer) 0x80 else 0)

    builder // ...
      .putByte(flags.toByte)
      .putInt(length)
      .++=(data)
      .result()
  }

  def writer(
      protocol: GrpcProtocol,
      codec: Codec,
      encodeFrame: Frame => ChunkStreamPart,
      encodeDataToResponse: (ByteString, immutable.Seq[HttpHeader], Trailer) => HttpResponse): GrpcProtocolWriter =
    GrpcProtocolWriter(
      adjustCompressibility(protocol.contentType, codec),
      codec,
      encodeFrame,
      encodeDataToResponse,
      Flow[Frame].map(encodeFrame))

  def reader(
      codec: Codec,
      decodeFrame: (Int, ByteString) => Frame,
      preDecodeStrict: ByteString => ByteString = null,
      preDecodeFlow: Flow[ByteString, ByteString, NotUsed] = null): GrpcProtocolReader = {
    val strictAdapter: ByteString => ByteString = if (preDecodeStrict eq null) identity else preDecodeStrict
    val adapter: Flow[ByteString, Frame, NotUsed] => Flow[ByteString, Frame, NotUsed] =
      if (preDecodeFlow eq null) identity
      else x => Flow[ByteString].via(preDecodeFlow).via(x)

    // strict decoder
    def decoder(bs: ByteString): ByteString =
      try {
        val reader = new ByteReader(strictAdapter(bs))
        val frameType = reader.readByte()
        val length = reader.readIntBE()
        val data = reader.take(length)
        if (reader.hasRemaining) throw new IllegalStateException("Unexpected data")
        if ((frameType & 0x80) == 0) codec.uncompress((frameType & 1) == 1, data)
        else throw new IllegalStateException("Cannot read unknown frame")
      } catch { case ByteStringParser.NeedMoreData => throw new MissingParameterException }

    GrpcProtocolReader(codec, decoder, adapter(Flow.fromGraph(new GrpcFramingDecoderStage(codec, decodeFrame))))
  }

  class GrpcFramingDecoderStage(codec: Codec, deframe: (Int, ByteString) => Frame) extends ByteStringParser[Frame] {
    override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
      new ParsingLogic {
        startWith(ReadFrameHeader)

        trait Step extends ParseStep[Frame]

        object ReadFrameHeader extends Step {
          override def parse(reader: ByteReader): ParseResult[Frame] = {
            val frameType = reader.readByte()
            // If we want to support > 2GB frames, this should be unsigned
            val length = reader.readIntBE()

            if (length == 0) ParseResult(Some(deframe(frameType, ByteString.empty)), ReadFrameHeader)
            else ParseResult(None, ReadFrame(frameType, length), acceptUpstreamFinish = false)
          }
        }

        sealed case class ReadFrame(frameType: Int, length: Int) extends Step {
          private val compression = (frameType & 0x01) == 1

          override def parse(reader: ByteReader): ParseResult[Frame] =
            try ParseResult(
                Some(deframe(frameType, codec.uncompress(compression, reader.take(length)))),
                ReadFrameHeader)
            catch {
              case s: StatusException =>
                failStage(s) // handle explicitly to avoid noisy log
                ParseResult(None, Failed)
            }
        }

        case object Failed extends Step {
          override def parse(reader: ByteReader): ParseResult[Frame] = ParseResult(None, Failed)
        }
      }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy