jvmMain.fr.acinq.lightning.io.JvmTcpSocket.kt Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of lightning-kmp-jvm Show documentation
Show all versions of lightning-kmp-jvm Show documentation
A Kotlin Multiplatform implementation of the Lightning Network
package fr.acinq.lightning.io
import co.touchlab.kermit.Logger
import fr.acinq.lightning.logging.*
import io.ktor.network.selector.*
import io.ktor.network.sockets.*
import io.ktor.network.tls.*
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.ClosedReceiveChannelException
import kotlinx.coroutines.channels.ClosedSendChannelException
import java.net.ConnectException
import java.net.SocketException
import java.security.KeyFactory
import java.security.KeyStore
import java.security.cert.CertificateException
import java.security.cert.X509Certificate
import java.security.spec.X509EncodedKeySpec
import java.util.*
import javax.net.ssl.TrustManagerFactory
import javax.net.ssl.X509TrustManager
class JvmTcpSocket(val socket: Socket, val loggerFactory: LoggerFactory) : TcpSocket {
private val logger = loggerFactory.newLogger(this::class)
private val connection = socket.connection()
override suspend fun send(bytes: ByteArray?, offset: Int, length: Int, flush: Boolean) =
withContext(Dispatchers.IO) {
ensureActive()
try {
if (bytes != null) connection.output.writeFully(bytes, offset, length)
if (flush) connection.output.flush()
} catch (ex: ClosedSendChannelException) {
throw TcpSocket.IOException.ConnectionClosed(ex)
} catch (ex: java.io.IOException) {
throw TcpSocket.IOException.ConnectionClosed(ex)
} catch (ex: CancellationException) {
throw ex
} catch (ex: Throwable) {
throw TcpSocket.IOException.Unknown(ex.message, ex)
}
}
private inline fun tryReceive(receive: () -> R): R {
try {
return receive()
} catch (ex: ClosedReceiveChannelException) {
throw TcpSocket.IOException.ConnectionClosed(ex)
} catch (ex: java.io.IOException) {
throw TcpSocket.IOException.ConnectionClosed(ex)
} catch (ex: CancellationException) {
throw ex
} catch (ex: Throwable) {
throw TcpSocket.IOException.Unknown(ex.message, ex)
}
}
private suspend fun receive(read: suspend () -> R): R =
withContext(Dispatchers.IO) {
ensureActive()
tryReceive { read() }
}
override suspend fun receiveFully(buffer: ByteArray, offset: Int, length: Int) {
receive { connection.input.readFully(buffer, offset, length) }
}
override suspend fun receiveAvailable(buffer: ByteArray, offset: Int, length: Int): Int {
return receive { connection.input.readAvailable(buffer, offset, length) }
.takeUnless { it == -1 } ?: throw TcpSocket.IOException.ConnectionClosed()
}
override suspend fun startTls(tls: TcpSocket.TLS): TcpSocket = try {
when (tls) {
is TcpSocket.TLS.TRUSTED_CERTIFICATES -> JvmTcpSocket(connection.tls(tlsContext(logger)), loggerFactory)
TcpSocket.TLS.UNSAFE_CERTIFICATES -> JvmTcpSocket(connection.tls(tlsContext(logger)) {
logger.warning { "using unsafe TLS!" }
trustManager = unsafeX509TrustManager()
}, loggerFactory)
is TcpSocket.TLS.PINNED_PUBLIC_KEY -> {
JvmTcpSocket(connection.tls(tlsContext(logger), tlsConfigForPinnedCert(tls.pubKey, logger)), loggerFactory)
}
TcpSocket.TLS.DISABLED -> this
}
} catch (e: Exception) {
throw when (e) {
is ConnectException -> TcpSocket.IOException.ConnectionRefused(e)
is SocketException -> TcpSocket.IOException.Unknown(e.message, e)
else -> e
}
}
override fun close() {
// NB: this safely calls close(), wrapping it into a try/catch.
socket.dispose()
}
companion object {
fun unsafeX509TrustManager() = object : X509TrustManager {
override fun checkClientTrusted(p0: Array?, p1: String?) {}
override fun checkServerTrusted(p0: Array?, p1: String?) {}
override fun getAcceptedIssuers(): Array? = null
}
fun buildPublicKey(encodedKey: ByteArray, logger: Logger): java.security.PublicKey {
val spec = X509EncodedKeySpec(encodedKey)
val algorithms = listOf("RSA", "EC", "DiffieHellman", "DSA", "RSASSA-PSS", "XDH", "X25519", "X448")
algorithms.map {
try {
return KeyFactory.getInstance(it).generatePublic(spec)
} catch (e: Exception) {
logger.debug { "key does not use $it algorithm" }
}
}
throw IllegalArgumentException("unsupported key's algorithm, only $algorithms")
}
fun tlsConfigForPinnedCert(pinnedPubkey: String, logger: Logger): TLSConfig = TLSConfigBuilder().apply {
// build a default X509 trust manager.
val factory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())!!
factory.init(null as KeyStore?)
val defaultX509TrustManager = factory.trustManagers!!.filterIsInstance().first()
// create a new trust manager that always accepts certificates for the pinned public key, or falls back to standard procedure.
trustManager = object : X509TrustManager {
override fun checkClientTrusted(chain: Array?, authType: String?) {
defaultX509TrustManager.checkClientTrusted(chain, authType)
}
override fun checkServerTrusted(chain: Array?, authType: String?) {
val serverKey = try {
buildPublicKey(chain?.asList()?.firstOrNull()?.publicKey?.encoded ?: throw RuntimeException("certificate missing"), logger)
} catch (e: Exception) {
logger.error(e) { "failed to read server's pubkey=${pinnedPubkey}" }
throw e
}
val pinnedKey = try {
buildPublicKey(Base64.getDecoder().decode(pinnedPubkey), logger)
} catch (e: Exception) {
logger.error(e) { "failed to read pinned pubkey=${pinnedPubkey}" }
throw e
}
if (serverKey == pinnedKey) {
logger.info { "successfully checked server's certificate against pinned pubkey" }
} else {
logger.warning { "server's certificate does not match pinned pubkey, fallback to default check" }
throw TcpSocket.IOException.ConnectionClosed(CertificateException("certificate does not match pinned key"))
}
}
override fun getAcceptedIssuers(): Array = defaultX509TrustManager.acceptedIssuers
}
}.build()
}
}
internal actual object PlatformSocketBuilder : TcpSocket.Builder {
override suspend fun connect(host: String, port: Int, tls: TcpSocket.TLS, loggerFactory: LoggerFactory): TcpSocket {
val logger = loggerFactory.newLogger(this::class)
return withContext(Dispatchers.IO) {
try {
val socket = aSocket(SelectorManager(Dispatchers.IO)).tcp().connect(host, port).let { socket ->
when (tls) {
is TcpSocket.TLS.TRUSTED_CERTIFICATES -> socket.tls(tlsContext(logger))
TcpSocket.TLS.UNSAFE_CERTIFICATES -> socket.tls(tlsContext(logger)) {
logger.warning { "using unsafe TLS!" }
trustManager = JvmTcpSocket.unsafeX509TrustManager()
}
is TcpSocket.TLS.PINNED_PUBLIC_KEY -> {
logger.info { "using certificate pinning for connections with $host" }
socket.tls(tlsContext(logger), JvmTcpSocket.tlsConfigForPinnedCert(tls.pubKey, logger))
}
else -> socket
}
}
JvmTcpSocket(socket, loggerFactory)
} catch (e: Exception) {
throw when (e) {
is ConnectException -> TcpSocket.IOException.ConnectionRefused(e)
is SocketException -> TcpSocket.IOException.Unknown(e.message, e)
else -> e
}
}
}
}
}
/**
* The TLS internal coroutines are launched in a background scope that doesn't let us do fine-grained supervision.
* They may throw exceptions when the socket is remotely closed, which crashes the application on Android.
* This should be fixed by https://github.com/ktorio/ktor/pull/3690, but for now we need to explicitly handle exceptions.
*/
fun tlsContext(logger: Logger) = Dispatchers.IO + CoroutineExceptionHandler { _, throwable -> logger.error(throwable) { "TLS socket error: " } }
© 2015 - 2024 Weber Informatics LLC | Privacy Policy