io.questdb.cutlass.line.tcp.DelegatingTlsChannel Maven / Gradle / Ivy
/*******************************************************************************
* ___ _ ____ ____
* / _ \ _ _ ___ ___| |_| _ \| __ )
* | | | | | | |/ _ \/ __| __| | | | _ \
* | |_| | |_| | __/\__ \ |_| |_| | |_) |
* \__\_\\__,_|\___||___/\__|____/|____/
*
* Copyright (c) 2014-2019 Appsicle
* Copyright (c) 2019-2024 QuestDB
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
******************************************************************************/
package io.questdb.cutlass.line.tcp;
import io.questdb.client.Sender;
import io.questdb.cutlass.line.LineChannel;
import io.questdb.cutlass.line.LineSenderException;
import io.questdb.log.Log;
import io.questdb.log.LogFactory;
import io.questdb.std.MemoryTag;
import io.questdb.std.Misc;
import io.questdb.std.Unsafe;
import io.questdb.std.Vect;
import javax.net.ssl.*;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.security.KeyStore;
import java.security.SecureRandom;
import java.security.cert.X509Certificate;
public final class DelegatingTlsChannel implements LineChannel {
private static final long ADDRESS_FIELD_OFFSET;
private static final int AFTER_HANDSHAKE = 1;
private static final TrustManager[] BLIND_TRUST_MANAGERS = new TrustManager[]{new X509TrustManager() {
public void checkClientTrusted(X509Certificate[] certs, String t) {
}
public void checkServerTrusted(X509Certificate[] certs, String t) {
}
public X509Certificate[] getAcceptedIssuers() {
return null;
}
}};
private static final long CAPACITY_FIELD_OFFSET;
private static final int CLOSED = 3;
private static final int CLOSING = 2;
private static final int INITIAL_BUFFER_CAPACITY = 64 * 1024;
private static final int INITIAL_STATE = 0;
private static final long LIMIT_FIELD_OFFSET;
private static final Log LOG = LogFactory.getLog(DelegatingTlsChannel.class);
private final ByteBuffer dummyBuffer;
private final SSLEngine sslEngine;
private final ByteBuffer wrapInputBuffer;
private LineChannel delegate;
private int state = INITIAL_STATE;
private ByteBuffer unwrapInputBuffer;
private long unwrapInputBufferPtr;
private ByteBuffer unwrapOutputBuffer;
private long unwrapOutputBufferPtr;
private ByteBuffer wrapOutputBuffer;
private long wrapOutputBufferPtr;
public DelegatingTlsChannel(LineChannel delegate, String trustStorePath, char[] password,
Sender.TlsValidationMode validationMode, String peerHost) {
this.delegate = delegate;
this.sslEngine = createSslEngine(trustStorePath, password, validationMode, peerHost);
// wrapInputBuffer is just a placeholder, we set the internal address, capacity and limit in send()
this.wrapInputBuffer = ByteBuffer.allocateDirect(0);
// allows to override in tests, but we don't necessary want to expose this to users.
int initialCapacity = Integer.getInteger("questdb.experimental.tls.buffersize", INITIAL_BUFFER_CAPACITY);
// we want to track allocated memory hence we just create dummy direct byte buffers
// and later reset it to manually allocated memory
this.wrapOutputBuffer = ByteBuffer.allocateDirect(0);
this.unwrapInputBuffer = ByteBuffer.allocateDirect(0);
this.unwrapOutputBuffer = ByteBuffer.allocateDirect(0);
this.wrapOutputBufferPtr = allocateMemoryAndResetBuffer(wrapOutputBuffer, initialCapacity);
this.unwrapInputBufferPtr = allocateMemoryAndResetBuffer(unwrapInputBuffer, initialCapacity);
this.unwrapOutputBufferPtr = allocateMemoryAndResetBuffer(unwrapOutputBuffer, initialCapacity);
this.dummyBuffer = ByteBuffer.allocate(0);
try {
handshakeLoop();
} catch (Throwable e) {
// do not close the delegate - we don't own it when our own constructors fails
close0(false);
throw new LineSenderException("could not perform TLS handshake", e);
}
}
@Override
public void close() {
close0(true);
}
public void close0(boolean closeDelegate) {
int prevState = state;
if (prevState == CLOSED) {
return;
}
state = CLOSING;
if (prevState == AFTER_HANDSHAKE) {
try {
sslEngine.closeOutbound();
wrapLoop(dummyBuffer);
writeToUpstreamAndClear();
} catch (Throwable e) {
LOG.error().$("could not send TLS close_notify alert").$(e).$();
}
}
state = CLOSED;
if (closeDelegate) {
delegate = Misc.free(delegate);
}
// a bit of ceremony to make sure there is no point that a buffer or a pointer is referencing unallocated memory
int capacity = wrapOutputBuffer.capacity();
long ptrToFree = wrapOutputBufferPtr;
wrapOutputBuffer = null; // if there is an attempt to use a buffer after close() then it's better to throw NPE than segfaulting
wrapOutputBufferPtr = 0;
Unsafe.free(ptrToFree, capacity, MemoryTag.NATIVE_TLS_RSS);
capacity = unwrapInputBuffer.capacity();
ptrToFree = unwrapInputBufferPtr;
unwrapInputBuffer = null;
unwrapInputBufferPtr = 0;
Unsafe.free(ptrToFree, capacity, MemoryTag.NATIVE_TLS_RSS);
capacity = unwrapOutputBuffer.capacity();
ptrToFree = unwrapOutputBufferPtr;
unwrapOutputBuffer = null;
unwrapOutputBufferPtr = 0;
Unsafe.free(ptrToFree, capacity, MemoryTag.NATIVE_TLS_RSS);
}
@Override
public int errno() {
return delegate.errno();
}
@Override
public int receive(long ptr, int len) {
try {
unwrapLoop();
unwrapOutputBuffer.flip();
int i = unwrapOutputBufferToPtr(ptr, len);
unwrapOutputBuffer.compact();
return i;
} catch (SSLException e) {
throw new LineSenderException("could not unwrap SSL packet", e);
}
}
@Override
public void send(long ptr, int len) {
try {
resetBufferToPointer(wrapInputBuffer, ptr, len);
wrapInputBuffer.position(0);
wrapLoop(wrapInputBuffer);
assert !wrapInputBuffer.hasRemaining();
} catch (SSLException e) {
throw new LineSenderException("error while sending data to questdb server", e);
}
}
private static long allocateMemoryAndResetBuffer(ByteBuffer buffer, int capacity) {
long newAddress = Unsafe.malloc(capacity, MemoryTag.NATIVE_TLS_RSS);
resetBufferToPointer(buffer, newAddress, capacity);
return newAddress;
}
private static SSLEngine createSslEngine(String trustStorePath, char[] trustStorePassword, Sender.TlsValidationMode validationMode, String peerHost) {
assert trustStorePath == null || validationMode == Sender.TlsValidationMode.DEFAULT;
try {
SSLContext sslContext;
if (trustStorePath != null) {
sslContext = SSLContext.getInstance("TLS");
TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
KeyStore jks = KeyStore.getInstance("JKS");
try (InputStream trustStoreStream = openTruststoreStream(trustStorePath)) {
jks.load(trustStoreStream, trustStorePassword);
}
tmf.init(jks);
TrustManager[] trustManagers = tmf.getTrustManagers();
sslContext.init(null, trustManagers, new SecureRandom());
} else if (validationMode == Sender.TlsValidationMode.INSECURE) {
sslContext = SSLContext.getInstance("TLS");
sslContext.init(null, BLIND_TRUST_MANAGERS, new SecureRandom());
} else {
sslContext = SSLContext.getDefault();
}
// SSLEngine needs to know hostname during TLS handshake to validate a server certificate was issued
// for the server we are connecting to. For details see the comment below.
// Hostname validation does not use port at all hence we can get away with a dummy value -1
SSLEngine sslEngine = sslContext.createSSLEngine(peerHost, -1);
if (validationMode != Sender.TlsValidationMode.INSECURE) {
SSLParameters sslParameters = sslEngine.getSSLParameters();
// The https validation algorithm? That looks confusing! After all we are not using any
// https here at so what does it mean?
// It's actually simple: It just instructs the SSLEngine to perform the same hostname validation
// as it does during HTTPS connections. SSLEngine does not do hostname validation by default. Without
// this option SSLEngine would happily accept any certificate as long as it's signed by a trusted CA.
// This option will make sure certificates are accepted only if they were issued for the
// server we are connecting to.
sslParameters.setEndpointIdentificationAlgorithm("https");
sslEngine.setSSLParameters(sslParameters);
}
sslEngine.setUseClientMode(true);
return sslEngine;
} catch (Throwable t) {
if (t instanceof LineSenderException) {
throw (LineSenderException) t;
}
throw new LineSenderException("could not create SSL engine", t);
}
}
private static long expandBuffer(ByteBuffer buffer, long oldAddress) {
int oldCapacity = buffer.capacity();
int newCapacity = oldCapacity * 2;
long newAddress = Unsafe.realloc(oldAddress, oldCapacity, newCapacity, MemoryTag.NATIVE_TLS_RSS);
resetBufferToPointer(buffer, newAddress, newCapacity);
return newAddress;
}
private static InputStream openTruststoreStream(String trustStorePath) throws FileNotFoundException {
InputStream trustStoreStream;
if (trustStorePath.startsWith("classpath:")) {
String adjustedPath = trustStorePath.substring("classpath:".length());
trustStoreStream = DelegatingTlsChannel.class.getResourceAsStream(adjustedPath);
if (trustStoreStream == null) {
throw new LineSenderException("configured trust store is unavailable ")
.put("[path=").put(trustStorePath).put("]");
}
return trustStoreStream;
}
return new FileInputStream(trustStorePath);
}
private static void resetBufferToPointer(ByteBuffer buffer, long ptr, int len) {
assert buffer.isDirect();
Unsafe.getUnsafe().putLong(buffer, ADDRESS_FIELD_OFFSET, ptr);
Unsafe.getUnsafe().putLong(buffer, LIMIT_FIELD_OFFSET, len);
Unsafe.getUnsafe().putLong(buffer, CAPACITY_FIELD_OFFSET, len);
}
private void growUnwrapInputBuffer() {
unwrapInputBufferPtr = expandBuffer(unwrapInputBuffer, unwrapInputBufferPtr);
}
private void growUnwrapOutputBuffer() {
unwrapOutputBufferPtr = expandBuffer(unwrapOutputBuffer, unwrapOutputBufferPtr);
}
private void growWrapOutputBuffer() {
wrapOutputBufferPtr = expandBuffer(wrapOutputBuffer, wrapOutputBufferPtr);
}
private void handshakeLoop() throws SSLException {
if (state != INITIAL_STATE) {
return;
}
// trigger handshaking - otherwise the initial state is NOT_HANDSHAKING
sslEngine.beginHandshake();
for (; ; ) {
SSLEngineResult.HandshakeStatus status = sslEngine.getHandshakeStatus();
switch (status) {
case NOT_HANDSHAKING:
state = AFTER_HANDSHAKE;
return;
case NEED_TASK:
sslEngine.getDelegatedTask().run();
break;
case NEED_WRAP:
wrapLoop(dummyBuffer);
break;
case NEED_UNWRAP:
unwrapLoop();
break;
case FINISHED:
throw new LineSenderException("getHandshakeStatus() returned FINISHED. It should not have been possible.");
default:
throw new LineSenderException(status + "not supported");
}
}
}
private void readFromUpstream(boolean force) {
if (unwrapInputBuffer.position() != 0 && !force) {
// we don't want to block on receive() if there are still data to be processed
// unless we are forced to do so
return;
}
assert unwrapInputBuffer.limit() == unwrapInputBuffer.capacity();
int remainingLen = unwrapInputBuffer.remaining();
if (remainingLen == 0) {
growUnwrapInputBuffer();
remainingLen = unwrapInputBuffer.remaining();
}
assert Unsafe.getUnsafe().getLong(unwrapInputBuffer, ADDRESS_FIELD_OFFSET) == unwrapInputBufferPtr;
long adjustedPtr = unwrapInputBufferPtr + unwrapInputBuffer.position();
int receive = delegate.receive(adjustedPtr, remainingLen);
if (receive < 0) {
throw new LineSenderException("connection closed");
}
unwrapInputBuffer.position(unwrapInputBuffer.position() + receive);
}
private void unwrapLoop() throws SSLException {
// we want the loop to return as soon as we have some unwrapped data in the output buffer
while (unwrapOutputBuffer.position() == 0) {
readFromUpstream(false);
unwrapInputBuffer.flip();
SSLEngineResult result = sslEngine.unwrap(unwrapInputBuffer, unwrapOutputBuffer);
unwrapInputBuffer.compact();
switch (result.getStatus()) {
case BUFFER_UNDERFLOW:
// we need more input no matter what. so let's force reading from the upstream channel
readFromUpstream(true);
break;
case BUFFER_OVERFLOW:
if (unwrapOutputBuffer.position() != 0) {
// we have at least something, that's enough
// if it's not enough then it's up to the caller to call us again
return;
}
// there was overflow, and we have nothing
// apparently the output buffer cannot fit even a single TLS record. let's grow it!
growUnwrapOutputBuffer();
break;
case OK:
return;
case CLOSED:
throw new LineSenderException("server closed connection unexpectedly");
}
}
}
private int unwrapOutputBufferToPtr(long dstPtr, int dstLen) {
int oldPosition = unwrapOutputBuffer.position();
assert Unsafe.getUnsafe().getLong(unwrapOutputBufferPtr, ADDRESS_FIELD_OFFSET) == unwrapOutputBufferPtr;
long srcPtr = unwrapOutputBufferPtr + oldPosition;
int srcLen = unwrapOutputBuffer.remaining();
int len = Math.min(dstLen, srcLen);
Vect.memcpy(dstPtr, srcPtr, len);
unwrapOutputBuffer.position(oldPosition + len);
return len;
}
private void wrapLoop(ByteBuffer src) throws SSLException {
do {
SSLEngineResult result = sslEngine.wrap(src, wrapOutputBuffer);
switch (result.getStatus()) {
case BUFFER_UNDERFLOW:
throw new LineSenderException("should not happen");
case BUFFER_OVERFLOW:
growWrapOutputBuffer();
break;
case OK:
writeToUpstreamAndClear();
break;
case CLOSED:
if (state != CLOSING) {
throw new LineSenderException("server closed connection unexpectedly");
}
return;
}
} while (src.hasRemaining());
}
private void writeToUpstreamAndClear() {
assert wrapOutputBuffer.limit() == wrapOutputBuffer.capacity();
// we don't flip the wrapOutputBuffer before reading from it
// hence the writer position is the actual length to be sent to the upstream channel
int len = wrapOutputBuffer.position();
assert Unsafe.getUnsafe().getLong(wrapOutputBuffer, ADDRESS_FIELD_OFFSET) == wrapOutputBufferPtr;
delegate.send(wrapOutputBufferPtr, len);
// we know limit == capacity
// thus setting the position to 0 is equivalent to clearing
wrapOutputBuffer.position(0);
}
static {
Field addressField;
Field limitField;
Field capacityField;
try {
addressField = Buffer.class.getDeclaredField("address");
limitField = Buffer.class.getDeclaredField("limit");
capacityField = Buffer.class.getDeclaredField("capacity");
} catch (NoSuchFieldException e) {
// possible improvement: implement a fallback strategy when reflection is unavailable for any reason.
throw new ExceptionInInitializerError(e);
}
ADDRESS_FIELD_OFFSET = Unsafe.getUnsafe().objectFieldOffset(addressField);
LIMIT_FIELD_OFFSET = Unsafe.getUnsafe().objectFieldOffset(limitField);
CAPACITY_FIELD_OFFSET = Unsafe.getUnsafe().objectFieldOffset(capacityField);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy