All Downloads are FREE. Search and download functionalities are using the official Maven repository.

keyvault.KeyVault.kt Maven / Gradle / Ivy

There is a newer version: 0.3.1
Show newest version
package se.wollan.crypto.keyvault

import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import org.slf4j.Logger
import se.wollan.crypto.*
import se.wollan.crypto.KeyStretchIterations
import se.wollan.crypto.KeyStretcher
import se.wollan.datascope.DataScope
import se.wollan.time.HLCTimestamp
import se.wollan.time.HybridLogicalClock
import kotlin.random.Random
import kotlin.time.Duration
import kotlin.time.Duration.Companion.hours

/**
 * Use pincode to protect and safely store a single secret key.
 */
interface KeyVault {

    suspend fun isUnlocked(): Boolean

    /** True if secret key exists at all, locked or unlocked. */
    suspend fun hasSecretKeyInVault(): Boolean

    /** Overwrites any old secretKey and pincode. Vault is now locked. */
    suspend fun replaceSecretKey(secretKey: SecretKey, pincode: Pincode)

    /**
     * In-memory state only, pin is never saved to disk.
     * Unlocked until app restart or lock() is called.
     * Will try to read key with this, else InvalidVaultPinCodeException.
     * If no key then EmptyKeyVaultException.
     */
    suspend fun unlock(pincode: Pincode)

    fun lock()

    /** Requires vault to be unlocked, else KeyVaultIsLockedException. */
    suspend fun getSecretKey(): SecretKey
}

class InvalidVaultPincodeException : Exception()
class KeyVaultIsLockedException : Exception()
class EmptyKeyVaultException : Exception()

private const val STORE_KEY_CIPHERMESSAGE = "keyVaultCiphermessage"
private const val STORE_KEY_SALT = "keyVaultSalt"
private const val STORE_KEY_ITERATIONS = "keyVaultIterations"
val KEY_VAULT_UNLOCK_PERIOD: Duration = 1.hours

internal class KeyVaultImpl(
    private val keyVaultRepo: KeyVaultRepo,
    private val encryptor: SymmetricEncryptor,
    private val keyStretcher: KeyStretcher,
    private val random: Random,
    private val dataScope: DataScope,
    private val keyStretchIterations: KeyStretchIterations,
    private val logger: Logger,
    private val clock: HybridLogicalClock
) : KeyVault {

    @Volatile
    private var vaultState: VaultState = VaultState.Locked

    override suspend fun isUnlocked(): Boolean = when (val state = vaultState) {
        VaultState.Locked -> false
        is VaultState.Unlocked -> {
            val unlocked = clock.tick() - state.unlockedAt <= KEY_VAULT_UNLOCK_PERIOD
            if (!unlocked) {
                // Update to reflect real state. Basically an optimization.
                // Also to not keep secret key in memory longer than necessary.
                lock()
            }

            unlocked
        }
    }

    override suspend fun hasSecretKeyInVault(): Boolean =
        isUnlocked() || keyVaultRepo.hasKey(STORE_KEY_CIPHERMESSAGE)

    override suspend fun replaceSecretKey(secretKey: SecretKey, pincode: Pincode) = withContext(Dispatchers.Default) {
        val newSalt = Salt.random(random)
        val encryptionKey = keyStretcher.stretch(pincode, newSalt, keyStretchIterations)
        val ciphermessage = encryptor.encrypt(plaintext = secretKey.value, encryptionKey)

        dataScope.write {
            keyVaultRepo.putBytes(STORE_KEY_CIPHERMESSAGE, ciphermessage)
            keyVaultRepo.putBytes(STORE_KEY_SALT, newSalt.value)
            keyVaultRepo.putInt(STORE_KEY_ITERATIONS, keyStretchIterations.value)
        }

        secretKey.clear()
        pincode.clear()
        lock()
    }

    override suspend fun unlock(pincode: Pincode) {
        if (isUnlocked())
            return

        withContext(Dispatchers.Default) {
            val (ciphermessage, salt, iterations) = dataScope.read {
                val ciphermessage = keyVaultRepo.getBytes(STORE_KEY_CIPHERMESSAGE) ?: throw EmptyKeyVaultException()
                val salt = keyVaultRepo.getBytes(STORE_KEY_SALT)?.let { Salt(it) } ?: throw EmptyKeyVaultException()
                val iterations = keyVaultRepo.getInt(STORE_KEY_ITERATIONS)?.let { KeyStretchIterations(it) }
                    ?: KeyStretchIterations.default
                Triple(ciphermessage, salt, iterations)
            }

            val decryptionKey = keyStretcher.stretch(pincode, salt, iterations)

            val secretKey = SecretKey(try {
                encryptor.decrypt(ciphermessage, decryptionKey)
            } catch (_: BadSecretKeyException) {
                throw InvalidVaultPincodeException()
            })

            if (iterations != keyStretchIterations) {
                logger.info("changing iteration count from $iterations to $keyStretchIterations...")
                replaceSecretKey(secretKey.copy(), pincode)
            }

            pincode.clear()
            vaultState = VaultState.Unlocked(secretKey, clock.tick())
        }
    }

    override fun lock() = when (val state = vaultState) {
        VaultState.Locked -> Unit
        is VaultState.Unlocked -> {
            state.secretKey.clear()
            vaultState = VaultState.Locked
        }
    }

    override suspend fun getSecretKey(): SecretKey {
        val state = vaultState as? VaultState.Unlocked
        if (isUnlocked() && state != null)
            return state.secretKey

        throw KeyVaultIsLockedException()
    }
}

private sealed interface VaultState {
    object Locked : VaultState
    data class Unlocked(val secretKey: SecretKey, val unlockedAt: HLCTimestamp) : VaultState
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy