All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.jeroenr.tepkin.auth.ScramSha1Authentication.scala Maven / Gradle / Ivy

The newest version!
package com.github.jeroenr.tepkin.auth

import akka.actor.{Actor, ActorLogging, ActorRef}
import akka.io.Tcp.{Received, Write}
import akka.util.ByteString
import com.github.jeroenr.bson.BsonDsl._
import com.github.jeroenr.bson.Implicits._
import com.github.jeroenr.bson.element.BinarySubtype
import com.github.jeroenr.bson.util.Converters
import com.github.jeroenr.bson.{BsonDocument, BsonDsl, BsonValueNumber, Implicits}
import com.github.jeroenr.tepkin.MongoCredentials
import com.github.jeroenr.tepkin.protocol.command.Command
import com.github.jeroenr.tepkin.protocol.message.Reply
import com.github.jeroenr.tepkin.util.{Crypto, Randomizer}

trait ScramSha1Authentication extends Authentication with Crypto with Randomizer {
  this: Actor with ActorLogging =>

  private[this] case class Step0(rPrefix: String,
                                 clientFirstMessageBare: String,
                                 clientFirstMessage: Array[Byte])

  private[this] case class Step1(serverSignature: Array[Byte],
                                 clientFinalMessage: Array[Byte])

  private[this] case class SaslStart(databaseName: String, message: ByteString) extends Command {
    override def command: BsonDocument = {
      ("saslStart" := 1) ~
        ("mechanism" := "SCRAM-SHA-1") ~
        ("payload" := Binary(message, BinarySubtype.Generic))
    }
  }

  private[this] case class SaslContinue(databaseName: String, conversationId: Int, message: ByteString) extends Command {
    override def command: BsonDocument = {
      ("saslContinue" := 1) ~
        ("conversationId" := conversationId) ~
        ("payload" := Binary(message, BinarySubtype.Generic))
    }
  }

  private val GS2_HEADER = "n,,"
  private val RANDOM_LENGTH = 24

  private def prepareUsername(username: String): String = {
    username.replace("=", "=3D").replace(",", "=2D")
  }

  private def parseServerResponse(response: String): Map[String, String] = {
    response.split(",").map(_.split("=", 2)).map(array => (array(0), array(1))).toMap
  }

  private def computeStep0(username: String): Step0 = {
    val preparedUsername = "n=" + prepareUsername(username)
    val rPrefix = randomString(RANDOM_LENGTH)
    val nonce = s"r=$rPrefix"
    val clientFirstMessageBare = s"$preparedUsername,$nonce"
    val clientFirstMessage = GS2_HEADER + clientFirstMessageBare
    Step0(
      rPrefix = rPrefix,
      clientFirstMessageBare = clientFirstMessageBare,
      clientFirstMessage = decodeUtf8(clientFirstMessage)
    )
  }

  private def computeStep1(challenge: Array[Byte], username: String, password: String, step0: Step0): Step1 = {
    val serverFirstMessage = encodeUtf8(challenge)
    val response = parseServerResponse(serverFirstMessage)
    val r = response("r")

    if (!r.startsWith(step0.rPrefix)) {
      throw new RuntimeException("Server sent an invalid nonce.")
    }

    val s = response("s")
    val iterations = response("i")

    val channelBinding = "c=" + encodeBase64(decodeUtf8(GS2_HEADER))
    val nonce = s"r=$r"
    val clientFinalMessageWithoutProof = s"$channelBinding,$nonce"

    val saltedPassword = keyDerive(Converters.md5Hex(s"$username:mongo:$password"), decodeBase64(s), iterations.toInt)
    val clientKey = hmac(saltedPassword, "Client Key")
    val storedKey = sha1(clientKey)
    val authMessage = step0.clientFirstMessageBare + "," + serverFirstMessage + "," + clientFinalMessageWithoutProof
    val clientSignature = hmac(storedKey, authMessage)
    val clientProof = xor(clientKey, clientSignature)
    val serverKey = hmac(saltedPassword, "Server Key")
    val serverSignature = hmac(serverKey, authMessage)

    val proof = "p=" + encodeBase64(clientProof)
    val clientFinalMessage = clientFinalMessageWithoutProof + "," + proof

    Step1(
      serverSignature = serverSignature,
      clientFinalMessage = decodeUtf8(clientFinalMessage)
    )
  }

  private def computeStep2(challenge: ByteString, step1: Step1): ByteString = {
    val response = parseServerResponse(encodeUtf8(challenge.toArray))

    if (response("v") != encodeBase64(step1.serverSignature)) {
      throw new RuntimeException("Server signature was invalid.")
    }

    challenge
  }

  override def authenticate(connection: ActorRef, databaseName: String, credentials: Option[MongoCredentials]): Unit = {
    val step0 = computeStep0(credentials.get.username)
    val command = SaslStart(databaseName, ByteString(step0.clientFirstMessage))
    context.become(step1(connection, databaseName, credentials, step0))
    connection ! Write(command.encode)
  }

  private def step1(connection: ActorRef, databaseName: String, credentials: Option[MongoCredentials], step0: Step0): Receive = {
    case Received(data) =>
      for {
        reply <- Reply(data.asByteBuffer)
        result <- reply.documents.headOption
        conversationId <- result.getAs[Int]("conversationId")
        payload <- result.get[Binary]("payload")
      } {
        val step1 = computeStep1(payload.value.toArray, credentials.get.username, credentials.get.password.get, step0)
        val command = SaslContinue(databaseName, conversationId, ByteString(step1.clientFinalMessage))

        context.become(stepN(2, connection, databaseName, conversationId, step1))
        connection ! Write(command.encode)
      }
  }

  private def stepN(n: Int, connection: ActorRef, databaseName: String, conversationId: Int, step1: Step1): Receive = {
    case Received(data) =>
      for {
        reply <- Reply(data.asByteBuffer)
        result <- reply.documents.headOption
      } result.get[BsonValueNumber]("ok") match {
        case Some(number) if number.toInt == 0 =>
          log.error(result.getAs[String]("errmsg").getOrElse("Authentication failed."))

        case _ =>
          result.getAs[Boolean]("done") match {
            case Some(true) =>
              authenticated(connection)

            case Some(false) if n == 2 =>
              val payload = result.get[Binary]("payload").get
              val command = SaslContinue(databaseName, conversationId, computeStep2(payload.value, step1))
              context.become(stepN(n + 1, connection, databaseName, conversationId, step1))
              connection ! Write(command.encode)

            case _ =>
              log.error("Too many steps involved in the SCRAM-SHA-1 negotiation.")
          }
      }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy