All Downloads are FREE. Search and download functionalities are using the official Maven repository.

commonMain.fr.acinq.lightning.payment.Bolt11Invoice.kt Maven / Gradle / Ivy

There is a newer version: 1.8.4
Show newest version
package fr.acinq.lightning.payment

import fr.acinq.bitcoin.*
import fr.acinq.bitcoin.Script.tail
import fr.acinq.bitcoin.io.ByteArrayInput
import fr.acinq.bitcoin.io.ByteArrayOutput
import fr.acinq.bitcoin.utils.Either
import fr.acinq.bitcoin.utils.Try
import fr.acinq.bitcoin.utils.runTrying
import fr.acinq.lightning.*
import fr.acinq.lightning.Lightning.randomBytes32
import fr.acinq.lightning.utils.*
import fr.acinq.lightning.wire.LightningCodecs
import kotlin.experimental.and

data class Bolt11Invoice(
    val prefix: String,
    override val amount: MilliSatoshi?,
    val timestampSeconds: Long,
    override val nodeId: PublicKey,
    val tags: List,
    val signature: ByteVector
) : PaymentRequest() {
    val chain: Chain? get() = prefixes.entries.firstOrNull { it.value == prefix }?.key

    override val paymentHash: ByteVector32 get() = tags.find { it is TaggedField.PaymentHash }!!.run { (this as TaggedField.PaymentHash).hash }

    val paymentSecret: ByteVector32 get() = tags.find { it is TaggedField.PaymentSecret }!!.run { (this as TaggedField.PaymentSecret).secret }

    val paymentMetadata: ByteVector? get() = tags.find { it is TaggedField.PaymentMetadata }?.run { (this as TaggedField.PaymentMetadata).data }

    val description: String? get() = tags.find { it is TaggedField.Description }?.run { (this as TaggedField.Description).description }

    val descriptionHash: ByteVector32? get() = tags.find { it is TaggedField.DescriptionHash }?.run { (this as TaggedField.DescriptionHash).hash }

    val expirySeconds: Long? get() = tags.find { it is TaggedField.Expiry }?.run { (this as TaggedField.Expiry).expirySeconds }

    val minFinalExpiryDelta: CltvExpiryDelta? get() = tags.find { it is TaggedField.MinFinalCltvExpiry }?.run { CltvExpiryDelta((this as TaggedField.MinFinalCltvExpiry).cltvExpiry.toInt()) }

    val fallbackAddress: String? = tags.find { it is TaggedField.FallbackAddress }?.run { (this as TaggedField.FallbackAddress).toAddress(prefix) }

    override val features: Features get() = tags.filterIsInstance().firstOrNull()?.run { Features(this.bits) } ?: Features.empty

    val routingInfo: List = tags.filterIsInstance()

    init {
        val f = features.invoiceFeatures()
        require(f.hasFeature(Feature.VariableLengthOnion)) { "${Feature.VariableLengthOnion.rfcName} must be supported" }
        require(f.hasFeature(Feature.PaymentSecret)) { "${Feature.PaymentSecret.rfcName} must be supported" }
        require(Features.validateFeatureGraph(f) == null)

        require(amount == null || amount > 0.msat) { "amount is not valid" }
        require(tags.filterIsInstance().size == 1) { "there must be exactly one payment hash tag" }
        require(tags.filterIsInstance().size == 1) { "there must be exactly one payment secret tag" }
        require(description != null || descriptionHash != null) { "there must be exactly one description tag or one description hash tag" }
    }

    override fun isExpired(currentTimestampSeconds: Long): Boolean = when (val expirySeconds = expirySeconds) {
        null -> timestampSeconds + DEFAULT_EXPIRY_SECONDS <= currentTimestampSeconds
        else -> timestampSeconds + expirySeconds <= currentTimestampSeconds
    }

    private fun hrp() = prefix + encodeAmount(amount)

    private fun rawData(): List {
        val data5 = ArrayList()
        data5.addAll(encodeTimestamp(timestampSeconds))
        tags
            .filterNot { it is TaggedField.Features && it.bits.isEmpty() }
            .forEach {
                val encoded = it.encode()
                val len = encoded.size
                data5.add(it.tag)
                data5.add((len / 32).toByte())
                data5.add((len.rem(32)).toByte())
                data5.addAll(encoded)
            }
        return data5
    }

    private fun signedPreimage(): ByteArray {
        return hrp().encodeToByteArray() + toByteArray(rawData())
    }

    private fun signedHash(): ByteVector32 = Crypto.sha256(signedPreimage()).toByteVector32()

    /**
     * Sign a payment request.
     *
     * @param privateKey private key, which must match the payment request's node id
     * @return a signature (64 bytes) plus a recovery id (1 byte)
     */
    fun sign(privateKey: PrivateKey): Bolt11Invoice {
        require(privateKey.publicKey() == nodeId) { "private key does not match node id" }
        val msg = signedHash()
        val sig = Crypto.sign(msg, privateKey)
        val (pub1, _) = Crypto.recoverPublicKey(sig, msg.toByteArray())
        val recid = if (nodeId == pub1) 0.toByte() else 1.toByte()
        return this.copy(signature = sig.concat(recid))
    }

    override fun write(): String {
        val signature5 = Bech32.eight2five(signature.toByteArray())
        return Bech32.encode(hrp(), rawData().toTypedArray() + signature5, Bech32.Encoding.Bech32)
    }

    companion object {
        const val DEFAULT_EXPIRY_SECONDS = 3600
        val DEFAULT_MIN_FINAL_EXPIRY_DELTA = CltvExpiryDelta(18)

        private val prefixes = mapOf(
            Chain.Regtest to "lnbcrt",
            Chain.Testnet3 to "lntb",
            Chain.Testnet4 to "lntb",
            Chain.Mainnet to "lnbc"
        )

        fun create(
            chain: Chain,
            amount: MilliSatoshi?,
            paymentHash: ByteVector32,
            privateKey: PrivateKey,
            description: Either,
            minFinalCltvExpiryDelta: CltvExpiryDelta,
            features: Features,
            paymentSecret: ByteVector32 = randomBytes32(),
            paymentMetadata: ByteVector? = null,
            expirySeconds: Long? = null,
            extraHops: List> = listOf(),
            timestampSeconds: Long = currentTimestampSeconds()
        ): Bolt11Invoice {
            val prefix = prefixes[chain] ?: error("unknown chain hash")
            val tags = mutableListOf(
                TaggedField.PaymentHash(paymentHash),
                TaggedField.MinFinalCltvExpiry(minFinalCltvExpiryDelta.toLong()),
                TaggedField.PaymentSecret(paymentSecret),
                // We remove unknown features which could make the invoice too big.
                TaggedField.Features(features.invoiceFeatures().copy(unknown = setOf()).toByteArray().toByteVector())
            )
            description.left?.let { tags.add(TaggedField.Description(it)) }
            description.right?.let { tags.add(TaggedField.DescriptionHash(it)) }
            paymentMetadata?.let { tags.add(TaggedField.PaymentMetadata(it)) }
            expirySeconds?.let { tags.add(TaggedField.Expiry(it)) }
            if (extraHops.isNotEmpty()) {
                extraHops.forEach { tags.add(TaggedField.RoutingInfo(it)) }
            }

            return Bolt11Invoice(
                prefix = prefix,
                amount = amount,
                timestampSeconds = timestampSeconds,
                nodeId = privateKey.publicKey(),
                tags = tags,
                signature = ByteVector.empty
            ).sign(privateKey)
        }

        private fun decodeTimestamp(input: List): Long = input.take(7).fold(0L) { a, b -> 32 * a + b }

        fun encodeTimestamp(input: Long): List {
            tailrec fun loop(value: Long, acc: List): List = if (acc.size == 7) acc.reversed() else loop(value / 32, acc + value.rem(32).toByte())
            return loop(input, listOf())
        }

        fun read(input: String): Try = runTrying {
            val (hrp, data) = Bech32.decode(input)
            val prefix = prefixes.values.find { hrp.startsWith(it) } ?: throw IllegalArgumentException("unknown prefix $hrp")
            val amount = decodeAmount(hrp.drop(prefix.length))
            val timestamp = decodeTimestamp(data.toList())
            // signature and recovery id, encoded on 65 bytes = 5 * 13 bytes = 5 * 13 * 8 bits =  8 * 13 "5-bits integers"
            val sigandrecid = toByteArray(data.copyOfRange(data.size - 8 * 13, data.size).toList())
            val sig = sigandrecid.dropLast(1).toByteArray().toByteVector64()
            val recid = sigandrecid.last()
            val data1 = toByteArray(data.copyOfRange(0, data.size - 8 * 13).toList())
            val tohash = hrp.encodeToByteArray() + data1
            val msg = Crypto.sha256(tohash)
            val nodeId = Crypto.recoverPublicKey(sig, msg, recid.toInt())
            val check = Crypto.verifySignature(msg, sig, nodeId)
            require(check) { "invalid signature" }

            val tags = ArrayList()

            tailrec fun loop(input: List) {
                if (input.isNotEmpty()) {
                    val tag = input[0]
                    val len = 32 * input[1] + input[2]
                    val value = input.drop(3).take(len)
                    when (tag) {
                        TaggedField.PaymentHash.tag -> tags.add(kotlin.runCatching { TaggedField.PaymentHash.decode(value) }.getOrDefault(TaggedField.InvalidTag(tag, value)))
                        TaggedField.PaymentSecret.tag -> tags.add(kotlin.runCatching { TaggedField.PaymentSecret.decode(value) }.getOrDefault(TaggedField.InvalidTag(tag, value)))
                        TaggedField.PaymentMetadata.tag -> tags.add(kotlin.runCatching { TaggedField.PaymentMetadata.decode(value) }.getOrDefault(TaggedField.InvalidTag(tag, value)))
                        TaggedField.Description.tag -> tags.add(kotlin.runCatching { TaggedField.Description.decode(value) }.getOrDefault(TaggedField.InvalidTag(tag, value)))
                        TaggedField.DescriptionHash.tag -> tags.add(kotlin.runCatching { TaggedField.DescriptionHash.decode(value) }.getOrDefault(TaggedField.InvalidTag(tag, value)))
                        TaggedField.Expiry.tag -> tags.add(kotlin.runCatching { TaggedField.Expiry.decode(value) }.getOrDefault(TaggedField.InvalidTag(tag, value)))
                        TaggedField.MinFinalCltvExpiry.tag -> tags.add(kotlin.runCatching { TaggedField.MinFinalCltvExpiry.decode(value) }.getOrDefault(TaggedField.InvalidTag(tag, value)))
                        TaggedField.FallbackAddress.tag -> tags.add(kotlin.runCatching { TaggedField.FallbackAddress.decode(value) }.getOrDefault(TaggedField.InvalidTag(tag, value)))
                        TaggedField.Features.tag -> tags.add(kotlin.runCatching { TaggedField.Features.decode(value) }.getOrDefault(TaggedField.InvalidTag(tag, value)))
                        TaggedField.RoutingInfo.tag -> tags.add(kotlin.runCatching { TaggedField.RoutingInfo.decode(value) }.getOrDefault(TaggedField.InvalidTag(tag, value)))
                        else -> tags.add(TaggedField.UnknownTag(tag, value))
                    }
                    loop(input.drop(3 + len))
                }
            }

            loop(data.drop(7).dropLast(104))
            val pr = Bolt11Invoice(prefix, amount, timestamp, nodeId, tags, sigandrecid.toByteVector())
            require(pr.signedPreimage().contentEquals(tohash)) { "invoice isn't canonically encoded" }
            pr
        }

        fun decodeAmount(input: String): MilliSatoshi? {
            val amount = when {
                input.isEmpty() -> null
                input.last() == 'p' -> {
                    require(input.endsWith("0p")) { "invalid sub-millisatoshi precision" }
                    MilliSatoshi(input.dropLast(1).toLong() / 10L)
                }

                input.last() == 'n' -> MilliSatoshi(input.dropLast(1).toLong() * 100L)
                input.last() == 'u' -> MilliSatoshi(input.dropLast(1).toLong() * 100000L)
                input.last() == 'm' -> MilliSatoshi(input.dropLast(1).toLong() * 100000000L)
                else -> MilliSatoshi(input.toLong() * 100000000000L)
            }
            return if (amount == MilliSatoshi(0)) null else amount
        }

        /**
         * @return the unit allowing for the shortest representation possible
         */
        fun unit(amount: MilliSatoshi): Char? {
            val pico = amount.toLong() * 10
            return when {
                pico.rem(1_000) > 0 -> 'p'
                pico.rem(1_000_000) > 0 -> 'n'
                pico.rem(1_000_000_000) > 0 -> 'u'
                pico.rem(1_000_000_000_000) > 0 -> 'm'
                else -> null
            }
        }

        fun encodeAmount(amount: MilliSatoshi?): String {
            return when {
                amount == null -> ""
                unit(amount) == 'p' -> "${amount.toLong() * 10}p" // 1 pico-bitcoin == 10 milli-satoshis
                unit(amount) == 'n' -> "${amount.toLong() / 100}n"
                unit(amount) == 'u' -> "${amount.toLong() / 100_000}u"
                unit(amount) == 'm' -> "${amount.toLong() / 100_000_000}m"
                unit(amount) == null -> "${amount.toLong() / 100_000_000_000}"
                else -> throw IllegalArgumentException("invalid amount $amount")
            }
        }

        fun toBits(value: Int5): List = listOf(
            (value and 16) != 0.toByte(),
            (value and 8) != 0.toByte(),
            (value and 4) != 0.toByte(),
            (value and 2) != 0.toByte(),
            (value and 1) != 0.toByte()
        )

        // converts a list of booleans (1 per bit) to a byte, right-padded if there are less than 8 bits
        internal fun toByte(bits: List): Byte {
            require(bits.size <= 8)
            val raw = bits.fold(0) { a, b -> 2 * a + if (b) 1 else 0 }
            val shift = 8 - bits.size
            return (raw.shl(shift) and 0xff).toByte()
        }

        // converts a list of 5 bits values to a byte array
        internal fun toByteArray(int5s: List): ByteArray {
            val allbits = int5s.flatMap { toBits(it) }
            return allbits.windowed(8, 8, partialWindows = true) { toByte(it) }.toByteArray()
        }
    }

    sealed class TaggedField {
        abstract val tag: Int5
        abstract fun encode(): List

        /** @param description a free-format string that will be included in the payment request */
        data class Description(val description: String) : TaggedField() {
            override val tag: Int5 = Description.tag
            override fun encode(): List = Bech32.eight2five(description.encodeToByteArray()).toList()

            companion object {
                const val tag: Int5 = 13
                fun decode(input: List): Description = Description(Bech32.five2eight(input.toTypedArray(), 0).decodeToString())
            }
        }

        /** @param hash sha256 hash of an associated description */
        data class DescriptionHash(val hash: ByteVector32) : TaggedField() {
            override val tag: Int5 = DescriptionHash.tag
            override fun encode(): List = Bech32.eight2five(hash.toByteArray()).toList()

            companion object {
                const val tag: Int5 = 23
                fun decode(input: List): DescriptionHash {
                    require(input.size == 52)
                    return DescriptionHash(Bech32.five2eight(input.toTypedArray(), 0).toByteVector32())
                }
            }
        }

        /** @param hash payment hash */
        data class PaymentHash(val hash: ByteVector32) : TaggedField() {
            override val tag: Int5 = PaymentHash.tag
            override fun encode(): List = Bech32.eight2five(hash.toByteArray()).toList()

            companion object {
                const val tag: Int5 = 1
                fun decode(input: List): PaymentHash {
                    require(input.size == 52)
                    return PaymentHash(Bech32.five2eight(input.toTypedArray(), 0).toByteVector32())
                }
            }
        }

        /** @param secret payment secret */
        data class PaymentSecret(val secret: ByteVector32) : TaggedField() {
            override val tag: Int5 = PaymentSecret.tag
            override fun encode(): List = Bech32.eight2five(secret.toByteArray()).toList()

            companion object {
                const val tag: Int5 = 16
                fun decode(input: List): PaymentSecret {
                    require(input.size == 52)
                    return PaymentSecret(Bech32.five2eight(input.toTypedArray(), 0).toByteVector32())
                }
            }
        }


        data class PaymentMetadata(val data: ByteVector) : TaggedField() {
            override val tag: Int5 = PaymentMetadata.tag
            override fun encode(): List = Bech32.eight2five(data.toByteArray()).toList()

            companion object {
                const val tag: Int5 = 27
                fun decode(input: List): PaymentMetadata = PaymentMetadata(Bech32.five2eight(input.toTypedArray(), 0).toByteVector())
            }
        }

        /** @param expirySeconds payment expiry (in seconds) */
        data class Expiry(val expirySeconds: Long) : TaggedField() {
            override val tag: Int5 = Expiry.tag
            override fun encode(): List {
                tailrec fun loop(value: Long, acc: List): List = if (value == 0L) acc.reversed() else {
                    loop(value / 32, acc + (value.rem(32)).toByte())
                }
                return loop(expirySeconds, listOf())
            }

            companion object {
                const val tag: Int5 = 6
                fun decode(input: List): Expiry {
                    var expiry = 0L
                    input.forEach { expiry = expiry * 32 + it }
                    return Expiry(expiry)
                }
            }
        }

        /** @param cltvExpiry minimum final expiry delta */
        data class MinFinalCltvExpiry(val cltvExpiry: Long) : TaggedField() {
            override val tag: Int5 = MinFinalCltvExpiry.tag
            override fun encode(): List {
                tailrec fun loop(value: Long, acc: List): List = if (value == 0L) acc.reversed() else {
                    loop(value / 32, acc + (value.rem(32)).toByte())
                }
                return loop(cltvExpiry, listOf())
            }

            companion object {
                const val tag: Int5 = 24
                fun decode(input: List): MinFinalCltvExpiry {
                    var expiry = 0L
                    input.forEach { expiry = expiry * 32 + it }
                    return MinFinalCltvExpiry(expiry)
                }
            }
        }

        /** Fallback on-chain payment address to be used if LN payment cannot be processed */
        data class FallbackAddress(val version: Byte, val data: ByteVector) : TaggedField() {
            override val tag: Int5 = FallbackAddress.tag
            override fun encode(): List = listOf(version) + Bech32.eight2five(data.toByteArray()).toList()

            fun toAddress(prefix: String): String = when (version.toInt()) {
                17 -> when (prefix) {
                    "lnbc" -> Base58Check.encode(Base58.Prefix.PubkeyAddress, data)
                    else -> Base58Check.encode(Base58.Prefix.PubkeyAddressTestnet, data)
                }

                18 -> when (prefix) {
                    "lnbc" -> Base58Check.encode(Base58.Prefix.ScriptAddress, data)
                    else -> Base58Check.encode(Base58.Prefix.ScriptAddressTestnet, data)
                }

                else -> when (prefix) {
                    "lnbc" -> Bech32.encodeWitnessAddress("bc", version, data.toByteArray())
                    "lntb" -> Bech32.encodeWitnessAddress("tb", version, data.toByteArray())
                    "lnbcrt" -> Bech32.encodeWitnessAddress("bcrt", version, data.toByteArray())
                    else -> throw IllegalArgumentException("unknown prefix $prefix")
                }
            }

            companion object {
                const val tag: Int5 = 9
                fun decode(input: List): FallbackAddress = FallbackAddress(input.first().toByte(), Bech32.five2eight(input.tail().toTypedArray(), 0).toByteVector())
            }
        }


        data class Features(val bits: ByteVector) : TaggedField() {
            override val tag: Int5 = Features.tag

            override fun encode(): List {
                // We pad left to a multiple of 5
                val padded = bits.toByteArray().toMutableList()
                while (padded.size * 8 % 5 != 0) {
                    padded.add(0, 0)
                }
                // Then we remove leading 0 bytes
                return Bech32.eight2five(padded.toByteArray()).dropWhile { it == 0.toByte() }
            }

            companion object {
                const val tag: Int5 = 5
                fun decode(input: List): Features {
                    // We pad left to a multiple of 8
                    val padded = input.toMutableList()
                    while (padded.size * 5 % 8 != 0) {
                        padded.add(0, 0)
                    }
                    // Then we remove leading 0 bytes
                    val features = Bech32.five2eight(padded.toTypedArray(), 0).dropWhile { it == 0.toByte() }
                    return Features(features.toByteArray().toByteVector())
                }
            }
        }

        /**
         * Extra hop contained in RoutingInfoTag
         *
         * @param nodeId start of the channel
         * @param shortChannelId channel id
         * @param feeBase node fixed fee
         * @param feeProportionalMillionths node proportional fee
         * @param cltvExpiryDelta node cltv expiry delta
         */
        data class ExtraHop(
            val nodeId: PublicKey,
            val shortChannelId: ShortChannelId,
            val feeBase: MilliSatoshi,
            val feeProportionalMillionths: Long,
            val cltvExpiryDelta: CltvExpiryDelta
        )

        /** @param hints extra routing information for a private route */
        data class RoutingInfo(val hints: List) : TaggedField() {
            override val tag: Int5 = RoutingInfo.tag

            override fun encode(): List {
                val out = ByteArrayOutput()
                hints.forEach {
                    LightningCodecs.writeBytes(it.nodeId.value, out)
                    LightningCodecs.writeU64(it.shortChannelId.toLong(), out)
                    LightningCodecs.writeU32(it.feeBase.toLong().toInt(), out)
                    LightningCodecs.writeU32(it.feeProportionalMillionths.toInt(), out)
                    LightningCodecs.writeU16(it.cltvExpiryDelta.toInt(), out)
                }
                return Bech32.eight2five(out.toByteArray()).toList()
            }

            companion object {
                const val tag: Int5 = 3

                fun decode(input: List): RoutingInfo {
                    val stream = ByteArrayInput(Bech32.five2eight(input.toTypedArray(), 0))
                    val hints = ArrayList()
                    while (stream.availableBytes >= 51) {
                        val hint = ExtraHop(
                            PublicKey(LightningCodecs.bytes(stream, 33)),
                            ShortChannelId(LightningCodecs.u64(stream)),
                            MilliSatoshi(LightningCodecs.u32(stream).toLong()),
                            LightningCodecs.u32(stream).toLong(),
                            CltvExpiryDelta(LightningCodecs.u16(stream))
                        )
                        hints.add(hint)
                    }
                    return RoutingInfo(hints)
                }
            }
        }

        /** Unknown tag (may or may not be valid) */
        data class UnknownTag(override val tag: Int5, val value: List) : TaggedField() {
            override fun encode(): List = value.toList()
        }

        /** Tag that we know is not valid (value is of the wrong length for example) */
        data class InvalidTag(override val tag: Int5, val value: List) : TaggedField() {
            override fun encode(): List = value.toList()
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy