
org.openziti.net.ZitiSocketChannel.kt Maven / Gradle / Ivy
/*
* Copyright (c) 2018-2023 NetFoundry Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.openziti.net
import com.google.gson.Gson
import com.goterl.lazysodium.utils.Key
import com.goterl.lazysodium.utils.KeyPair
import com.goterl.lazysodium.utils.SessionPair
import kotlinx.coroutines.*
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.sync.Mutex
import org.openziti.Errors
import org.openziti.ZitiAddress
import org.openziti.ZitiConnection
import org.openziti.ZitiException
import org.openziti.api.Session
import org.openziti.api.SessionType
import org.openziti.crypto.Crypto
import org.openziti.impl.ZitiContextImpl
import org.openziti.net.ZitiProtocol.CryptoMethod
import org.openziti.net.ZitiProtocol.Header
import org.openziti.net.nio.FutureHandler
import org.openziti.net.nio.readSuspend
import org.openziti.net.nio.writeCompletely
import org.openziti.util.Logged
import org.openziti.util.ZitiLog
import java.io.ByteArrayOutputStream
import java.io.Externalizable
import java.io.IOException
import java.io.ObjectOutputStream
import java.net.ConnectException
import java.net.InetSocketAddress
import java.net.SocketAddress
import java.net.SocketOption
import java.nio.ByteBuffer
import java.nio.channels.*
import java.nio.channels.CompletionHandler
import java.nio.channels.spi.AsynchronousChannelProvider
import java.util.concurrent.*
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.atomic.AtomicReference
import kotlin.text.Charsets.UTF_8
import kotlinx.coroutines.channels.Channel as Chan
internal class ZitiSocketChannel private constructor(internal val ctx: ZitiContextImpl, val connId: Int):
AsynchronousSocketChannel(Provider),
Channel.MessageReceiver,
ZitiConnection,
InputChannel,
CoroutineScope by ctx,
Logged by ZitiLog("ziti-conn[${ctx.name()}/${connId}]") {
constructor(ztx:ZitiContextImpl): this(ztx, ztx.nextConnId())
object Provider: AsynchronousChannelProvider() {
override fun openAsynchronousSocketChannel(group: AsynchronousChannelGroup?): AsynchronousSocketChannel =
TODO()
override fun openAsynchronousServerSocketChannel(group: AsynchronousChannelGroup?): AsynchronousServerSocketChannel {
TODO("Not yet implemented")
}
override fun openAsynchronousChannelGroup(nThreads: Int, threadFactory: ThreadFactory?): AsynchronousChannelGroup =
TODO("Not yet implemented")
override fun openAsynchronousChannelGroup(executor: ExecutorService?, initialSize: Int): AsynchronousChannelGroup =
TODO("Not yet implemented")
}
internal enum class State {
initial,
connecting,
connected,
closed
}
val sentFin = AtomicBoolean(false)
override var timeout: Long = 0
val state = AtomicReference(State.initial)
val channel = CompletableDeferred()
val seq = AtomicInteger(1)
lateinit var serviceName: String
var remote: SocketAddress? = null
var local: ZitiAddress? = null
val receiveQueue = Chan(16)
override val inputSupport = InputChannel.InputSupport(receiveQueue)
val crypto = CompletableDeferred()
override fun getLocalAddress(): SocketAddress? = local
override fun getRemoteAddress(): SocketAddress? = remote
override fun supportedOptions(): MutableSet> = mutableSetOf()
override fun getOption(name: SocketOption?): T? = null
override fun setOption(name: SocketOption?, value: T): AsynchronousSocketChannel = this
override fun isClosed() = !isOpen
override fun isConnected() = isOpen && state.get() != State.initial
private var connectOp: Job? = null
private var writeOp: Job? = null
private val writeMutex = Mutex()
internal suspend fun connectInternal(addr: ZitiAddress.Dial, attachment: A, handler: CompletionHandler) {
d{"connecting to $serviceName"}
val (service,ns) = ctx.runCatching {
val service = getService(serviceName) ?: throw ZitiException(Errors.ServiceNotAvailable)
val ns = getNetworkSession(service, SessionType.DIAL)
service to ns
}.getOrElse {
w{"failed to connect: $it"}
channel.completeExceptionally(it)
close()
handler.failed(it, attachment)
return
}
d{"using session[${ns.id}]"}
val ch = ctx.runCatching { getChannel(ns) }.getOrElse {
w{"failed to connect: $it"}
channel.completeExceptionally(it)
close()
handler.failed(it, attachment)
return
}
d{"using ch[$ch]"}
channel.complete(ch)
ch.registerReceiver(connId, this)
val kp = if (service.encryptionRequired) Crypto.newKeyPair() else null
runCatching { doZitiHandshake(ch, addr, ns, kp) }.onFailure {
if (it is CancellationException) {
d{"connect was canceled"}
handler.failed(AsynchronousCloseException(), attachment)
} else {
w{"failed to connect: $it"}
close()
handler.failed(it, attachment)
}
}.onSuccess {
handler.completed(null, attachment)
}
}
override fun connect(remote: SocketAddress, attachment: A, handler: CompletionHandler) {
ctx.isEnabled() || throw ShutdownChannelGroupException()
state.getAndUpdate { st ->
when(st) {
State.initial -> {}
State.connecting -> throw ConnectionPendingException()
State.connected -> throw AlreadyConnectedException()
State.closed -> throw ClosedChannelException()
null -> error("not possible")
}
State.connecting
}
d{"connecting to $remote"}
val addr = when (remote) {
is InetSocketAddress -> ctx.getDialAddress(remote, Protocol.TCP) ?: throw UnresolvedAddressException()
is ZitiAddress.Dial -> remote
else -> throw UnsupportedAddressTypeException()
}
serviceName = addr.service
val conOp = ctx.launch { connectInternal(addr, attachment, handler) }
conOp.invokeOnCompletion { ex ->
if (ex != null)
e{" failed to connect: $ex"}
else
d{"connected"}
}
connectOp = conOp
}
override fun connect(remote: SocketAddress): Future {
val result = CompletableFuture()
connect(remote, result, FutureHandler())
return result
}
override fun isOpen(): Boolean = state.get() != State.closed
override fun bind(local: SocketAddress?): AsynchronousSocketChannel {
return this
}
override fun shutdownInput(): ZitiSocketChannel {
when(state.get()) {
State.connecting, State.connected -> deregister()
else -> {}
}
return super.shutdownInput()
}
private fun deregister() {
ctx.launch {
channel.runCatching {
await().deregisterReceiver(connId)
}
}
}
override fun close() {
connectOp?.cancel("close")
writeOp?.cancel("close")
deregister()
super.close()
runCatching { shutdownOutput() }
runCatching { closeInternal() }
}
override fun shutdownOutput(): AsynchronousSocketChannel {
if (state.get() == State.connected && sentFin.compareAndSet(false, true)) {
val finMsg = Message(ZitiProtocol.ContentType.Data).apply {
setHeader(Header.ConnId, connId)
setHeader(Header.FlagsHeader, ZitiProtocol.EdgeFlags.FIN)
setHeader(Header.SeqHeader, seq.getAndIncrement())
}
d("sending FIN")
ctx.async {
val ch = channel.getCompleted()
ch.SendSynch(finMsg)
}.invokeOnCompletion { ex ->
ex.takeIf { it !is CancellationException }?.let { e ->
w{ "failed to send FIN message: $e" }
}
}
}
return this
}
internal fun closeInternal(): AsynchronousSocketChannel {
synchronized(state) {
when (state.get()) {
State.initial ->
state.set(State.closed)
State.connecting, State.connected -> {
val closeMsg = Message(ZitiProtocol.ContentType.StateClosed).apply {
setHeader(Header.ConnId, connId)
}
d("closing conn = ${this.connId}")
ctx.async {
val ch = channel.getCompleted()
ch.SendSynch(closeMsg)
}.invokeOnCompletion {
it.takeIf { it !is CancellationException }?.let {
w { "failed to send StateClosed message: ${it.localizedMessage}" }
}
}
state.set(State.closed)
}
State.closed -> {}
else -> {}
}
ctx.close(this)
}
return this
}
override
fun read(
dst: ByteBuffer, timeout: Long, unit: TimeUnit,
att: A, handler: CompletionHandler
) = super.read(dst, timeout, unit, att, handler)
override fun read(dst: ByteBuffer): Future = super.read(dst)
override fun read(
dsts: Array, offset: Int, length: Int,
to: Long, unit: TimeUnit, att: A, handler: CompletionHandler
) = super.read(dsts, offset, length, to, unit, att, handler)
override
fun write(src: ByteBuffer, to: Long, unit: TimeUnit?, att: A, handler: CompletionHandler) {
write(arrayOf(src), 0, 1, to, unit, att, object : CompletionHandler{
override fun completed(result: Long, a: A): Unit = handler.completed(result.toInt(), a)
override fun failed(exc: Throwable, a: A) = handler.failed(exc, a)
})
}
override fun write(src: ByteBuffer): Future {
val result = CompletableFuture()
write(src, result, FutureHandler())
return result
}
override fun write(
_srcs: Array, offset: Int, length: Int,
timeout: Long, unit: TimeUnit?,
att: A, handler: CompletionHandler
) {
when (state.get()) {
State.initial,
State.connecting -> throw NotYetConnectedException()
State.connected -> {}
State.closed -> throw ClosedChannelException()
else -> error("should not be here")
}
channel.isCompleted || throw NotYetConnectedException()
writeMutex.tryLock() || throw WritePendingException()
val srcs = _srcs.slice(offset until offset + length)
val wop = ctx.async {
var sent = 0L
for (b in srcs) {
var data = ByteArray(b.remaining())
b.get(data)
crypto.getCompleted()?.let {
data = it.encrypt(data)
}
val dataMessage = Message(ZitiProtocol.ContentType.Data, data)
dataMessage.setHeader(Header.ConnId, connId)
dataMessage.setHeader(Header.SeqHeader, seq.getAndIncrement())
sent += data.size
v("sending $dataMessage")
channel.await().Send(dataMessage)
}
sent
}
writeOp = wop
wop.invokeOnCompletion { ex ->
writeOp = null
writeMutex.unlock()
if (ex == null) {
val sent = wop.getCompleted()
handler.completed(sent, att)
} else if (ex is TimeoutCancellationException) {
handler.failed(InterruptedByTimeoutException(), att)
} else if (ex is CancellationException) {
handler.failed(AsynchronousCloseException(), att)
} else {
handler.failed(ex, att)
}
}
}
override suspend fun receive(msg: Result) {
msg.onSuccess {
receiveMsg(it)
}.onFailure {
close()
}
}
private suspend fun receiveMsg(msg: Message) {
v{"conn[$connId] received message[${msg.content}] with seq[${msg.getIntHeader(Header.SeqHeader)}]"}
when (msg.content) {
ZitiProtocol.ContentType.StateClosed -> {
t{"signaling EOF"}
receiveQueue.close()
deregister()
close()
}
ZitiProtocol.ContentType.Data -> {
t{"received data(${msg.body.size} bytes) for conn[$connId]"}
if (msg.body.size > 0) {
val crypt = crypto.await()
if (crypt != null) {
if (crypt.initialized()) {
receiveQueue.send(crypt.decrypt(msg.body))
} else {
crypt.init(msg.body)
d { "crypto init finished conn[$connId]" }
}
} else {
receiveQueue.send(msg.body)
}
}
msg.getIntHeader(Header.FlagsHeader)?.let {
if (it and ZitiProtocol.EdgeFlags.FIN != 0 ) {
d("received FIN")
receiveQueue.close()
}
}
}
else -> {
e{"unexpected message type[${msg.content}] for conn[$connId]"}
receiveQueue.close(IllegalStateException())
deregister()
close()
}
}
}
override suspend fun send(data: ByteArray) = send(data, 0, data.size)
suspend fun send(data: ByteArray, offset: Int, len: Int) {
writeCompletely(ByteBuffer.wrap(data, offset, len))
}
override suspend fun receive(out: ByteArray, off: Int, len: Int): Int {
val dst = ByteBuffer.wrap(out, off, len)
return try {
readSuspend(dst, timeout, TimeUnit.MILLISECONDS)
} catch (ex: TimeoutCancellationException) {
0
}
}
internal suspend fun doZitiHandshake(ch: Channel, remote: ZitiAddress.Dial, ns: Session, kp: KeyPair?) {
val connectMsg = Message(ZitiProtocol.ContentType.Connect, ns.token.toByteArray(UTF_8)).apply {
setHeader(Header.ConnId, connId)
setHeader(Header.SeqHeader, 0)
kp?.let {
setHeader(Header.PublicKeyHeader, it.publicKey.asBytes)
setHeader(Header.CryptoMethodHeader, CryptoMethod.Libsodium)
}
remote.identity?.let {
setHeader(Header.TerminatorIdentityHeader, it)
}
(remote.callerId ?: ctx.getId()?.name)?.let {
setHeader(Header.CallerIdHeader, it)
}
remote.appData?.let { obj ->
val header = when(obj) {
is ByteArray -> obj
is String -> obj.toByteArray(UTF_8)
is Externalizable ->
runCatching {
val out = ByteArrayOutputStream()
obj.writeExternal(ObjectOutputStream(out))
out.toByteArray()
}.onFailure {
w { "failed to serialize provided app_data: ${it.localizedMessage}" }
}.getOrNull()
else -> Gson().toJson(obj).toByteArray(UTF_8)
}
header?.let { setHeader(Header.AppDataHeader, it) }
}
}
d("starting network connection ${ns.id}/$connId")
val reply = ch.SendAndWait(connectMsg)
when (reply.content) {
ZitiProtocol.ContentType.StateConnected -> {
val peerPk = reply.getHeader(Header.PublicKeyHeader)
if (kp == null || peerPk == null) {
crypto.complete(null)
} else {
setupCrypto(Crypto.kx(kp, Key.fromBytes(peerPk), false))
startCrypto(ch)
}
local = ZitiAddress.Session(ns.id, serviceName, null, null)
d("network connection established ${ns.id}/$connId")
state.set(State.connected)
}
ZitiProtocol.ContentType.StateClosed -> {
val err = reply.body.toString(UTF_8)
w("connection rejected: $err")
throw ConnectException(err)
}
else -> {
throw IOException("Invalid response type")
}
}
}
internal fun setupCrypto(keys: SessionPair?) {
crypto.complete(keys?.let { Crypto.newStream(it) })
}
internal suspend fun startCrypto(ch: Channel) {
crypto.await()?.let {
val header = it.header()
val headerMessage = Message(ZitiProtocol.ContentType.Data, header)
.setHeader(Header.ConnId, connId)
.setHeader(Header.SeqHeader, seq.getAndIncrement())
ch.Send(headerMessage)
}
}
override fun toString(): String {
val s = StringBuilder(this::class.java.simpleName)
s.append("[$state]")
when(state.get()) {
State.connecting -> s.append("(remote=$remote)")
State.connected -> s.append("($local -> $remote)")
else -> {}
}
return s.toString()
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy