All Downloads are FREE. Search and download functionalities are using the official Maven repository.

commonMain.fr.acinq.lightning.crypto.noise.Noise.kt Maven / Gradle / Ivy

There is a newer version: 1.8.4
Show newest version
package fr.acinq.lightning.crypto.noise

import kotlin.random.Random

interface DHFunctions {
    fun name(): String
    fun generateKeyPair(priv: ByteArray): Pair
    fun dh(keyPair: Pair, publicKey: ByteArray): ByteArray
    fun dhLen(): Int
    fun pubKeyLen(): Int
}

/**
 * Cipher functions
 */
interface CipherFunctions {
    fun name(): String

    // Encrypts plaintext using the cipher key k of 32 bytes and an 8-byte unsigned integer nonce n which must be unique
    // for the key k. Returns the ciphertext. Encryption must be done with an "AEAD" encryption mode with the associated
    // data ad (using the terminology from [1]) and returns a ciphertext that is the same size as the plaintext
    // plus 16 bytes for authentication data. The entire ciphertext must be indistinguishable from random if the key is secret.
    fun encrypt(k: ByteArray, n: Long, ad: ByteArray, plaintext: ByteArray): ByteArray

    // Decrypts ciphertext using a cipher key k of 32 bytes, an 8-byte unsigned integer nonce n, and associated data ad.
    // Returns the plaintext, unless authentication fails, in which case an error is signaled to the caller.
    fun decrypt(k: ByteArray, n: Long, ad: ByteArray, ciphertext: ByteArray): ByteArray
}

/**
 * Hash functions
 */
interface HashFunctions {
    fun name(): String

    // Hashes some arbitrary-length data with a collision-resistant cryptographic hash function and returns an output of HASHLEN bytes.
    fun hash(data: ByteArray): ByteArray

    // A constant specifying the size in bytes of the hash output. Must be 32 or 64.
    fun hashLen(): Int

    // A constant specifying the size in bytes that the hash function uses internally to divide its input for iterative processing. This is needed to use the hash function with HMAC (BLOCKLEN is B in [2]).
    fun blockLen(): Int

    // Applies HMAC from [2] using the HASH() function. This function is only called as part of HKDF(), below.
    fun hmacHash(key: ByteArray, data: ByteArray): ByteArray

    // Takes a chaining_key byte sequence of length HASHLEN, and an input_key_material byte sequence with length either zero bytes, 32 bytes, or DHLEN bytes. Returns two byte sequences of length HASHLEN, as follows:
    // Sets temp_key = HMAC-HASH(chaining_key, input_key_material).
    // Sets output1 = HMAC-HASH(temp_key, byte(0x01)).
    // Sets output2 = HMAC-HASH(temp_key, output1 || byte(0x02)).
    // Returns the pair (output1, output2).
    fun hkdf(chainingKey: ByteArray, inputMaterial: ByteArray): Pair {
        val tempkey = hmacHash(chainingKey, inputMaterial)
        val output1 = hmacHash(tempkey, byteArrayOf(0x01))
        val output2 = hmacHash(tempkey, output1 + byteArrayOf(0x02))

        return Pair(output1, output2)
    }
}

/**
 * Cipher state
 */
interface CipherState {
    fun cipher(): CipherFunctions

    fun initializeKey(key: ByteArray): CipherState =
        apply(key, cipher())

    fun hasKey(): Boolean

    fun encryptWithAd(ad: ByteArray, plaintext: ByteArray): Pair

    fun decryptWithAd(ad: ByteArray, ciphertext: ByteArray): Pair

    companion object {
        fun apply(k: ByteArray, cipher: CipherFunctions): CipherState = when (k.size) {
            0 -> UninitializedCipherState(cipher)
            32 -> InitializedCipherState(k, 0, cipher)
            else -> throw RuntimeException("invalid key size")
        }

        fun apply(cipher: CipherFunctions): CipherState =
            UninitializedCipherState(cipher)
    }
}

/**
 * Uninitialized cipher state. Encrypt and decrypt do nothing (ciphertext = plaintext)
 *
 * @param cipher cipher functions
 */
data class UninitializedCipherState(val cipher: CipherFunctions) :
    CipherState {
    override fun cipher(): CipherFunctions = cipher

    override fun hasKey() = false

    override fun encryptWithAd(ad: ByteArray, plaintext: ByteArray): Pair = Pair(this, plaintext)

    override fun decryptWithAd(ad: ByteArray, ciphertext: ByteArray): Pair = Pair(this, ciphertext)
}

/**
 * Initialized cipher state
 *
 * @param k      key
 * @param n      nonce
 * @param cipher cipher functions
 */
data class InitializedCipherState(val k: ByteArray, val n: Long, val cipher: CipherFunctions) :
    CipherState {
    init {
        require(k.size == 32) { "key size must be 32 bytes" }
    }

    override fun cipher(): CipherFunctions = cipher

    override fun hasKey() = true

    override fun encryptWithAd(ad: ByteArray, plaintext: ByteArray): Pair {
        return Pair(this.copy(n = this.n + 1), cipher.encrypt(k, n, ad, plaintext))
    }

    override fun decryptWithAd(ad: ByteArray, ciphertext: ByteArray): Pair {
        return Pair(this.copy(n = this.n + 1), cipher.decrypt(k, n, ad, ciphertext))
    }

    override fun equals(other: Any?): Boolean {
        if (this === other) return true
        //if (javaClass != other?.javaClass) return false

        other as InitializedCipherState

        if (!k.contentEquals(other.k)) return false
        if (n != other.n) return false
        if (cipher != other.cipher) return false

        return true
    }

    override fun hashCode(): Int {
        var result = k.contentHashCode()
        result = 31 * result + n.hashCode()
        result = 31 * result + cipher.hashCode()
        return result
    }
}

/**
 *
 * @param cipherState   cipher state
 * @param ck            chaining key
 * @param h             hash
 * @param hashFunctions hash functions
 */
data class SymmetricState(val cipherState: CipherState, val ck: ByteArray, val h: ByteArray, val hashFunctions: HashFunctions) {
    fun mixKey(inputKeyMaterial: ByteArray): SymmetricState {
        val (ck1, tempk) = hashFunctions.hkdf(ck, inputKeyMaterial)
        val tempk1 = when (hashFunctions.hashLen()) {
            32 -> tempk
            64 -> tempk.take(32).toByteArray()
            else -> throw RuntimeException("invalid key size, must be 32 or 64 bytes")
        }
        return this.copy(cipherState = cipherState.initializeKey(tempk1), ck = ck1)
    }

    fun mixHash(data: ByteArray): SymmetricState {
        return this.copy(h = hashFunctions.hash(h + data))
    }

    fun encryptAndHash(plaintext: ByteArray): Pair {
        val (cipherstate1, ciphertext) = cipherState.encryptWithAd(h, plaintext)
        return Pair(this.copy(cipherState = cipherstate1).mixHash(ciphertext), ciphertext)
    }

    fun decryptAndHash(ciphertext: ByteArray): Pair {
        val (cipherstate1, plaintext) = cipherState.decryptWithAd(h, ciphertext)
        return Pair(this.copy(cipherState = cipherstate1).mixHash(ciphertext), plaintext)
    }

    fun split(): Triple {
        val (tempk1, tempk2) = hashFunctions.hkdf(ck, ByteArray(0))
        return Triple(cipherState.initializeKey(tempk1.take(32).toByteArray()), cipherState.initializeKey(tempk2.take(32).toByteArray()), ck)
    }

    override fun equals(other: Any?): Boolean {
        if (this === other) return true
        //if (javaClass != other?.javaClass) return false

        other as SymmetricState

        if (cipherState != other.cipherState) return false
        if (!ck.contentEquals(other.ck)) return false
        if (!h.contentEquals(other.h)) return false
        if (hashFunctions != other.hashFunctions) return false

        return true
    }

    override fun hashCode(): Int {
        var result = cipherState.hashCode()
        result = 31 * result + ck.contentHashCode()
        result = 31 * result + h.contentHashCode()
        result = 31 * result + hashFunctions.hashCode()
        return result
    }

    companion object {
        fun apply(protocolName: ByteArray, cipherFunctions: CipherFunctions, hashFunctions: HashFunctions): SymmetricState {
            val h = if (protocolName.size <= hashFunctions.hashLen())
                protocolName + ByteArray(hashFunctions.hashLen() - protocolName.size)
            else hashFunctions.hash(protocolName)

            return SymmetricState(
                CipherState.apply(
                    cipherFunctions
                ), ck = h, h = h, hashFunctions = hashFunctions
            )
        }
    }
}

enum class MessagePattern {
    S, E, EE, ES, SE, SS
}

data class HandshakePattern(val name: String, val initiatorPreMessages: List, val responderPreMessages: List, val messages: List>) {
    init {
        require(
            isValidInitiator(
                initiatorPreMessages
            )
        ) { "invalid initiator messages" }
        require(
            isValidInitiator(
                responderPreMessages
            )
        ) { "invalid responder messages" }
    }

    companion object {
        val validInitiatorPatterns: Set> = setOf(
            listOf(), listOf(
                MessagePattern.E
            ), listOf(MessagePattern.S), listOf(
                MessagePattern.E,
                MessagePattern.S
            )
        )

        fun isValidInitiator(initiator: List): Boolean = validInitiatorPatterns.contains(initiator)
    }
}

/**
 * standard handshake patterns
 */

val handshakePatternNN = HandshakePattern(
    "NN",
    initiatorPreMessages = listOf(),
    responderPreMessages = listOf(),
    messages = listOf(
        listOf(MessagePattern.E),
        listOf(
            MessagePattern.E,
            MessagePattern.EE
        )
    )
)

val handshakePatternXK = HandshakePattern(
    "XK",
    initiatorPreMessages = listOf(),
    responderPreMessages = listOf(MessagePattern.S),
    messages = listOf(
        listOf(
            MessagePattern.E,
            MessagePattern.ES
        ),
        listOf(
            MessagePattern.E,
            MessagePattern.EE
        ),
        listOf(
            MessagePattern.S,
            MessagePattern.SE
        )
    )
)

interface ByteStream {
    fun nextBytes(length: Int): ByteArray
}

object RandomBytes : ByteStream {
    override fun nextBytes(length: Int) = Random.nextBytes(length)
}


sealed class HandshakeState {
    @ExperimentalStdlibApi
    companion object {
        private fun makeSymmetricState(handshakePattern: HandshakePattern, prologue: ByteArray, dh: DHFunctions, cipher: CipherFunctions, hash: HashFunctions): SymmetricState {
            val name = "Noise_${handshakePattern.name}_${dh.name()}_${cipher.name()}_${hash.name()}"
            val symmetricState = SymmetricState.apply(
                name.encodeToByteArray(),
                cipher,
                hash
            )
            return symmetricState.mixHash(prologue)
        }

        fun initializeWriter(
            handshakePattern: HandshakePattern,
            prologue: ByteArray,
            s: Pair,
            e: Pair,
            rs: ByteArray,
            re: ByteArray,
            dh: DHFunctions,
            cipher: CipherFunctions,
            hash: HashFunctions,
            byteStream: ByteStream = RandomBytes
        ): HandshakeStateWriter {
            val symmetricState =
                makeSymmetricState(
                    handshakePattern,
                    prologue,
                    dh,
                    cipher,
                    hash
                )
            val symmetricState1 = handshakePattern.initiatorPreMessages.fold(symmetricState, { state, pattern ->
                when (pattern) {
                    MessagePattern.E -> state.mixHash(e.first)
                    MessagePattern.S -> state.mixHash(s.first)
                    else -> throw RuntimeException("invalid pre-message")
                }
            })
            val symmetricState2 = handshakePattern.responderPreMessages.fold(symmetricState1, { state, pattern ->
                when (pattern) {
                    MessagePattern.E -> state.mixHash(re)
                    MessagePattern.S -> state.mixHash(rs)
                    else -> throw RuntimeException("invalid pre-message")
                }
            })
            return HandshakeStateWriter(
                handshakePattern.messages,
                symmetricState2,
                s,
                e,
                rs,
                re,
                dh,
                byteStream
            )
        }

        fun initializeReader(
            handshakePattern: HandshakePattern,
            prologue: ByteArray,
            s: Pair,
            e: Pair,
            rs: ByteArray,
            re: ByteArray,
            dh: DHFunctions,
            cipher: CipherFunctions,
            hash: HashFunctions,
            byteStream: ByteStream = RandomBytes
        ): HandshakeStateReader {
            val symmetricState =
                makeSymmetricState(
                    handshakePattern,
                    prologue,
                    dh,
                    cipher,
                    hash
                )
            val symmetricState1 = handshakePattern.initiatorPreMessages.fold(symmetricState, { state, pattern ->
                when (pattern) {
                    MessagePattern.E -> state.mixHash(re)
                    MessagePattern.S -> state.mixHash(rs)
                    else -> throw RuntimeException("invalid pre-message")
                }
            })
            val symmetricState2 = handshakePattern.responderPreMessages.fold(symmetricState1, { state, pattern ->
                when (pattern) {
                    MessagePattern.E -> state.mixHash(e.first)
                    MessagePattern.S -> state.mixHash(s.first)
                    else -> throw RuntimeException("invalid pre-message")
                }
            })
            return HandshakeStateReader(
                handshakePattern.messages,
                symmetricState2,
                s,
                e,
                rs,
                re,
                dh,
                byteStream
            )
        }

    }
}

data class HandshakeStateWriter(
    val messages: List>,
    val state: SymmetricState,
    val s: Pair,
    val e: Pair,
    val rs: ByteArray,
    val re: ByteArray,
    val dh: DHFunctions,
    val byteStream: ByteStream
) : HandshakeState() {

    fun toReader(): HandshakeStateReader =
        HandshakeStateReader(
            messages,
            state,
            s,
            e,
            rs,
            re,
            dh,
            byteStream
        )

    /**
     *
     * @param payload input message (can be empty)
     * @return a (reader, output, Option[(cipherstate, cipherstate)] tuple.
     *         The output will be sent to the other side, and we will read its answer using the returned reader instance
     *         When the handshake is over (i.e. there are no more handshake patterns to process) the last item will
     *         contain 2 cipherstates than can be used to encrypt/decrypt further communication
     */
    fun write(payload: ByteArray): Triple?> {
        val (writer1, buffer1) = messages.first().fold(Pair(this, ByteArray(0)), { (writer, buffer), pattern ->
            when (pattern) {
                MessagePattern.E -> {
                    val e1 = dh.generateKeyPair(byteStream.nextBytes(dh.dhLen()))
                    val state1 = writer.state.mixHash(e1.first)
                    Pair(writer.copy(state = state1, e = e1), buffer + e1.first)
                }
                MessagePattern.S -> {
                    val (state1, ciphertext) = writer.state.encryptAndHash(s.first)
                    Pair(writer.copy(state = state1), buffer + ciphertext)
                }
                MessagePattern.EE -> {
                    val state1 = writer.state.mixKey(dh.dh(writer.e, writer.re))
                    Pair(writer.copy(state = state1), buffer)
                }
                MessagePattern.SS -> {
                    val state1 = writer.state.mixKey(dh.dh(writer.s, writer.rs))
                    Pair(writer.copy(state = state1), buffer)
                }
                MessagePattern.ES -> {
                    val state1 = writer.state.mixKey(dh.dh(writer.e, writer.rs))
                    Pair(writer.copy(state = state1), buffer)
                }
                MessagePattern.SE -> {
                    val state1 = writer.state.mixKey(dh.dh(writer.s, writer.re))
                    Pair(writer.copy(state = state1), buffer)
                }
            }
        })

        val (state1, ciphertext) = writer1.state.encryptAndHash(payload)
        val buffer2 = buffer1 + ciphertext
        val writer2 = writer1.copy(messages = messages.drop(1), state = state1)

        return Triple(writer2.toReader(), buffer2, if (messages.drop(1).isEmpty()) writer2.state.split() else null)
    }

    companion object {
        fun apply(messages: List>, state: SymmetricState, s: Pair, e: Pair, rs: ByteArray, re: ByteArray, dh: DHFunctions): HandshakeStateWriter =
            HandshakeStateWriter(
                messages,
                state,
                s,
                e,
                rs,
                re,
                dh,
                RandomBytes
            )
    }
}


data class HandshakeStateReader(
    val messages: List>,
    val state: SymmetricState,
    val s: Pair,
    val e: Pair,
    val rs: ByteArray,
    val re: ByteArray,
    val dh: DHFunctions,
    val byteStream: ByteStream
) : HandshakeState() {
    fun toWriter(): HandshakeStateWriter =
        HandshakeStateWriter(
            messages,
            state,
            s,
            e,
            rs,
            re,
            dh,
            byteStream
        )

    /** *
     *
     * @param message input message
     * @return a (writer, payload, Option[(cipherstate, cipherstate)] tuple.
     *         The payload contains the original payload used by the sender and a writer that will be used to create the
     *         next message. When the handshake is over (i.e. there are no more handshake patterns to process) the last item will
     *         contain 2 cipherstates than can be used to encrypt/decrypt further communication
     */
    fun read(message: ByteArray): Triple?> {
        val (reader1, buffer1) = messages.first().fold(Pair(this, message), { (reader, buffer), pattern ->
            when (pattern) {
                MessagePattern.E -> {
                    val (re1, buffer1) = buffer.splitAt(dh.pubKeyLen())
                    val state1 = reader.state.mixHash(re1)
                    Pair(reader.copy(state = state1, re = re1), buffer1)
                }
                MessagePattern.S -> {
                    val len = if (reader.state.cipherState.hasKey()) dh.pubKeyLen() + 16 else dh.pubKeyLen()
                    val (temp, buffer1) = buffer.splitAt(len)
                    val (state1, rs1) = reader.state.decryptAndHash(temp)
                    Pair(reader.copy(state = state1, rs = rs1), buffer1)
                }
                MessagePattern.EE -> {
                    val state1 = reader.state.mixKey(dh.dh(reader.e, reader.re))
                    Pair(reader.copy(state = state1), buffer)
                }
                MessagePattern.SS -> {
                    val state1 = reader.state.mixKey(dh.dh(reader.s, reader.rs))
                    Pair(reader.copy(state = state1), buffer)
                }
                MessagePattern.ES -> {
                    val ss = dh.dh(reader.s, reader.re)
                    val state1 = reader.state.mixKey(ss)
                    Pair(reader.copy(state = state1), buffer)
                }
                MessagePattern.SE -> {
                    val state1 = reader.state.mixKey(dh.dh(reader.e, reader.rs))
                    Pair(reader.copy(state = state1), buffer)
                }
            }
        })

        val (state1, payload) = reader1.state.decryptAndHash(buffer1)
        val reader2 = reader1.copy(messages = messages.drop(1), state = state1)
        return Triple(reader2.toWriter(), payload, if (messages.drop(1).isEmpty()) reader2.state.split() else null)
    }

    companion object {
        private fun ByteArray.splitAt(n: Int): Pair = Pair(this.take(n).toByteArray(), this.drop(n).toByteArray())
        fun apply(messages: List>, state: SymmetricState, s: Pair, e: Pair, rs: ByteArray, re: ByteArray, dh: DHFunctions): HandshakeStateReader =
            HandshakeStateReader(
                messages,
                state,
                s,
                e,
                rs,
                re,
                dh,
                RandomBytes
            )
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy