All Downloads are FREE. Search and download functionalities are using the official Maven repository.

huis.noise-kotlin.1.0.1.source-code.Handshake.kt Maven / Gradle / Ivy

package nl.sanderdijkhuis.noise

import nl.sanderdijkhuis.noise.Handshake.Token.*
import nl.sanderdijkhuis.noise.Role.INITIATOR
import nl.sanderdijkhuis.noise.Role.RESPONDER
import nl.sanderdijkhuis.noise.cryptography.*
import nl.sanderdijkhuis.noise.data.Data
import nl.sanderdijkhuis.noise.data.Size
import nl.sanderdijkhuis.noise.data.State

/**
 * Encompasses all Noise protocol handshake state required to read and write messages.
 *
 * Start with [initialize], supplying one of the implemented patterns:
 *
 * - [Noise_NK_25519_ChaChaPoly_SHA256]
 * - [Noise_XN_25519_ChaChaPoly_SHA256]
 *
 * Proposals to add more patterns or generate correct patterns dynamically are welcomed.
 */
data class Handshake(
    val role: Role,
    val symmetry: Symmetry,
    val messagePatterns: List>,
    val localStaticKeyPair: Pair? = null,
    val localEphemeralKeyPair: Pair? = null,
    val remoteStaticKey: PublicKey? = null,
    val remoteEphemeralKey: PublicKey? = null,
    val trustedStaticKeys: Set = emptySet()
) : MessageType {

    /** A handshake pattern. */
    data class Pattern(
        val name: String,
        val preSharedMessagePatterns: List>,
        val messagePatterns: List>
    )

    /** Message pattern token, indicating which keys are sent and which key agreements are performed. */
    enum class Token {
        E, S, EE, ES, SE, SS
    }

    private val cryptography get() = symmetry.cryptography

    private fun State.run(d: Data = Data.empty, f: (Symmetry) -> Symmetry) =
        copy(value = value.copy(symmetry = f(value.symmetry)), result = result + d)

    private fun State.append(f: (Symmetry) -> State?) =
        f(value.symmetry)?.let { s -> State(value.copy(symmetry = s.value), result + s.result) }

    /** Returns a handshake message only if this is appropriate given the pattern and state. */
    fun writeMessage(payload: Payload): State? =
        messagePatterns.first().fold(State(this, Data.empty) as State?) { state, token ->
            state?.let { s ->
                fun mix(local: Pair?, remote: PublicKey?) =
                    local?.let { l -> remote?.let { r -> s.run { it.mixKey(cryptography.agree(l.second, r)) } } }
                when {
                    token == E -> localEphemeralKeyPair?.let { e -> s.run(e.first.data) { it.mixHash(e.first.data) } }
                    token == S -> localStaticKeyPair?.let { p -> s.append { it.encryptAndHash(p.first.plaintext) } }
                    token == EE -> mix(localEphemeralKeyPair, remoteEphemeralKey)
                    token == ES && role == INITIATOR -> mix(localEphemeralKeyPair, remoteStaticKey)
                    token == ES && role == RESPONDER -> mix(localStaticKeyPair, remoteEphemeralKey)
                    token == SE && role == INITIATOR -> mix(localStaticKeyPair, remoteEphemeralKey)
                    token == SE && role == RESPONDER -> mix(localEphemeralKeyPair, remoteStaticKey)
                    token == SS -> mix(localStaticKeyPair, remoteStaticKey)
                    else -> null
                }
            }
        }
            ?.append { it.encryptAndHash(Plaintext(payload.data)) }
            ?.let { s ->
                val rest = messagePatterns.drop(1)
                if (rest.isEmpty()) s.value.symmetry.split().let {
                    State(Transport(it.first, it.second, s.value.symmetry.handshakeHash.digest), s.result)
                }
                else State(s.value.copy(messagePatterns = rest), s.result)
            }

    /** Returns a handshake message payload only if this is appropriate given the pattern and state. */
    fun readMessage(data: Data): State? =
        messagePatterns.first().fold(State(this, data) as State?) { state, token ->
            state?.let { s ->
                fun mix(f: (Handshake) -> Pair?, g: (Handshake) -> PublicKey?) =
                    f(s.value)?.let { l ->
                        g(s.value)?.let { r -> s.run { it.mixKey(cryptography.agree(l.second, r)) } }
                    }

                fun read(size: Size?, f: (Data) -> Handshake?): State? =
                    size?.let { s.result.readFirst(it) }?.let { v -> f(v.first)?.let { s.copy(it, v.second) } }
                when {
                    token == E && s.value.remoteEphemeralKey == null -> read(SharedSecret.SIZE) {
                        s.value.copy(symmetry = s.value.symmetry.mixHash(it), remoteEphemeralKey = PublicKey(it))
                    }

                    token == S && s.value.remoteStaticKey == null -> read(SharedSecret.SIZE + Size(16u)) { r ->
                        s.value.symmetry.decryptAndHash(Ciphertext(r))?.let { ss ->
                            trustedStaticKeys.firstOrNull { it == PublicKey(ss.result.data) }
                                ?.let { s.value.copy(symmetry = ss.value, remoteStaticKey = it) }
                        }
                    }

                    token == EE -> mix({ it.localEphemeralKeyPair }, { it.remoteEphemeralKey })
                    token == ES && role == INITIATOR -> mix({ it.localEphemeralKeyPair }, { it.remoteStaticKey })
                    token == ES && role == RESPONDER -> mix({ it.localStaticKeyPair }, { it.remoteEphemeralKey })
                    token == SE && role == INITIATOR -> mix({ it.localStaticKeyPair }, { it.remoteEphemeralKey })
                    token == SE && role == RESPONDER -> mix({ it.localEphemeralKeyPair }, { it.remoteStaticKey })
                    token == SS -> mix({ it.localStaticKeyPair }, { it.remoteStaticKey })
                    else -> null
                }
            }
        }?.let {
            it.value.symmetry.decryptAndHash(Ciphertext(it.result))?.let { decrypted ->
                State(it.value.copy(symmetry = decrypted.value), Payload(decrypted.result.data))
            }
        }?.let { s ->
            val rest = messagePatterns.drop(1)
            if (rest.isEmpty()) s.value.symmetry.split()
                .let { State(Transport(it.first, it.second, s.value.symmetry.handshakeHash.digest), s.result) }
            else State(s.value.copy(messagePatterns = rest), s.result)
        }

    companion object {

        /** Returns initial handshake state only if sufficient keys are provided. */
        fun initialize(
            cryptography: Cryptography,
            pattern: Pattern,
            role: Role,
            prologue: Data,
            localStaticKeyPair: Pair? = null,
            localEphemeralKeyPair: Pair? = null,
            remoteStaticKey: PublicKey? = null,
            trustedStaticKeys: Set = emptySet()
        ): Handshake? = pattern.preSharedMessagePatterns.foldIndexed(
            Symmetry.initialize(cryptography, pattern.name).mixHash(prologue) as Symmetry?
        ) { index, state, p ->
            p.fold(state) { s, t ->
                when {
                    index == 0 && t == S && role == INITIATOR -> localStaticKeyPair?.let { s?.mixHash(it.first.data) }
                    index == 0 && t == S && role == RESPONDER -> remoteStaticKey?.let { s?.mixHash(it.data) }
                    index == 1 && t == S && role == RESPONDER -> localStaticKeyPair?.let { s?.mixHash(it.first.data) }
                    index == 1 && t == S && role == INITIATOR -> remoteStaticKey?.let { s?.mixHash(it.data) }
                    else -> null
                }
            }
        }?.let {
            Handshake(
                role,
                it,
                pattern.messagePatterns,
                localStaticKeyPair,
                localEphemeralKeyPair,
                remoteStaticKey,
                null,
                trustedStaticKeys
            )
        }

        val Noise_XN_25519_ChaChaPoly_SHA256 =
            Pattern(
                "Noise_XN_25519_ChaChaPoly_SHA256",
                listOf(),
                listOf(listOf(E), listOf(E, EE), listOf(S, SE))
            )

        val Noise_NK_25519_ChaChaPoly_SHA256 =
            Pattern(
                "Noise_NK_25519_ChaChaPoly_SHA256",
                listOf(listOf(), listOf(S)),
                listOf(listOf(E, ES), listOf(E, EE))
            )
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy