All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.protocol.StartupPlatform.scala Maven / Gradle / Ivy

// Copyright (c) 2018-2021 by Rob Norris
// This software is licensed under the MIT License (MIT).
// For more information see LICENSE or https://opensource.org/licenses/MIT

package skunk.net.protocol

import com.ongres.scram.client.ScramClient
import com.ongres.scram.common.stringprep.StringPreparations

import cats.MonadError
import cats.syntax.all._
import natchez.Trace
import scala.util.control.NonFatal
import skunk.net.MessageSocket
import skunk.net.message._
import skunk.exception.{
  SCRAMProtocolException,
  UnsupportedSASLMechanismsException
}

private[protocol] trait StartupCompanionPlatform { this: Startup.type =>
  
  private[protocol] def authenticationSASL[F[_]: MessageSocket: Trace](
    sm:         StartupMessage,
    password:   Option[String],
    mechanisms: List[String]
  )(
    implicit ev: MonadError[F, Throwable]
  ): F[Unit] =
    Trace[F].span("authenticationSASL") {
        for {
          client <- {
            try ScramClient.
              channelBinding(ScramClient.ChannelBinding.NO).
              stringPreparation(StringPreparations.SASL_PREPARATION).
              selectMechanismBasedOnServerAdvertised(mechanisms.toArray: _*).
              setup().pure[F]
            catch {
              case _: IllegalArgumentException => new UnsupportedSASLMechanismsException(mechanisms).raiseError[F, ScramClient]
              case NonFatal(t) => new SCRAMProtocolException(t.getMessage).raiseError[F, ScramClient]
            }
          }
          session = client.scramSession("*")
          _ <- send(SASLInitialResponse(client.getScramMechanism.getName, bytesUtf8(session.clientFirstMessage)))
          serverFirstBytes <- flatExpectStartup(sm) {
            case AuthenticationSASLContinue(serverFirstBytes) => serverFirstBytes.pure[F]
          }
          serverFirst <- guardScramAction {
            session.receiveServerFirstMessage(new String(serverFirstBytes.toArray, "UTF-8"))
          }
          pw <- requirePassword[F](sm, password)
          clientFinal = serverFirst.clientFinalProcessor(pw)
          _ <- send(SASLResponse(bytesUtf8(clientFinal.clientFinalMessage)))
          serverFinalBytes <- flatExpectStartup(sm) {
            case AuthenticationSASLFinal(serverFinalBytes) => serverFinalBytes.pure[F]
          }
          _ <- guardScramAction {
            clientFinal.receiveServerFinalMessage(new String(serverFinalBytes.toArray, "UTF-8")).pure[F]
          }
          _ <- flatExpectStartup(sm) { case AuthenticationOk => ().pure[F] }
        } yield ()
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy