All Downloads are FREE. Search and download functionalities are using the official Maven repository.

t-scala.jwt-core_3.10.0.1.source-code.JwtUtils.scala Maven / Gradle / Ivy

The newest version!
package pdi.jwt

import java.security.spec.{PKCS8EncodedKeySpec, X509EncodedKeySpec}
import java.security.{KeyFactory, PrivateKey, PublicKey, Signature}
import javax.crypto.spec.SecretKeySpec
import javax.crypto.{Mac, SecretKey}

import pdi.jwt.JwtAlgorithm.{ES256, ES384, ES512}
import pdi.jwt.algorithms.*
import pdi.jwt.exceptions.{JwtNonSupportedAlgorithm, JwtSignatureFormatException}

object JwtUtils {
  val ENCODING = "UTF-8"
  val RSA = "RSA"
  val ECDSA = "EC"
  val EdDSA = "EdDSA"

  /** Convert an array of bytes to its corresponding string using the default encoding.
    *
    * @return
    *   the final string
    * @param arr
    *   the array of bytes to transform
    */
  def stringify(arr: Array[Byte]): String = new String(arr, ENCODING)

  /** Convert a string to its corresponding array of bytes using the default encoding.
    *
    * @return
    *   the final array of bytes
    * @param str
    *   the string to convert
    */
  def bytify(str: String): Array[Byte] = str.getBytes(ENCODING)

  private def escape(value: String): String = value.replaceAll("\"", "\\\\\"")

  /** Convert a sequence to a JSON array
    */
  def seqToJson(seq: Seq[Any]): String = seq
    .map {
      case value: String        => "\"" + escape(value) + "\""
      case value: Boolean       => if (value) "true" else "false"
      case value: Double        => value.toString
      case value: Short         => value.toString
      case value: Float         => value.toString
      case value: Long          => value.toString
      case value: Int           => value.toString
      case value: BigDecimal    => value.toString
      case value: BigInt        => value.toString
      case (key: String, value) => hashToJson(Seq(key -> value))
      case value: Any           => "\"" + escape(value.toString) + "\""
    }
    .mkString("[", ",", "]")

  /** Convert a sequence of tuples to a JSON object
    */
  def hashToJson(hash: Seq[(String, Any)]): String = hash
    .map {
      case (key, value: String)     => "\"" + escape(key) + "\":\"" + escape(value) + "\""
      case (key, value: Boolean)    => "\"" + escape(key) + "\":" + (if (value) "true" else "false")
      case (key, value: Double)     => "\"" + escape(key) + "\":" + value.toString
      case (key, value: Short)      => "\"" + escape(key) + "\":" + value.toString
      case (key, value: Float)      => "\"" + escape(key) + "\":" + value.toString
      case (key, value: Long)       => "\"" + escape(key) + "\":" + value.toString
      case (key, value: Int)        => "\"" + escape(key) + "\":" + value.toString
      case (key, value: BigDecimal) => "\"" + escape(key) + "\":" + value.toString
      case (key, value: BigInt)     => "\"" + escape(key) + "\":" + value.toString
      case (key, (vKey: String, vValue)) =>
        "\"" + escape(key) + "\":" + hashToJson(Seq(vKey -> vValue))
      case (key, value: Seq[Any]) => "\"" + escape(key) + "\":" + seqToJson(value)
      case (key, value: Set[_])   => "\"" + escape(key) + "\":" + seqToJson(value.toSeq)
      case (key, value: Any)      => "\"" + escape(key) + "\":\"" + escape(value.toString) + "\""
    }
    .mkString("{", ",", "}")

  /** Merge multiple JSON strings to a unique one
    */
  def mergeJson(json: String, jsonSeq: String*): String = {
    val initJson = json.trim match {
      case ""    => ""
      case value => value.drop(1).dropRight(1)
    }

    "{" + jsonSeq.map(_.trim).fold(initJson) {
      case (j1, result) if j1.length < 5 => result.drop(1).dropRight(1)
      case (result, j2) if j2.length < 7 => result
      case (j1, j2)                      => j1 + "," + j2.drop(1).dropRight(1)
    } + "}"
  }

  private def parseKey(key: String): Array[Byte] = JwtBase64.decodeNonSafe(
    key.replaceAll("-----BEGIN ([^-]*)-----|-----END ([^-]*)-----|\\s*", "")
  )

  private def parsePrivateKey(key: String, keyAlgo: String) = {
    val spec = new PKCS8EncodedKeySpec(parseKey(key))
    KeyFactory.getInstance(keyAlgo).generatePrivate(spec)
  }

  private def parsePublicKey(key: String, keyAlgo: String): PublicKey = {
    val spec = new X509EncodedKeySpec(parseKey(key))
    KeyFactory.getInstance(keyAlgo).generatePublic(spec)
  }

  /** Generate the signature for a given data using the key and HMAC algorithm provided.
    */
  def sign(data: Array[Byte], key: SecretKey, algorithm: JwtHmacAlgorithm): Array[Byte] = {
    val mac = Mac.getInstance(algorithm.fullName)
    mac.init(key)
    mac.doFinal(data)
  }

  def sign(data: String, key: SecretKey, algorithm: JwtHmacAlgorithm): Array[Byte] =
    sign(bytify(data), key, algorithm)

  /** Generate the signature for a given data using the key and RSA or ECDSA algorithm provided.
    */
  def sign(data: Array[Byte], key: PrivateKey, algorithm: JwtAsymmetricAlgorithm): Array[Byte] = {
    val signer = Signature.getInstance(algorithm.fullName)
    signer.initSign(key)
    signer.update(data)
    algorithm match {
      case _: JwtRSAAlgorithm => signer.sign
      case algorithm: JwtECDSAAlgorithm =>
        transcodeSignatureToConcat(signer.sign, getSignatureByteArrayLength(algorithm))
      case _: JwtEdDSAAlgorithm => signer.sign
    }
  }

  def sign(data: String, key: PrivateKey, algorithm: JwtAsymmetricAlgorithm): Array[Byte] =
    sign(bytify(data), key, algorithm)

  /** Will try to sign some given data by parsing the provided key, if parsing fail, please consider
    * retrieving the SecretKey or the PrivateKey on your side and then use another "sign" method.
    */
  def sign(data: Array[Byte], key: String, algorithm: JwtAlgorithm): Array[Byte] =
    algorithm match {
      case algo: JwtHmacAlgorithm => sign(data, new SecretKeySpec(bytify(key), algo.fullName), algo)
      case algo: JwtRSAAlgorithm  => sign(data, parsePrivateKey(key, RSA), algo)
      case algo: JwtECDSAAlgorithm   => sign(data, parsePrivateKey(key, ECDSA), algo)
      case algo: JwtEdDSAAlgorithm   => sign(data, parsePrivateKey(key, EdDSA), algo)
      case algo: JwtUnknownAlgorithm => throw new JwtNonSupportedAlgorithm(algo.fullName)
    }

  /** Alias to `sign` using a String data which will be converted to an array of bytes.
    */
  def sign(data: String, key: String, algorithm: JwtAlgorithm): Array[Byte] =
    sign(bytify(data), key, algorithm)

  /** Check if a signature is valid for a given data using the key and the HMAC algorithm provided.
    */
  def verify(
      data: Array[Byte],
      signature: Array[Byte],
      key: SecretKey,
      algorithm: JwtHmacAlgorithm
  ): Boolean = {
    JwtArrayUtils.constantTimeAreEqual(sign(data, key, algorithm), signature)
  }

  /** Check if a signature is valid for a given data using the key and the RSA or ECDSA algorithm
    * provided.
    */
  def verify(
      data: Array[Byte],
      signature: Array[Byte],
      key: PublicKey,
      algorithm: JwtAsymmetricAlgorithm
  ): Boolean = {
    val signer = Signature.getInstance(algorithm.fullName)
    signer.initVerify(key)
    signer.update(data)
    algorithm match {
      case _: JwtRSAAlgorithm   => signer.verify(signature)
      case _: JwtECDSAAlgorithm => signer.verify(transcodeSignatureToDER(signature))
      case _: JwtEdDSAAlgorithm => signer.verify(signature)
    }
  }

  /** Will try to check if a signature is valid for a given data by parsing the provided key, if
    * parsing fail, please consider retrieving the SecretKey or the PublicKey on your side and then
    * use another "verify" method.
    */
  def verify(
      data: Array[Byte],
      signature: Array[Byte],
      key: String,
      algorithm: JwtAlgorithm
  ): Boolean = algorithm match {
    case algo: JwtHmacAlgorithm =>
      verify(data, signature, new SecretKeySpec(bytify(key), algo.fullName), algo)
    case algo: JwtRSAAlgorithm     => verify(data, signature, parsePublicKey(key, RSA), algo)
    case algo: JwtECDSAAlgorithm   => verify(data, signature, parsePublicKey(key, ECDSA), algo)
    case algo: JwtEdDSAAlgorithm   => verify(data, signature, parsePublicKey(key, EdDSA), algo)
    case algo: JwtUnknownAlgorithm => throw new JwtNonSupportedAlgorithm(algo.fullName)
  }

  /** Alias for `verify`
    */
  def verify(data: String, signature: String, key: String, algorithm: JwtAlgorithm): Boolean =
    verify(bytify(data), bytify(signature), key, algorithm)

  /** Returns the expected signature byte array length (R + S parts) for the specified ECDSA
    * algorithm.
    *
    * @param algorithm
    *   The ECDSA algorithm. Must be supported and not {@code null} .
    * @return
    *   The expected byte array length for the signature.
    */
  def getSignatureByteArrayLength(algorithm: JwtECDSAAlgorithm): Int = algorithm match {
    case ES256 => 64
    case ES384 => 96
    case ES512 => 132
  }

  /** Transcodes the JCA ASN.1/DER-encoded signature into the concatenated R + S format expected by
    * ECDSA JWS.
    *
    * @param derSignature
    *   The ASN1./DER-encoded. Must not be {@code null} .
    * @param outputLength
    *   The expected length of the ECDSA JWS signature.
    * @return
    *   The ECDSA JWS encoded signature.
    * @throws JwtSignatureFormatException
    *   If the ASN.1/DER signature format is invalid.
    */
  @throws[JwtSignatureFormatException]
  def transcodeSignatureToConcat(derSignature: Array[Byte], outputLength: Int): Array[Byte] = {
    if (derSignature.length < 8 || derSignature(0) != 48)
      throw new JwtSignatureFormatException("Invalid ECDSA signature format")

    val offset: Int = derSignature(1) match {
      case s if s > 0            => 2
      case s if s == 0x81.toByte => 3
      case _ => throw new JwtSignatureFormatException("Invalid ECDSA signature format")
    }

    val rLength: Byte = derSignature(offset + 1)
    var i = rLength.toInt
    while ((i > 0) && (derSignature((offset + 2 + rLength) - i) == 0)) {
      i -= 1
    }

    val sLength: Byte = derSignature(offset + 2 + rLength + 1)
    var j = sLength.toInt
    while ((j > 0) && (derSignature((offset + 2 + rLength + 2 + sLength) - j) == 0)) {
      j -= 1
    }

    val rawLen: Int = Math.max(Math.max(i, j), outputLength / 2)

    if (
      (derSignature(offset - 1) & 0xff) != derSignature.length - offset
      || (derSignature(offset - 1) & 0xff) != 2 + rLength + 2 + sLength
      || derSignature(offset) != 2 || derSignature(offset + 2 + rLength) != 2
    )
      throw new JwtSignatureFormatException("Invalid ECDSA signature format")

    val concatSignature: Array[Byte] = new Array[Byte](2 * rawLen)
    System.arraycopy(derSignature, (offset + 2 + rLength) - i, concatSignature, rawLen - i, i)
    System.arraycopy(
      derSignature,
      (offset + 2 + rLength + 2 + sLength) - j,
      concatSignature,
      2 * rawLen - j,
      j
    )
    concatSignature
  }

  /** Transcodes the ECDSA JWS signature into ASN.1/DER format for use by the JCA verifier.
    *
    * @param signature
    *   The JWS signature, consisting of the concatenated R and S values. Must not be {@code null} .
    * @return
    *   The ASN.1/DER encoded signature.
    * @throws JwtSignatureFormatException
    *   If the ECDSA JWS signature format is invalid.
    */
  @throws[JwtSignatureFormatException]
  def transcodeSignatureToDER(signature: Array[Byte]): Array[Byte] = {
    var (r, s) = signature.splitAt(signature.length / 2)
    r = r.dropWhile(_ == 0)
    if (r.length > 0 && r(0) < 0)
      r +:= 0.toByte

    s = s.dropWhile(_ == 0)
    if (s.length > 0 && s(0) < 0)
      s +:= 0.toByte

    val signatureLength = 2 + r.length + 2 + s.length

    if (signatureLength > 255)
      throw new JwtSignatureFormatException("Invalid ECDSA signature format")

    val signatureDER = scala.collection.mutable.ListBuffer.empty[Byte]
    signatureDER += 48
    if (signatureLength >= 128)
      signatureDER += 0x81.toByte

    signatureDER += signatureLength.toByte
    signatureDER ++= (Seq(2.toByte, r.length.toByte) ++ r)
    signatureDER ++= (Seq(2.toByte, s.length.toByte) ++ s)

    signatureDER.toArray
  }

  def splitString(input: String, separator: Char): Array[String] = {
    val builder = scala.collection.mutable.ArrayBuffer.empty[String]
    var lastIndex = 0
    var index = input.indexOf(separator.toInt, lastIndex)
    while (index != -1) {
      builder += input.substring(lastIndex, index)
      lastIndex = index + 1
      index = input.indexOf(separator.toInt, lastIndex)
    }
    // Add the remainder
    if (lastIndex < input.length) {
      builder += input.substring(lastIndex, input.length)
    }
    builder.toArray
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy