com.chatwork.scala.jwk.ECJWK.scala Maven / Gradle / Ivy
package com.chatwork.scala.jwk
import cats.data.NonEmptyList
import java.net.URI
import java.security.interfaces.{ ECPrivateKey, ECPublicKey }
import java.security.spec.{ ECPoint, ECPrivateKeySpec, ECPublicKeySpec, InvalidKeySpecException }
import java.security.{ KeyFactory, NoSuchAlgorithmException, _ }
import java.time.ZonedDateTime
import com.chatwork.scala.jwk.JWKError._
import com.chatwork.scala.jwk.utils.ECChecks
import com.github.j5ik2o.base64scala.{ Base64EncodeError, Base64String, Base64StringFactory, BigIntUtils }
import io.circe.DecodingFailure
object ECJWK extends ECJWKJsonImplicits {
import io.circe.Json
import io.circe.parser
def parseFromJson(json: Json): Either[JWKCreationError, ECJWK] =
json.as[ECJWK].left.map(error => JWKCreationError(error.getMessage, None))
def parseFromText(jsonText: String): Either[JWKCreationError, ECJWK] = {
parser
.parse(jsonText) match {
case Left(error) =>
Left(JWKCreationError(error.getMessage()))
case Right(json) =>
parseFromJson(json)
}
}
def apply(
curve: Curve,
x: Base64String,
y: Base64String,
publicKeyUseType: Option[PublicKeyUseType] = None,
keyOperations: KeyOperations = KeyOperations.empty,
algorithmType: Option[JWSAlgorithmType] = None,
keyId: Option[KeyId] = None,
x509Url: Option[URI] = None,
x509CertificateSHA1Thumbprint: Option[Base64String] = None,
x509CertificateSHA256Thumbprint: Option[Base64String] = None,
x509CertificateChain: Option[NonEmptyList[Base64String]] = None,
d: Option[Base64String] = None,
privateKey: Option[PrivateKey] = None,
expireAt: Option[ZonedDateTime] = None,
keyStore: Option[KeyStore] = None
): Either[JWKCreationError, ECJWK] = {
try {
Right(
new ECJWK(
curve,
x,
y,
publicKeyUseType,
keyOperations,
algorithmType,
keyId,
x509Url,
x509CertificateSHA1Thumbprint,
x509CertificateSHA256Thumbprint,
x509CertificateChain,
d,
privateKey,
expireAt,
keyStore
)
)
} catch {
case ex: IllegalArgumentException =>
Left(JWKCreationError(ex.getMessage))
}
}
def fromKeyPair(
curve: Curve,
pub: ECPublicKey,
priv: ECPrivateKey,
publicKeyUseType: Option[PublicKeyUseType] = None,
keyOperations: KeyOperations = KeyOperations.empty,
algorithmType: Option[JWSAlgorithmType] = None,
keyId: Option[KeyId] = None,
x509Url: Option[URI] = None,
x509CertificateSHA1Thumbprint: Option[Base64String] = None,
x509CertificateSHA256Thumbprint: Option[Base64String] = None,
x509CertificateChain: Option[NonEmptyList[Base64String]] = None,
expireAt: Option[ZonedDateTime] = None
): Either[JWKCreationError, ECJWK] = {
for {
pub <- encodeCoordinate(pub.getParams.getCurve.getField.getFieldSize, pub.getW.getAffineX()).left.map(error =>
JWKCreationError(error.message)
)
priv <- encodeCoordinate(priv.getParams.getCurve.getField.getFieldSize, priv.getS()).left.map(error =>
JWKCreationError(error.message)
)
jwk <- apply(
curve,
pub,
priv,
publicKeyUseType,
keyOperations,
algorithmType,
keyId,
x509Url,
x509CertificateSHA1Thumbprint,
x509CertificateSHA256Thumbprint,
x509CertificateChain,
expireAt = expireAt
)
} yield jwk
}
val SUPPORTED_CURVES = Set(Curve.P_256, Curve.P_256K, Curve.P_384, Curve.P_521)
private def ensurePublicCoordinatesOnCurve(curve: Curve, x: Base64String, y: Base64String): Unit = {
require(SUPPORTED_CURVES.contains(curve), "Unknown / unsupported curve: " + curve)
val result = for {
dx <- x.decodeToBigInt
dy <- y.decodeToBigInt
} yield ECChecks.isPointOnCurve(
dx,
dy,
curve.toECParameterSpec.getOrElse(throw new IllegalArgumentException("Unknown curve instance."))
)
if (result.isLeft)
throw new IllegalArgumentException(
"Invalid EC JWK: The 'x' and 'y' public coordinates are not on the " + curve + " curve"
)
}
private def encodeCoordinate(fieldSize: Int, coordinate: BigInt): Either[Base64EncodeError, Base64String] = {
val base64StringFactory = Base64StringFactory(urlSafe = true, isNoPadding = true)
val notPadded = BigIntUtils.toBytesUnsigned(coordinate)
val bytesToOutput = (fieldSize + 7) / 8
if (notPadded.length >= bytesToOutput) { // Greater-than check to prevent exception on malformed
// key below
base64StringFactory.encode(notPadded)
} else {
val padded = new Array[Byte](bytesToOutput)
System.arraycopy(notPadded, 0, padded, bytesToOutput - notPadded.length, notPadded.length)
base64StringFactory.encode(padded)
}
}
}
class ECJWK private[jwk] (
val curve: Curve,
val x: Base64String,
val y: Base64String,
publicKeyUseType: Option[PublicKeyUseType] = None,
keyOperations: KeyOperations = KeyOperations.empty,
algorithmType: Option[JWSAlgorithmType] = None,
keyId: Option[KeyId] = None,
x509Url: Option[URI] = None,
x509CertificateSHA1Thumbprint: Option[Base64String] = None,
x509CertificateSHA256Thumbprint: Option[Base64String] = None,
x509CertificateChain: Option[NonEmptyList[Base64String]] = None,
val d: Option[Base64String] = None,
val privateKey: Option[PrivateKey] = None,
expireAt: Option[ZonedDateTime] = None,
keyStore: Option[KeyStore] = None
) extends JWK(
KeyType.EC,
publicKeyUseType,
keyOperations,
algorithmType,
keyId,
x509Url,
x509CertificateSHA256Thumbprint,
x509CertificateSHA1Thumbprint,
x509CertificateChain,
expireAt,
keyStore
)
with AssymetricJWK
with CurveBasedJWK {
require(x.urlSafe)
require(y.urlSafe)
require(x509CertificateSHA1Thumbprint.fold(true)(_.urlSafe))
require(x509CertificateSHA256Thumbprint.fold(true)(_.urlSafe))
require(d.fold(true)(_.urlSafe))
ECJWK.ensurePublicCoordinatesOnCurve(curve, x, y)
def toECPublicKey(provider: Option[Provider] = None): Either[PublicKeyCreationError, ECPublicKey] = {
curve.toECParameterSpec
.map { spec =>
for {
dx <- x.decodeToBigInt.left.map(error => PublicKeyCreationError(error.message))
dy <- y.decodeToBigInt.left.map(error => PublicKeyCreationError(error.message))
publicKeySpec = new ECPublicKeySpec(new ECPoint(dx.bigInteger, dy.bigInteger), spec)
result <-
try {
val keyFactory =
provider.map(p => KeyFactory.getInstance("EC", p)).getOrElse(KeyFactory.getInstance("EC"))
Right(keyFactory.generatePublic(publicKeySpec).asInstanceOf[ECPublicKey])
} catch {
case e @ (_: NoSuchAlgorithmException | _: InvalidKeySpecException) =>
Left(PublicKeyCreationError(e.getMessage))
}
} yield result
}
.getOrElse(Left(PublicKeyCreationError("Couldn't get EC parameter spec for curve " + curve)))
}
def toECPrivateKey(provider: Option[Provider] = None): Either[PrivateKeyCreationError, Option[ECPrivateKey]] = {
d match {
case None =>
Right(None)
case Some(_d) =>
curve.toECParameterSpec
.map { spec =>
for {
dx <- _d.decodeToBigInt.left.map(error => PrivateKeyCreationError(error.message))
privateKeySpec = new ECPrivateKeySpec(dx.bigInteger, spec)
result <-
try {
val keyFactory =
provider.map(p => KeyFactory.getInstance("EC", p)).getOrElse(KeyFactory.getInstance("EC"))
Right(Some(keyFactory.generatePrivate(privateKeySpec).asInstanceOf[ECPrivateKey]))
} catch {
case e @ (_: NoSuchAlgorithmException | _: InvalidKeySpecException) =>
Left(PrivateKeyCreationError(e.getMessage))
}
} yield result
}
.getOrElse(Left(PrivateKeyCreationError("Couldn't get EC parameter spec for curve " + curve)))
}
}
override def getRequiredParams: Map[String, Any] = Map(
"crv" -> curve.name,
"kty" -> keyType.entryName,
"x" -> x.asString,
"y" -> y.asString
)
override def isPrivate: Boolean = d.nonEmpty || privateKey.nonEmpty
override def toPublicJWK: JWK = new ECJWK(
curve,
x,
y,
publicKeyUseType,
keyOperations,
algorithmType,
keyId,
x509Url,
x509CertificateSHA1Thumbprint,
x509CertificateSHA256Thumbprint,
x509CertificateChain
)
override def size: Either[JWKError.JOSEError, Int] = {
curve.toECParameterSpec
.map { spec => Right(spec.getCurve.getField.getFieldSize) }
.getOrElse(Left(JOSEError("Couldn't determine field size for curve " + curve.name)))
}
override def toPublicKey: Either[PublicKeyCreationError, PublicKey] = toECPublicKey(None)
override def toPrivateKey: Either[PrivateKeyCreationError, PrivateKey] = {
for {
prv <- toECPrivateKey()
result <- prv.map(Right(_)).getOrElse {
privateKey
.map(Right(_))
.getOrElse(Left(PrivateKeyCreationError("Illegal Argument: privateKey is not found")))
}
} yield result
}
override def toKeyPair: Either[KeyCreationError, KeyPair] =
for {
publicKey <- toECPublicKey()
privateKey <- toPrivateKey
} yield new KeyPair(publicKey, privateKey)
override def compare(that: JWK): Int = super.compareTo(that)
override def canEqual(other: Any): Boolean = other.isInstanceOf[ECJWK]
override def equals(other: Any): Boolean = other match {
case that: ECJWK =>
super.equals(that) &&
(that canEqual this) &&
curve == that.curve &&
x == that.x &&
y == that.y &&
d == that.d &&
privateKey == that.privateKey
case _ => false
}
override def hashCode(): Int = {
val state = Seq(curve, x, y, d, privateKey)
state.map(_.hashCode()).foldLeft(super.hashCode())((a, b) => 31 * a + b)
}
override def toString: String =
Seq(
curve,
x,
y,
keyType,
publicKeyUseType,
keyOperations,
algorithmType,
keyId,
x509Url,
x509CertificateSHA256Thumbprint,
x509CertificateSHA1Thumbprint,
x509CertificateChain,
d,
privateKey
).mkString("ECJWK(", ",", ")")
}
trait ECJWKJsonImplicits extends JsonImplicits {
import io.circe.{ Decoder, Encoder, Json }
import io.circe.syntax._
implicit val CurveJsonEncoder: Encoder[Curve] = Encoder[String].contramap(_.name)
implicit val CurveJsonDecoder: Decoder[Curve] = Decoder[String].map(v => Curve.withName(v).get)
implicit val ECJWKJsonEncoder: Encoder[ECJWK] = Encoder.instance { v =>
Json.obj(
"kty" -> v.keyType.asJson,
"use" -> v.publicKeyUseType.asJson,
"key_ops" -> v.keyOperations.asJson,
"alg" -> v.algorithmType.asJson,
"kid" -> v.keyId.asJson,
"x5u" -> v.x509Url.asJson,
"x5t" -> v.x509CertificateSHA1Thumbprint.asJson,
"x5t#256" -> v.x509CertificateSHA256Thumbprint.asJson,
"x5c" -> v.x509CertificateChain.asJson,
"crv" -> v.curve.asJson,
"x" -> v.x.asJson,
"y" -> v.y.asJson,
"d" -> v.d.asJson
)
}
implicit val ECJWKJsonDecoder: Decoder[ECJWK] = Decoder.instance { hcursor =>
for {
_ <- hcursor.get[KeyType]("kty").flatMap { v =>
if (v == KeyType.EC) Right(v) else Left(DecodingFailure("Invalid key type", hcursor.history))
}
use <- hcursor.get[Option[PublicKeyUseType]]("use")
ops <- hcursor.getOrElse[KeyOperations]("key_ops")(KeyOperations.empty)
alg <- hcursor.getOrElse[Option[JWSAlgorithmType]]("alg")(None)
kid <- hcursor.getOrElse[Option[KeyId]]("kid")(None)
x5u <- hcursor.getOrElse[Option[URI]]("k5u")(None)
k5t <- hcursor.getOrElse[Option[Base64String]]("k5t")(None)
k5t256 <- hcursor.getOrElse[Option[Base64String]]("k5t#256")(None)
k5c <- hcursor.getOrElse[Option[NonEmptyList[Base64String]]]("k5c")(None)
crv <- hcursor.get[Curve]("crv")
x <- hcursor.get[Base64String]("x")
y <- hcursor.get[Base64String]("y")
d <- hcursor.getOrElse[Option[Base64String]]("d")(None)
} yield new ECJWK(
crv,
x,
y,
use,
ops,
alg,
kid,
x5u,
k5t,
k5t256,
k5c,
d = d
)
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy