commonMain.io.eqoty.secretk.crypto.elliptic.Curve.kt Maven / Gradle / Ivy
package io.eqoty.secretk.crypto.elliptic
import io.eqoty.secretk.crypto.elliptic.biginteger.BN
import io.eqoty.secretk.crypto.elliptic.biginteger.mont
import io.eqoty.secretk.crypto.elliptic.biginteger.red
import io.eqoty.secretk.crypto.elliptic.curves.Endomorphism
import io.eqoty.secretk.crypto.elliptic.curves.PresetCurve
import io.eqoty.secretk.crypto.elliptic.utils.getJSF
import io.eqoty.secretk.crypto.elliptic.utils.getNAF
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.jsonArray
import kotlin.math.max
sealed class Curve(val presetCurve: PresetCurve) {
val p = presetCurve.p
val red = if (presetCurve.prime != null) BN.red(presetCurve.prime!!) else BN.mont(p)
// Useful for many curves
val zero = BN(0).toRed(this.red)
val one = BN(1).toRed(this.red)
val two = BN(2).toRed(this.red)
val n = presetCurve.n
val g: BasePoint<*>? = presetCurve.g?.let { g ->
pointFromJSON(g, presetCurve.gRed)
}
abstract fun pointFromJSON(g: String, gRed: Boolean): BasePoint<*>
abstract fun validate(basePoint: BasePoint<*>): Boolean
abstract fun decodePoint(bytes: UByteArray, enc: String?): BasePoint<*>
// Temporary arrays
protected val _wnafT1 = Array(4) { null }
protected val _wnafT3 = Array?>(4) { null }
protected val _wnafT4 = Array(4) { null }
protected val bitLength = if (this.n != null) this.n.bitLength() else 0u
}
class ShortCurve(presetCurve: PresetCurve) : Curve(presetCurve) {
val a = presetCurve.a.toRed(red)
val b = presetCurve.b.toRed(red)
//val tinv = two.redInvm();
val zeroA = a.fromRed() == BN.ZERO
val threeA = this.a.fromRed().subtract(this.p).compareTo(-3) == 0
val endo = getEndomorphism()
fun point(x: String, y: String, isRed: Boolean? = null): ShortCurvePoint =
ShortCurvePoint(this, "affine", x, y, isRed)
fun point(x: BN?, y: BN?, isRed: Boolean? = null): ShortCurvePoint =
ShortCurvePoint(this, "affine", x, y, isRed)
private fun getEndomorphism(): Endomorphism? {
// No efficient endomorphism
if (!this.zeroA || this.g == null || this.n == null || p.mod(BN(3)) != BN.ONE)
return null
// Compute beta and lambda, that lambda * P = (beta * Px; Py)
val beta = presetCurve.beta!!.toRed(red)
val lambda = presetCurve.lambda
// Get basis vectors, used for balanced length-two representation
val basis = presetCurve.basis
return Endomorphism(
beta = beta,
lambda = lambda,
basis = basis,
)
}
private data class EndoSplit(var k1: BN, var k2: BN)
private fun endoSplit(k: BN): EndoSplit {
val basis = this.endo!!.basis
val v1 = basis[0]
val v2 = basis[1]
val c1 = v2.b.multiply(k).divRound(n!!)
var c2 = v1.b.negate().multiply(k).divRound(n)
val p1 = c1.multiply(v1.a)
val p2 = c2.multiply(v2.a)
val q1 = c1.multiply(v1.b)
val q2 = c2.multiply(v2.b)
// Calculate answer
var k1 = k.subtract(p1).subtract(p2)
var k2 = q1.add(q2).negate()
return EndoSplit(k1 = k1, k2 = k2)
}
fun endoWnafMulAdd(
points: List,
coeffs: List,
jacobianResult: Any? = null
): BasePoint {
val npoints = MutableList(4) { null }
val ncoeffs = MutableList(4) { null }
for (i in points.indices) {
val split = this.endoSplit(coeffs[i])
var p = points[i]
var beta = p.getBeta()!!
if (split.k1.negative) {
split.k1 = split.k1.negate()
p = p.neg(true)
}
if (split.k2.negative) {
split.k2 = split.k2.negate()
beta = beta.neg(true)
}
npoints[i * 2] = p
npoints[i * 2 + 1] = beta
ncoeffs[i * 2] = split.k1
ncoeffs[i * 2 + 1] = split.k2
}
val res = this.wnafMulAdd(1, npoints, ncoeffs.toList(), points.size * 2, jacobianResult)
// Clean-up references to points and coefficients
for (j in 0 until npoints.size) {
npoints[j] = null
ncoeffs[j] = null
}
return res
}
fun wnafMulAdd(
defW: Int,
points: MutableList,
coeffs: List,
len: Int,
jacobianResult: Any? = null
): BasePoint {
var wndWidth = this._wnafT1
val wnd = Array?>?>(4) { null }
var naf = this._wnafT3
// Fill all arrays
var max = 0
// var i;
// var j;
for (i in 0 until len) {
val p = points[i]!!
val nafPoints = p.getNAFPoints(defW)
wndWidth[i] = nafPoints.wnd
wnd[i] = nafPoints.points as List
}
var p = points[len - 1] as BasePoint?
// Comb small window NAFs
for (i in len - 1 downTo 1 step 2) {
var a = i - 1
var b = i
if (wndWidth[a] != 1 || wndWidth[b] != 1) {
naf[a] = getNAF(coeffs[a]!!, wndWidth[a]!!, this.bitLength.toInt())
naf[b] = getNAF(coeffs[b]!!, wndWidth[b]!!, this.bitLength.toInt())
max = max(naf[a]!!.size, max)
max = max(naf[b]!!.size, max)
continue
}
val comb: MutableList?> = mutableListOf(
points[a] as ShortCurvePoint, /* 1 */
null, /* 3 */
null, /* 5 */
points[b] as ShortCurvePoint, /* 7 */
)
// Try to avoid Projective points, if possible
if (points[a]!!.y!!.compareTo(points[b]!!.y!!) == 0) {
comb[1] = points[a]!!.add(points[b]!!)
comb[2] = points[a]!!.toJ().mixedAdd(points[b]!!.neg())
} else if (points[a]!!.y!!.compareTo(points[b]!!.y!!.redNeg()) == 0) {
comb[1] = points[a]!!.toJ().mixedAdd(points[b]!!)
comb[2] = points[a]!!.add(points[b]!!.neg())
} else {
comb[1] = points[a]!!.toJ().mixedAdd(points[b]!!)
comb[2] = points[a]!!.toJ().mixedAdd(points[b]!!.neg())
}
val index = listOf(
-3, /* -1 -1 */
-1, /* -1 0 */
-5, /* -1 1 */
-7, /* 0 -1 */
0, /* 0 0 */
7, /* 0 1 */
5, /* 1 -1 */
1, /* 1 0 */
3, /* 1 1 */
)
val jsf = getJSF(coeffs[a]!!, coeffs[b]!!)
max = max(jsf[0].size, max)
naf[a] = Array(max) { 0 }
naf[b] = Array(max) { 0 }
for (j in 0 until max) {
val ja = jsf[0][j] or 0
val jb = jsf[1][j] or 0
naf[a]!![j] = index[(ja + 1) * 3 + (jb + 1)]
naf[b]!![j] = 0
wnd[a] = comb.toList()
}
}
var acc = JPoint(this, null, null, null)
var tmp = this._wnafT4
var i = max
while (i >= 0) {
var k = 0
while (i >= 0) {
var zero = true
for (j in 0 until len) {
tmp[j] = naf[j]!!.getOrElse(i) { 0 }
if (tmp[j] != 0)
zero = false
}
if (!zero)
break
k++
i--
}
if (i >= 0)
k++
acc = acc.dblp(k)
if (i < 0)
break
for (j in 0 until len) {
val z = tmp[j]
if (z == 0)
continue
else if (z!! > 0)
p = wnd[j]!![(z - 1) shr 1]
else if (z < 0)
p = wnd[j]!![(-z - 1) shr 1]!!.neg()
if (p!!.type == "affine")
acc = acc.mixedAdd(p as ShortCurvePoint)
else
acc = acc.add(p as JPoint)
}
i--
}
// Zeroify references
for (i in 0 until len) {
wnd[i] = null
}
if (jacobianResult != null)
return acc
else
return acc.toP()
}
fun jpoint(x: BN?, y: BN?, z: BN?): JPoint = JPoint(this, x, y, z)
fun fixedNafMul(p: ShortCurvePoint, k: BN): BasePoint {
require(p.precomputed != null)
var doubles = p.getDoubles()
var naf = getNAF(k, 1, this.bitLength.toInt())
var I = (1 shl (doubles.step + 1)) - (if (doubles.step % 2 == 0) 2 else 1)
I /= 3
// Translate into more windowed form
var repr = mutableListOf()
var j = 0
var nafW: Int?
while (j < naf.size) {
nafW = 0
for (l in j + doubles.step - 1 downTo j) {
nafW = naf.getOrNull(l)?.plus(nafW?.shl(1) ?: 0)
}
repr.add(nafW!!)
j += doubles.step
}
var a = this.jpoint(null, null, null)
var b = this.jpoint(null, null, null)
for (i in I downTo 1) {
for (l in 0 until repr.size) {
nafW = repr[l]
if (nafW == i)
b = b.mixedAdd(doubles.points[l] as ShortCurvePoint)
else if (nafW == -i)
b = b.mixedAdd((doubles.points[l] as ShortCurvePoint).neg())
}
a = a.add(b)
}
return a.toP()
}
override fun pointFromJSON(g: String, gRed: Boolean): BasePoint<*> =
BasePoint.fromJSON(this, Json.parseToJsonElement(g).jsonArray, presetCurve.gRed)
override fun validate(_point: BasePoint<*>): Boolean {
val point = _point as ShortCurvePoint
if (point.inf)
return true
val x = point.x
val y = point.y
val ax = this.a.redMul(x!!)
val rhs = x.redSqr().redMul(x).redAdd(ax).redAdd(this.b)
return y!!.redSqr().redSub(rhs).compareTo(0) == 0
}
override fun decodePoint(bytes: UByteArray, enc: String?): ShortCurvePoint {
var len = this.p.byteLength()
// uncompressed, hybrid-odd, hybrid-even
if ((bytes[0].toInt() == 0x04 || bytes[0].toInt() == 0x06 || bytes[0].toInt() == 0x07) &&
bytes.size - 1 == 2 * len
) {
if (bytes[0].toInt() == 0x06)
require(bytes[bytes.size - 1].toInt() % 2 == 0)
else if (bytes[0].toInt() == 0x07)
require(bytes[bytes.size - 1].toInt() % 2 == 1)
val res = this.point(
BN(bytes.slice(IntRange(1, len)).toUByteArray()),
BN(bytes.slice(IntRange(1 + len, 2 * len)).toUByteArray())
)
return res
} else if ((bytes[0].toInt() == 0x02 || bytes[0].toInt() == 0x03) &&
bytes.size - 1 == len
) {
return this.pointFromX(
bytes.slice(IntRange(1, len)).toUByteArray(),
bytes[0].toInt() == 0x03
)
}
throw Error("Unknown point format")
}
private fun pointFromX(_x: UByteArray, odd: Boolean): ShortCurvePoint {
TODO()
// var x = BN(_x);
// if (x.red == null)
// x = x.toRed(this.red);
//
// val y2 = x.redSqr().redMul(x).redAdd(x.redMul(this.a)).redAdd(this.b);
// var y = y2.redSqrt();
// if (y.redSqr().redSub(y2).cmp(this.zero) !== 0)
// throw Error("invalid point");
//
// // XXX Is there any way to tell if the number is odd without converting it
// // to non-red form?
// var isOdd = y.fromRed().isOdd();
// if (odd && !isOdd || !odd && isOdd)
// y = y.redNeg();
//
// return this.point(x, y);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy