All Downloads are FREE. Search and download functionalities are using the official Maven repository.

epus-client_native0.4_3.0.5.3.source-code.StartupNegotiation.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2021 Hossein Naderi
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package lepus.client

import cats.effect.Concurrent
import cats.effect.kernel.Deferred
import cats.syntax.all.*
import fs2.Pipe
import fs2.Pull
import fs2.Stream
import lepus.protocol.ConnectionClass
import lepus.protocol.ConnectionClass.Close
import lepus.protocol.ConnectionClass.Secure
import lepus.protocol.ConnectionClass.Start
import lepus.protocol.ConnectionClass.Tune
import lepus.protocol.Frame
import lepus.protocol.constants.ReplyCode
import lepus.protocol.domains.*

trait StartupNegotiation[F[_]] {
  def pipe(sendQ: Frame => F[Unit]): Pipe[F, Frame, Frame]
  def config: F[NegotiatedConfig]
  def capabilities: F[Capabilities]
}
object StartupNegotiation {
  def apply[F[_]: Concurrent](
      auth: AuthenticationConfig[F],
      vhost: Path = Path("/"),
      connectionName: Option[ShortString] = None
  ): F[StartupNegotiation[F]] = for {
    conf <- Deferred[F, Either[Throwable, NegotiatedConfig]]
    caps <- Deferred[F, Either[Throwable, Capabilities]]
  } yield new StartupNegotiation[F] {

    override def capabilities: F[Capabilities] = caps.get.flatMap(_.liftTo)

    private def terminate(ex: Throwable) =
      conf.complete(Left(ex)) >> caps.complete(Left(ex)).void

    override def pipe(sendQ: Frame => F[Unit]): Pipe[F, Frame, Frame] = in => {
      def go(
          step: Negotiation[F],
          frames: Stream[F, Frame]
      ): Pull[F, Frame, Unit] =
        frames.pull.uncons1
          .onError(ex => Pull.eval(terminate(ex)))
          .flatMap {
            case Some((frame, nextFrames)) =>
              Pull
                .eval(step(frame).onError(terminate(_)))
                .flatMap {
                  case NegotiationResult.Continue(response, nextStep) =>
                    Pull
                      .eval(sendQ(response)) >> go(nextStep, nextFrames)
                  case NegotiationResult.Completed(response, config) =>
                    Pull.eval(
                      sendQ(response) >> conf.complete(Right(config))
                    ) >> nextFrames.pull.echo
                }
            case None => Pull.eval(terminate(NegotiationFailed))
          }

      go(start, in).stream
    }

    override def config: F[NegotiatedConfig] = conf.get.flatMap(_.liftTo)

    private def method(
        f: PartialFunction[ConnectionClass, F[NegotiationResult[F]]]
    ): Negotiation[F] = {
      case Frame.Method(0, method: ConnectionClass) =>
        f.lift(method).getOrElse(NegotiationError.raiseError)
      case _ => NegotiationError.raiseError
    }

    private def start: Negotiation[F] = method {
      case Start(0, 9, serverProperties, mechanisms, locales) =>
        val proposedMechanisms = mechanisms.split(" ")
        val serverCaps = serverProperties
          .get(ShortString("capabilities"))
          .collect { case t: FieldTable => t }
          .fold(Capabilities.none)(Capabilities.from(_))

        auth.get(proposedMechanisms: _*) match {
          case None => NoSupportedSASLMechanism.raiseError
          case Some(mechanism) =>
            caps.complete(Right(serverCaps)) >>
              mechanism.first.map(response =>
                NegotiationResult.continue(
                  ConnectionClass.StartOk(
                    clientProps(connectionName),
                    mechanism.name,
                    response,
                    ShortString("en-US")
                  )
                )(handleChallenge(mechanism))
              )
        }

    }
    private def handleChallenge(sasl: SaslMechanism[F]): Negotiation[F] =
      method {
        case Secure(challenge) =>
          sasl
            .next(challenge)
            .map(response =>
              NegotiationResult
                .continue(ConnectionClass.SecureOk(response))(
                  handleChallenge(sasl)
                )
            )
        case msg: Tune => afterChallenge(msg)
        case Close(ReplyCode.AccessRefused, details, _, _) =>
          AuthenticationFailure(details).raiseError
      }
    private def afterChallenge
        : ConnectionClass.Tune => F[NegotiationResult[F]] = {
      case Tune(channelMax, frameMax, heartbeat) =>
        NegotiationResult
          .completed(
            NegotiatedConfig(
              channelMax = channelMax,
              frameMax = frameMax,
              heartbeat = heartbeat
            )
          )
          .pure[F]
    }

  }

  private[client] def clientProps(connectionName: Option[ShortString]) =
    FieldTable(
      ShortString("product") -> ShortString("Lepus"),
      ShortString("version") -> ShortString(BuildInfo.version),
      ShortString("platform") -> ShortString("scala"),
      ShortString("scala-version") -> ShortString(BuildInfo.scalaVersion),
      ShortString("capabilities") -> FieldTable(
        ShortString("publisher_confirms") -> true,
        ShortString("authentication_failure_close") -> true,
        ShortString("consumer_cancel_notify") -> true,
        ShortString("basic.nack") -> true,
        ShortString("connection.blocked") -> true
      )
    ).updated(ShortString("connection_name"), connectionName)
}

final case class NegotiatedConfig(
    channelMax: Short,
    frameMax: Int,
    heartbeat: Short
)

private enum NegotiationResult[F[_]] {
  case Continue(response: Frame, next: Negotiation[F])
  case Completed(response: Frame, config: NegotiatedConfig)
}

private object NegotiationResult {
  def continue[F[_]](response: ConnectionClass)(next: Negotiation[F]) =
    Continue(Frame.Method(ChannelNumber(0), response), next)
  def completed[F[_]](config: NegotiatedConfig) =
    Completed[F](
      Frame.Method(
        ChannelNumber(0),
        ConnectionClass
          .TuneOk(config.channelMax, config.frameMax, config.heartbeat)
      ),
      config
    )
}

type Negotiation[F[_]] = Frame => F[NegotiationResult[F]]

case object NegotiationError
    extends Exception("Error while negotiating with server!")
case class AuthenticationFailure(details: String)
    extends Exception(
      s"Server refused connection due to authentication failure!\nDetails: $details"
    )
case object NoSupportedSASLMechanism
    extends Exception(
      "Server does not support any of your requested SASL mechanisms!"
    )
case object NegotiationFailed
    extends Exception("Negotiation with server failed!")




© 2015 - 2024 Weber Informatics LLC | Privacy Policy