huis.noise-kotlin.1.0.1.source-code.Handshake.kt Maven / Gradle / Ivy
package nl.sanderdijkhuis.noise
import nl.sanderdijkhuis.noise.Handshake.Token.*
import nl.sanderdijkhuis.noise.Role.INITIATOR
import nl.sanderdijkhuis.noise.Role.RESPONDER
import nl.sanderdijkhuis.noise.cryptography.*
import nl.sanderdijkhuis.noise.data.Data
import nl.sanderdijkhuis.noise.data.Size
import nl.sanderdijkhuis.noise.data.State
/**
* Encompasses all Noise protocol handshake state required to read and write messages.
*
* Start with [initialize], supplying one of the implemented patterns:
*
* - [Noise_NK_25519_ChaChaPoly_SHA256]
* - [Noise_XN_25519_ChaChaPoly_SHA256]
*
* Proposals to add more patterns or generate correct patterns dynamically are welcomed.
*/
data class Handshake(
val role: Role,
val symmetry: Symmetry,
val messagePatterns: List>,
val localStaticKeyPair: Pair? = null,
val localEphemeralKeyPair: Pair? = null,
val remoteStaticKey: PublicKey? = null,
val remoteEphemeralKey: PublicKey? = null,
val trustedStaticKeys: Set = emptySet()
) : MessageType {
/** A handshake pattern. */
data class Pattern(
val name: String,
val preSharedMessagePatterns: List>,
val messagePatterns: List>
)
/** Message pattern token, indicating which keys are sent and which key agreements are performed. */
enum class Token {
E, S, EE, ES, SE, SS
}
private val cryptography get() = symmetry.cryptography
private fun State.run(d: Data = Data.empty, f: (Symmetry) -> Symmetry) =
copy(value = value.copy(symmetry = f(value.symmetry)), result = result + d)
private fun State.append(f: (Symmetry) -> State?) =
f(value.symmetry)?.let { s -> State(value.copy(symmetry = s.value), result + s.result) }
/** Returns a handshake message only if this is appropriate given the pattern and state. */
fun writeMessage(payload: Payload): State? =
messagePatterns.first().fold(State(this, Data.empty) as State?) { state, token ->
state?.let { s ->
fun mix(local: Pair?, remote: PublicKey?) =
local?.let { l -> remote?.let { r -> s.run { it.mixKey(cryptography.agree(l.second, r)) } } }
when {
token == E -> localEphemeralKeyPair?.let { e -> s.run(e.first.data) { it.mixHash(e.first.data) } }
token == S -> localStaticKeyPair?.let { p -> s.append { it.encryptAndHash(p.first.plaintext) } }
token == EE -> mix(localEphemeralKeyPair, remoteEphemeralKey)
token == ES && role == INITIATOR -> mix(localEphemeralKeyPair, remoteStaticKey)
token == ES && role == RESPONDER -> mix(localStaticKeyPair, remoteEphemeralKey)
token == SE && role == INITIATOR -> mix(localStaticKeyPair, remoteEphemeralKey)
token == SE && role == RESPONDER -> mix(localEphemeralKeyPair, remoteStaticKey)
token == SS -> mix(localStaticKeyPair, remoteStaticKey)
else -> null
}
}
}
?.append { it.encryptAndHash(Plaintext(payload.data)) }
?.let { s ->
val rest = messagePatterns.drop(1)
if (rest.isEmpty()) s.value.symmetry.split().let {
State(Transport(it.first, it.second, s.value.symmetry.handshakeHash.digest), s.result)
}
else State(s.value.copy(messagePatterns = rest), s.result)
}
/** Returns a handshake message payload only if this is appropriate given the pattern and state. */
fun readMessage(data: Data): State? =
messagePatterns.first().fold(State(this, data) as State?) { state, token ->
state?.let { s ->
fun mix(f: (Handshake) -> Pair?, g: (Handshake) -> PublicKey?) =
f(s.value)?.let { l ->
g(s.value)?.let { r -> s.run { it.mixKey(cryptography.agree(l.second, r)) } }
}
fun read(size: Size?, f: (Data) -> Handshake?): State? =
size?.let { s.result.readFirst(it) }?.let { v -> f(v.first)?.let { s.copy(it, v.second) } }
when {
token == E && s.value.remoteEphemeralKey == null -> read(SharedSecret.SIZE) {
s.value.copy(symmetry = s.value.symmetry.mixHash(it), remoteEphemeralKey = PublicKey(it))
}
token == S && s.value.remoteStaticKey == null -> read(SharedSecret.SIZE + Size(16u)) { r ->
s.value.symmetry.decryptAndHash(Ciphertext(r))?.let { ss ->
trustedStaticKeys.firstOrNull { it == PublicKey(ss.result.data) }
?.let { s.value.copy(symmetry = ss.value, remoteStaticKey = it) }
}
}
token == EE -> mix({ it.localEphemeralKeyPair }, { it.remoteEphemeralKey })
token == ES && role == INITIATOR -> mix({ it.localEphemeralKeyPair }, { it.remoteStaticKey })
token == ES && role == RESPONDER -> mix({ it.localStaticKeyPair }, { it.remoteEphemeralKey })
token == SE && role == INITIATOR -> mix({ it.localStaticKeyPair }, { it.remoteEphemeralKey })
token == SE && role == RESPONDER -> mix({ it.localEphemeralKeyPair }, { it.remoteStaticKey })
token == SS -> mix({ it.localStaticKeyPair }, { it.remoteStaticKey })
else -> null
}
}
}?.let {
it.value.symmetry.decryptAndHash(Ciphertext(it.result))?.let { decrypted ->
State(it.value.copy(symmetry = decrypted.value), Payload(decrypted.result.data))
}
}?.let { s ->
val rest = messagePatterns.drop(1)
if (rest.isEmpty()) s.value.symmetry.split()
.let { State(Transport(it.first, it.second, s.value.symmetry.handshakeHash.digest), s.result) }
else State(s.value.copy(messagePatterns = rest), s.result)
}
companion object {
/** Returns initial handshake state only if sufficient keys are provided. */
fun initialize(
cryptography: Cryptography,
pattern: Pattern,
role: Role,
prologue: Data,
localStaticKeyPair: Pair? = null,
localEphemeralKeyPair: Pair? = null,
remoteStaticKey: PublicKey? = null,
trustedStaticKeys: Set = emptySet()
): Handshake? = pattern.preSharedMessagePatterns.foldIndexed(
Symmetry.initialize(cryptography, pattern.name).mixHash(prologue) as Symmetry?
) { index, state, p ->
p.fold(state) { s, t ->
when {
index == 0 && t == S && role == INITIATOR -> localStaticKeyPair?.let { s?.mixHash(it.first.data) }
index == 0 && t == S && role == RESPONDER -> remoteStaticKey?.let { s?.mixHash(it.data) }
index == 1 && t == S && role == RESPONDER -> localStaticKeyPair?.let { s?.mixHash(it.first.data) }
index == 1 && t == S && role == INITIATOR -> remoteStaticKey?.let { s?.mixHash(it.data) }
else -> null
}
}
}?.let {
Handshake(
role,
it,
pattern.messagePatterns,
localStaticKeyPair,
localEphemeralKeyPair,
remoteStaticKey,
null,
trustedStaticKeys
)
}
val Noise_XN_25519_ChaChaPoly_SHA256 =
Pattern(
"Noise_XN_25519_ChaChaPoly_SHA256",
listOf(),
listOf(listOf(E), listOf(E, EE), listOf(S, SE))
)
val Noise_NK_25519_ChaChaPoly_SHA256 =
Pattern(
"Noise_NK_25519_ChaChaPoly_SHA256",
listOf(listOf(), listOf(S)),
listOf(listOf(E, ES), listOf(E, EE))
)
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy