All Downloads are FREE. Search and download functionalities are using the official Maven repository.

commonMain.net.folivo.trixnity.client.key.KeyTrustService.kt Maven / Gradle / Ivy

There is a newer version: 4.7.1
Show newest version
package net.folivo.trixnity.client.key

import io.github.oshai.kotlinlogging.KotlinLogging
import kotlinx.coroutines.flow.*
import net.folivo.trixnity.client.store.*
import net.folivo.trixnity.client.store.KeySignatureTrustLevel.*
import net.folivo.trixnity.clientserverapi.client.MatrixClientServerApiClient
import net.folivo.trixnity.core.UserInfo
import net.folivo.trixnity.core.model.UserId
import net.folivo.trixnity.core.model.events.m.crosssigning.MasterKeyEventContent
import net.folivo.trixnity.core.model.events.m.secretstorage.SecretKeyEventContent
import net.folivo.trixnity.core.model.keys.*
import net.folivo.trixnity.core.model.keys.CrossSigningKeysUsage.MasterKey
import net.folivo.trixnity.core.model.keys.Key.Ed25519Key
import net.folivo.trixnity.crypto.SecretType
import net.folivo.trixnity.crypto.SecretType.M_CROSS_SIGNING_SELF_SIGNING
import net.folivo.trixnity.crypto.SecretType.M_CROSS_SIGNING_USER_SIGNING
import net.folivo.trixnity.crypto.key.decryptSecret
import net.folivo.trixnity.crypto.sign.*
import net.folivo.trixnity.olm.OlmPkSigning
import net.folivo.trixnity.olm.freeAfter
import net.folivo.trixnity.utils.decodeUnpaddedBase64Bytes
import kotlin.collections.component1
import kotlin.collections.component2
import kotlin.jvm.JvmName

private val log = KotlinLogging.logger {}

interface KeyTrustService {

    suspend fun checkOwnAdvertisedMasterKeyAndVerifySelf(
        key: ByteArray,
        keyId: String,
        keyInfo: SecretKeyEventContent
    ): Result

    suspend fun trustAndSignKeys(keys: Set, userId: UserId)
    suspend fun calculateDeviceKeysTrustLevel(deviceKeys: SignedDeviceKeys): KeySignatureTrustLevel
    suspend fun calculateCrossSigningKeysTrustLevel(crossSigningKeys: SignedCrossSigningKeys): KeySignatureTrustLevel

    suspend fun updateTrustLevelOfKeyChainSignedBy(
        signingUserId: UserId,
        signingKey: Ed25519Key,
    )
}

class KeyTrustServiceImpl(
    private val userInfo: UserInfo,
    private val keyStore: KeyStore,
    private val globalAccountDataStore: GlobalAccountDataStore,
    private val signService: SignService,
    private val api: MatrixClientServerApiClient,
) : KeyTrustService {

    override suspend fun checkOwnAdvertisedMasterKeyAndVerifySelf(
        key: ByteArray,
        keyId: String,
        keyInfo: SecretKeyEventContent
    ): Result {
        val encryptedMasterKey = globalAccountDataStore.get().first()?.content
            ?: return Result.failure(MasterKeyInvalidException("could not find encrypted master key"))
        val decryptedPublicKey =
            kotlin.runCatching {
                decryptSecret(key, keyId, keyInfo, "m.cross_signing.master", encryptedMasterKey, api.json)
            }.getOrNull()
                ?.let { privateKey ->
                    freeAfter(OlmPkSigning.create(privateKey)) { it.publicKey }
                }
        val advertisedPublicKey =
            keyStore.getCrossSigningKey(userInfo.userId, MasterKey)?.value?.signed?.get()
        return if (advertisedPublicKey?.value?.decodeUnpaddedBase64Bytes()
                ?.contentEquals(decryptedPublicKey?.decodeUnpaddedBase64Bytes()) == true
        ) {
            val ownDeviceKeys =
                keyStore.getDeviceKey(userInfo.userId, userInfo.deviceId).first()?.value?.get()
            kotlin.runCatching {
                trustAndSignKeys(setOfNotNull(advertisedPublicKey, ownDeviceKeys), userInfo.userId)
            }
        } else Result.failure(MasterKeyInvalidException("master public key $decryptedPublicKey did not match the advertised ${advertisedPublicKey?.value}"))
    }

    override suspend fun updateTrustLevelOfKeyChainSignedBy(
        signingUserId: UserId,
        signingKey: Ed25519Key,
    ) {
        updateTrustLevelOfKeyChainSignedBy(signingUserId, signingKey, mutableSetOf())
    }

    private suspend fun updateTrustLevelOfKeyChainSignedBy(
        signingUserId: UserId,
        signingKey: Ed25519Key,
        visitedKeys: MutableSet>
    ) {
        log.trace { "update trust level of all keys signed by $signingUserId $signingKey" }
        visitedKeys.add(signingUserId to signingKey.keyId)
        keyStore.getKeyChainLinksBySigningKey(signingUserId, signingKey)
            .filterNot { visitedKeys.contains(it.signedUserId to it.signedKey.keyId) }
            .forEach { keyChainLink ->
                updateTrustLevelOfKey(keyChainLink.signedUserId, keyChainLink.signedKey)
                updateTrustLevelOfKeyChainSignedBy(keyChainLink.signedUserId, keyChainLink.signedKey, visitedKeys)
            }
    }

    private suspend fun updateTrustLevelOfKey(userId: UserId, key: Ed25519Key) {
        val keyId = key.keyId

        if (keyId != null) {
            val foundKey = MutableStateFlow(false)

            keyStore.updateDeviceKeys(userId) { oldDeviceKeys ->
                val foundDeviceKeys = oldDeviceKeys?.get(keyId)
                if (foundDeviceKeys != null) {
                    val newTrustLevel = calculateDeviceKeysTrustLevel(foundDeviceKeys.value)
                    foundKey.value = true
                    log.trace { "updated device keys ${foundDeviceKeys.value.signed.deviceId} of user $userId with trust level $newTrustLevel" }
                    oldDeviceKeys + (keyId to foundDeviceKeys.copy(trustLevel = newTrustLevel))
                } else oldDeviceKeys
            }
            if (foundKey.value.not()) {
                keyStore.updateCrossSigningKeys(userId) { oldKeys ->
                    val foundCrossSigningKeys = oldKeys?.firstOrNull { keys ->
                        keys.value.signed.keys.keys.filterIsInstance().any { it.keyId == keyId }
                    }
                    if (foundCrossSigningKeys != null) {
                        val newTrustLevel = calculateCrossSigningKeysTrustLevel(foundCrossSigningKeys.value)
                        foundKey.value = true
                        log.trace { "updated cross signing key ${foundCrossSigningKeys.value.signed.usage.firstOrNull()?.name} of user $userId with trust level $newTrustLevel" }
                        (oldKeys - foundCrossSigningKeys) + foundCrossSigningKeys.copy(trustLevel = newTrustLevel)

                    } else oldKeys
                }
            }
            if (foundKey.value.not()) log.warn { "could not find device or cross signing keys of $key" }
        } else log.warn { "could not update trust level, because key id of $key was null" }
    }

    override suspend fun calculateDeviceKeysTrustLevel(deviceKeys: SignedDeviceKeys): KeySignatureTrustLevel {
        log.trace { "calculate trust level for ${deviceKeys.signed}" }
        val userId = deviceKeys.signed.userId
        val signedKey = deviceKeys.signed.keys.get()
            ?: return Invalid("missing ed25519 key")
        return calculateTrustLevel(
            userId,
            { signService.verify(deviceKeys, it) },
            signedKey,
            deviceKeys.signatures ?: mapOf(),
            deviceKeys.getVerificationState(),
            false
        ).also { log.trace { "calculated trust level of ${deviceKeys.signed} from $userId is $it" } }
    }

    override suspend fun calculateCrossSigningKeysTrustLevel(crossSigningKeys: SignedCrossSigningKeys): KeySignatureTrustLevel {
        log.trace { "calculate trust level for ${crossSigningKeys.signed}" }
        val userId = crossSigningKeys.signed.userId
        val signedKey = crossSigningKeys.signed.keys.get()
            ?: return Invalid("missing ed25519 key")
        return calculateTrustLevel(
            userId,
            { signService.verify(crossSigningKeys, it) },
            signedKey,
            crossSigningKeys.signatures ?: mapOf(),
            crossSigningKeys.getVerificationState(),
            crossSigningKeys.signed.usage.contains(MasterKey)
        ).also { log.trace { "calculated trust level of ${crossSigningKeys.signed} from $userId is $it" } }
    }

    private suspend fun calculateTrustLevel(
        userId: UserId,
        verifySignedObject: suspend (signingKeys: Map>) -> VerifyResult,
        signedKey: Ed25519Key,
        signatures: Signatures,
        keyVerificationState: KeyVerificationState?,
        isMasterKey: Boolean
    ): KeySignatureTrustLevel {
        val masterKey = keyStore.getCrossSigningKey(userId, MasterKey)
        return when {
            keyVerificationState is KeyVerificationState.Verified && isMasterKey -> CrossSigned(true)
            keyVerificationState is KeyVerificationState.Verified && (masterKey == null) -> Valid(true)
            keyVerificationState is KeyVerificationState.Blocked -> Blocked
            else -> searchSignaturesForTrustLevel(userId, verifySignedObject, signedKey, signatures)
                ?: when {
                    isMasterKey -> CrossSigned(false)
                    else -> if (masterKey == null) Valid(false) else NotCrossSigned
                }
        }
    }

    private suspend fun searchSignaturesForTrustLevel(
        signedUserId: UserId,
        verifySignedObject: suspend (signingKeys: Map>) -> VerifyResult,
        signedKey: Ed25519Key,
        signatures: Signatures,
        visitedKeys: MutableSet> = mutableSetOf()
    ): KeySignatureTrustLevel? {
        log.trace { "search in signatures of $signedKey for trust level calculation: $signatures" }
        visitedKeys.add(signedUserId to signedKey.keyId)
        keyStore.deleteKeyChainLinksBySignedKey(signedUserId, signedKey)
        val states = signatures.flatMap { (signingUserId, signatureKeys) ->
            signatureKeys
                .filterIsInstance()
                .filterNot { visitedKeys.contains(signingUserId to it.keyId) }
                .flatMap { signatureKey ->
                    visitedKeys.add(signingUserId to signatureKey.keyId)

                    val crossSigningKey =
                        signatureKey.keyId?.let { keyStore.getCrossSigningKey(signingUserId, it) }?.value
                    val signingCrossSigningKey = crossSigningKey?.signed?.get()
                    val crossSigningKeyState = if (signingCrossSigningKey != null) {
                        val isValid = verifySignedObject(mapOf(signingUserId to setOf(signingCrossSigningKey)))
                            .also { v ->
                                if (v != VerifyResult.Valid)
                                    log.warn { "signature was $v for key chain $signingCrossSigningKey ($signingUserId) ---> $signedKey ($signedUserId)" }
                            } == VerifyResult.Valid
                        if (isValid) when (crossSigningKey.getVerificationState()) {
                            is KeyVerificationState.Verified -> CrossSigned(true)
                            is KeyVerificationState.Blocked -> Blocked
                            else -> {
                                searchSignaturesForTrustLevel(
                                    signingUserId,
                                    { signService.verify(crossSigningKey, it) },
                                    signingCrossSigningKey,
                                    crossSigningKey.signatures ?: mapOf(),
                                    visitedKeys
                                ) ?: if (crossSigningKey.signed.usage.contains(MasterKey)
                                    && crossSigningKey.signed.userId == signedUserId
                                    && crossSigningKey.signed.userId == signingUserId
                                ) CrossSigned(false) else null
                            }
                        } else null
                    } else null

                    val deviceKey = signatureKey.keyId?.let { keyStore.getDeviceKey(signingUserId, it).first() }?.value
                    val signingDeviceKey = deviceKey?.get()
                    val deviceKeyState = if (signingDeviceKey != null) {
                        val isValid = verifySignedObject(mapOf(signingUserId to setOf(signingDeviceKey)))
                            .also { v ->
                                if (v != VerifyResult.Valid)
                                    log.warn { "signature was $v for key chain $signingCrossSigningKey ($signingUserId) ---> $signedKey ($signedUserId)" }
                            } == VerifyResult.Valid
                        if (isValid) when (deviceKey.getVerificationState()) {
                            is KeyVerificationState.Verified -> CrossSigned(true)
                            is KeyVerificationState.Blocked -> Blocked
                            else -> searchSignaturesForTrustLevel(
                                signedUserId,
                                { signService.verify(deviceKey, it) },
                                signingDeviceKey,
                                deviceKey.signatures ?: mapOf(),
                                visitedKeys
                            )
                        } else null
                    } else null

                    val signingKey = signingCrossSigningKey ?: signingDeviceKey
                    if (signingKey != null) {
                        keyStore.saveKeyChainLink(KeyChainLink(signingUserId, signingKey, signedUserId, signedKey))
                    }

                    listOf(crossSigningKeyState, deviceKeyState)
                }.toSet()
        }.toSet()
        return when {
            states.any { it is CrossSigned && it.verified } -> CrossSigned(true)
            states.any { it is CrossSigned && !it.verified } -> CrossSigned(false)
            states.contains(Blocked) -> Blocked
            else -> null
        }
    }

    private suspend fun signWithSecret(type: SecretType): SignWith.PrivateKey {
        val privateKey = keyStore.getSecrets()[type]?.decryptedPrivateKey
        requireNotNull(privateKey) { "could not find private key of $type" }
        val publicKey =
            keyStore.getCrossSigningKey(
                userInfo.userId,
                when (type) {
                    M_CROSS_SIGNING_SELF_SIGNING -> CrossSigningKeysUsage.SelfSigningKey
                    M_CROSS_SIGNING_USER_SIGNING -> CrossSigningKeysUsage.UserSigningKey
                    else -> throw IllegalArgumentException("cannot sign with $type")
                }
            )?.value?.signed?.get()?.keyId
        requireNotNull(publicKey) { "could not find public key of $type" }
        return SignWith.PrivateKey(privateKey, publicKey)
    }

    override suspend fun trustAndSignKeys(keys: Set, userId: UserId) {
        log.debug { "sign keys (when possible): $keys" }
        val signedDeviceKeys = keys.mapNotNull { key ->
            val deviceKey = key.keyId?.let { keyStore.getDeviceKey(userId, it).first() }?.value?.signed
            if (deviceKey != null) {
                keyStore.saveKeyVerificationState(key, KeyVerificationState.Verified(key.value))
                updateTrustLevelOfKey(userId, key)
                try {
                    if (userId == userInfo.userId && deviceKey.get() == key) {
                        signService.sign(deviceKey, signWithSecret(M_CROSS_SIGNING_SELF_SIGNING))
                            .also { log.info { "signed own accounts device with own self signing key" } }
                    } else null
                } catch (error: Exception) {
                    log.warn { "could not sign device key $key with self signing key: ${error.message}" }
                    null
                }
            } else null
        }
        val signedCrossSigningKeys = keys.mapNotNull { key ->
            val crossSigningKey = key.keyId?.let { keyStore.getCrossSigningKey(userId, it) }?.value?.signed
            if (crossSigningKey != null) {
                keyStore.saveKeyVerificationState(key, KeyVerificationState.Verified(key.value))
                updateTrustLevelOfKey(userId, key)
                if (crossSigningKey.usage.contains(MasterKey)) {
                    if (crossSigningKey.get() == key) {
                        if (userId == userInfo.userId) {
                            try {
                                signService.sign(crossSigningKey, SignWith.DeviceKey)
                                    .also { log.info { "signed own master key with own device key" } }
                            } catch (error: Exception) {
                                log.warn { "could not sign own master key $key with device key: ${error.message}" }
                                null
                            }
                        } else {
                            try {
                                signService.sign(crossSigningKey, signWithSecret(M_CROSS_SIGNING_USER_SIGNING))
                                    .also { log.info { "signed other users master key with own user signing key" } }
                            } catch (error: Exception) {
                                log.warn { "could not sign other users master key $key with user signing key: ${error.message}" }
                                null
                            }
                        }
                    } else null
                } else null
            } else null
        }
        if (signedDeviceKeys.isNotEmpty() || signedCrossSigningKeys.isNotEmpty()) {
            log.debug { "upload signed keys: ${signedDeviceKeys + signedCrossSigningKeys}" }
            val response = api.key.addSignatures(signedDeviceKeys.toSet(), signedCrossSigningKeys.toSet()).getOrThrow()
            if (response.failures.isNotEmpty()) {
                log.error { "could not add signatures to server: ${response.failures}" }
                throw UploadSignaturesException(response.failures.toString())
            }
        }
    }

    private suspend fun Keys.getVerificationState() =
        this.asFlow().mapNotNull { keyStore.getKeyVerificationState(it) }.firstOrNull()

    @JvmName("getVerificationStateCsk")
    private suspend fun SignedCrossSigningKeys.getVerificationState() =
        this.signed.keys.getVerificationState()

    @JvmName("getVerificationStateDk")
    private suspend fun SignedDeviceKeys.getVerificationState() =
        this.signed.keys.getVerificationState()
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy