All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.kaizensolutions.trace4cats.zio.extras.tapir.TraceInterceptor.scala Maven / Gradle / Ivy

package io.kaizensolutions.trace4cats.zio.extras.tapir

import io.kaizensolutions.trace4cats.zio.extras.{ZSpan, ZTracer}
import sttp.model.{Header, HeaderNames, StatusCode}
import sttp.monad.MonadError
import sttp.tapir.model.ServerRequest
import sttp.tapir.server.ServerEndpoint
import sttp.tapir.server.interceptor.*
import sttp.tapir.server.interpreter.BodyListener
import sttp.tapir.server.model.ServerResponse
import trace4cats.model.{SpanKind, SpanStatus, TraceHeaders}
import trace4cats.{AttributeValue, ToHeaders}
import zio.*

/**
 * Tapir Request interceptor that traces requests and responses that delegates
 * to the Endpoint interceptor in order to trace the endpoint logic with higher
 * precision (and make use of templated endpoint information).
 * @param tracer
 *   the tracer to use
 * @param dropHeadersWhen
 *   a function that determines whether a header should be dropped from the
 *   trace
 * @param enrichResponseHeadersWithTraceIds
 *   whether to add trace headers to the response
 * @param enrichLogs
 *   whether to add trace headers to the logs
 * @param headerFormat
 *   the format to use for trace headers
 */
final class TraceInterceptor[Env, Err] private (
  private val tracer: ZTracer,
  private val dropHeadersWhen: String => Boolean,
  private val enrichResponseHeadersWithTraceIds: Boolean,
  private val enrichLogs: Boolean,
  private val headerFormat: ToHeaders
) extends RequestInterceptor[ZIO[Env, Err, *]] {

  override def apply[R, B](
    responder: Responder[ZIO[Env, Err, *], B],
    requestHandler: EndpointInterceptor[ZIO[Env, Err, *]] => RequestHandler[ZIO[Env, Err, *], R, B]
  ): RequestHandler[ZIO[Env, Err, *], R, B] = new RequestHandler[ZIO[Env, Err, *], R, B] {
    private val tracingEndpointInterceptor = new TraceEndpointInterceptor[Env, Err](
      tracer,
      dropHeadersWhen,
      enrichResponseHeadersWithTraceIds,
      enrichLogs,
      headerFormat
    )

    override def apply(request: ServerRequest, endpoints: List[ServerEndpoint[R, ZIO[Env, Err, *]]])(implicit
      monad: MonadError[ZIO[Env, Err, *]]
    ): ZIO[Env, Err, RequestResult[B]] =
      requestHandler(tracingEndpointInterceptor)(request, endpoints)
  }
}
object TraceInterceptor {
  def apply[Env, Err](
    tracer: ZTracer,
    dropHeadersWhen: String => Boolean = HeaderNames.isSensitive,
    enrichResponseHeadersWithTraceIds: Boolean = true,
    enrichLogs: Boolean = true,
    headerFormat: ToHeaders = ToHeaders.standard
  ): TraceInterceptor[Env, Err] = new TraceInterceptor(
    tracer,
    dropHeadersWhen,
    enrichResponseHeadersWithTraceIds,
    enrichLogs,
    headerFormat
  )

  def task(
    tracer: ZTracer,
    dropHeadersWhen: String => Boolean = HeaderNames.isSensitive,
    enrichResponseHeadersWithTraceIds: Boolean = true,
    enrichLogs: Boolean = true,
    headerFormat: ToHeaders = ToHeaders.standard
  ): TraceInterceptor[Any, Throwable] =
    apply(tracer, dropHeadersWhen, enrichResponseHeadersWithTraceIds, enrichLogs, headerFormat)

  def rio[R, E <: Throwable](
    tracer: ZTracer,
    dropHeadersWhen: String => Boolean = HeaderNames.isSensitive,
    enrichResponseHeadersWithTraceIds: Boolean = true,
    enrichLogs: Boolean = true,
    headerFormat: ToHeaders = ToHeaders.standard
  ): TraceInterceptor[R, E] =
    apply(tracer, dropHeadersWhen, enrichResponseHeadersWithTraceIds, enrichLogs, headerFormat)
}

private class TraceEndpointInterceptor[Env, Err](
  private val tracer: ZTracer,
  private val dropHeadersWhen: String => Boolean,
  private val enrichResponseHeadersWithTraceIds: Boolean,
  private val enrichLogs: Boolean,
  private val headerFormat: ToHeaders
) extends EndpointInterceptor[ZIO[Env, Err, *]] {
  override def apply[B](
    responder: Responder[ZIO[Env, Err, *], B],
    endpointHandler: EndpointHandler[ZIO[Env, Err, *], B]
  ): EndpointHandler[ZIO[Env, Err, *], B] = new EndpointHandler[ZIO[Env, Err, *], B] {

    override def onDecodeSuccess[A, U, I](
      ctx: DecodeSuccessContext[ZIO[Env, Err, *], A, U, I]
    )(implicit
      monad: MonadError[ZIO[Env, Err, *]],
      bodyListener: BodyListener[ZIO[Env, Err, *], B]
    ): ZIO[Env, Err, ServerResponse[B]] = {
      val spanName     = ctx.endpoint.showShort
      val request      = ctx.request
      val traceHeaders = TraceHeaders.of(request.headers.map(h => (h.name, h.value))*)
      tracer.fromHeaders(traceHeaders, name = spanName, kind = SpanKind.Server) { span =>
        val logTraceContext =
          if (enrichLogs) ZIOAspect.annotated(annotations = extractKVHeaders(span, headerFormat).toList*)
          else ZIOAspect.identity

        enrichSpanFromRequest(request, dropHeadersWhen, span) *>
          (endpointHandler.onDecodeSuccess(ctx) @@ logTraceContext)
            .foldZIO(
              error => span.setStatus(SpanStatus.Internal(error.toString)) *> ZIO.fail(error),
              serverResponse =>
                enrichSpanFromResponse(serverResponse, dropHeadersWhen, span).as(
                  if (enrichResponseHeadersWithTraceIds) serverResponse.addHeaders(toHttpHeaders(span, headerFormat))
                  else serverResponse
                )
            )
      }
    }

    override def onSecurityFailure[A](
      ctx: SecurityFailureContext[ZIO[Env, Err, *], A]
    )(implicit
      monad: MonadError[ZIO[Env, Err, *]],
      bodyListener: BodyListener[ZIO[Env, Err, *], B]
    ): ZIO[Env, Err, ServerResponse[B]] =
      endpointHandler.onSecurityFailure(ctx)

    override def onDecodeFailure(
      ctx: DecodeFailureContext
    )(implicit
      monad: MonadError[ZIO[Env, Err, *]],
      bodyListener: BodyListener[ZIO[Env, Err, *], B]
    ): ZIO[Env, Err, Option[ServerResponse[B]]] =
      endpointHandler.onDecodeFailure(ctx)
  }

  private def toHttpHeaders(span: ZSpan, whichHeaders: ToHeaders): Seq[Header] =
    span
      .extractHeaders(whichHeaders)
      .values
      .collect { case (k, v) if v.nonEmpty => Header(k.toString, v) }
      .toSeq

  private def extractKVHeaders(span: ZSpan, whichHeaders: ToHeaders): Map[String, String] =
    span
      .extractHeaders(whichHeaders)
      .values
      .collect { case (k, v) if v.nonEmpty => (k.toString, v) }

  private def enrichSpanFromRequest(
    request: ServerRequest,
    dropHeadersWhen: String => Boolean,
    span: ZSpan
  ): UIO[Unit] =
    if (span.isSampled) span.putAll(requestFields(request.headers, dropHeadersWhen)*)
    else ZIO.unit

  private def enrichSpanFromResponse[A](
    response: ServerResponse[A],
    dropHeadersWhen: String => Boolean,
    span: ZSpan
  ): UIO[Unit] = {
    val respFields = {
      val statusCodeField = "resp.status.code" -> AttributeValue.intToTraceValue(response.code.code)
      statusCodeField +: responseFields(response.headers, dropHeadersWhen)
    }
    val spanRespAttrs = if (span.isSampled) span.putAll(respFields*) else ZIO.unit
    spanRespAttrs *> span.setStatus(toSpanStatus(response.code))
  }

  private def toSpanStatus(value: StatusCode): SpanStatus =
    value match {
      case StatusCode.BadRequest         => SpanStatus.Internal("Bad Request")
      case StatusCode.Unauthorized       => SpanStatus.Unauthenticated
      case StatusCode.Forbidden          => SpanStatus.PermissionDenied
      case StatusCode.NotFound           => SpanStatus.NotFound
      case StatusCode.TooManyRequests    => SpanStatus.Unavailable
      case StatusCode.BadGateway         => SpanStatus.Unavailable
      case StatusCode.ServiceUnavailable => SpanStatus.Unavailable
      case StatusCode.GatewayTimeout     => SpanStatus.Unavailable
      case status if status.isSuccess    => SpanStatus.Ok
      case _                             => SpanStatus.Unknown
    }

  private def requestFields(
    hs: Seq[Header],
    dropHeadersWhen: String => Boolean
  ): Seq[(String, AttributeValue)] =
    headerFields(hs, "req", dropHeadersWhen)

  private def responseFields(
    hs: Seq[Header],
    dropHeadersWhen: String => Boolean
  ): Seq[(String, AttributeValue)] =
    headerFields(hs, "resp", dropHeadersWhen)

  private def headerFields(
    hs: Seq[Header],
    `type`: String,
    dropHeadersWhen: String => Boolean
  ): Seq[(String, AttributeValue)] =
    hs.filter(h => !dropHeadersWhen(h.name)).map { h =>
      (s"${`type`}.header.${h.name}", h.value: AttributeValue)
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy