All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.yahoo.jrt.TlsCryptoSocket Maven / Gradle / Ivy

The newest version!
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.jrt;

import com.yahoo.security.tls.ConnectionAuthContext;
import com.yahoo.security.tls.PeerAuthorizationFailedException;
import com.yahoo.security.tls.TransportSecurityUtils;

import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLHandshakeException;
import javax.net.ssl.SSLSession;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.SocketChannel;
import java.util.Objects;
import java.util.logging.Logger;

import static javax.net.ssl.SSLEngineResult.Status;

/**
 * A {@link CryptoSocket} using TLS ({@link SSLEngine})
 *
 * @author bjorncs
 */
public class TlsCryptoSocket implements CryptoSocket {

    private static final ByteBuffer NULL_BUFFER = ByteBuffer.allocate(0);

    private static final Logger log = Logger.getLogger(TlsCryptoSocket.class.getName());

    private enum HandshakeState { NOT_STARTED, NEED_READ, NEED_WRITE, NEED_WORK, COMPLETED }

    private final TransportMetrics metrics = TransportMetrics.getInstance();
    private final SocketChannel channel;
    private final SSLEngine sslEngine;
    private final Buffer wrapBuffer;
    private final Buffer unwrapBuffer;
    private int sessionPacketBufferSize;
    private int sessionApplicationBufferSize;
    private ByteBuffer handshakeDummyBuffer;
    private HandshakeState handshakeState;
    private ConnectionAuthContext authContext;

    public TlsCryptoSocket(SocketChannel channel, SSLEngine sslEngine) {
        this.channel = channel;
        this.sslEngine = sslEngine;
        this.wrapBuffer = new Buffer(0);
        this.unwrapBuffer = new Buffer(0);
        SSLSession nullSession = sslEngine.getSession();
        sessionApplicationBufferSize = nullSession.getApplicationBufferSize();
        sessionPacketBufferSize = nullSession.getPacketBufferSize();
        // Note: Dummy buffer as unwrap requires a full size application buffer even though no application data is unwrapped
        this.handshakeDummyBuffer = ByteBuffer.allocate(sessionApplicationBufferSize);
        this.handshakeState = HandshakeState.NOT_STARTED;
        log.fine(() -> "Initialized with " + sslEngine.toString());
    }

    // inject pre-read data into the read pipeline (typically called by MaybeTlsCryptoSocket)
    public void injectReadData(Buffer data) {
        unwrapBuffer.getWritable(data.bytes()).put(data.getReadable());
    }

    @Override
    public SocketChannel channel() {
        return channel;
    }

    @Override
    public HandshakeResult handshake() throws IOException {
        HandshakeState newHandshakeState = processHandshakeState(this.handshakeState);
        log.fine(() -> String.format("Handshake state '%s -> %s'", this.handshakeState, newHandshakeState));
        this.handshakeState = newHandshakeState;
        return toHandshakeResult(newHandshakeState);
    }

    @Override
    public void doHandshakeWork() {
        Runnable task;
        while ((task = sslEngine.getDelegatedTask()) != null) {
            task.run();
        }
    }

    private HandshakeState processHandshakeState(HandshakeState state) throws IOException {
        try {
            switch (state) {
                case NOT_STARTED:
                    log.fine(() -> "Initiating handshake");
                    sslEngine.beginHandshake();
                    break;
                case NEED_WRITE:
                    channelWrite();
                    break;
                case NEED_READ:
                    channelRead();
                    break;
                case NEED_WORK:
                    break;
                case COMPLETED:
                    return HandshakeState.COMPLETED;
                default:
                    throw unhandledStateException(state);
            }
            while (true) {
                log.fine(() -> "SSLEngine.getHandshakeStatus(): " + sslEngine.getHandshakeStatus());
                switch (sslEngine.getHandshakeStatus()) {
                    case NOT_HANDSHAKING:
                        if (wrapBuffer.bytes() > 0) return HandshakeState.NEED_WRITE;
                        sslEngine.setEnableSessionCreation(false); // disable renegotiation
                        handshakeDummyBuffer = null;
                        SSLSession session = sslEngine.getSession();
                        sessionApplicationBufferSize = session.getApplicationBufferSize();
                        sessionPacketBufferSize = session.getPacketBufferSize();
                        authContext = TransportSecurityUtils.getConnectionAuthContext(session).orElseThrow();
                        if (!authContext.authorized()) {
                            metrics.incrementPeerAuthorizationFailures();
                        }
                        log.fine(() -> String.format("Handshake complete: protocol=%s, cipherSuite=%s", session.getProtocol(), session.getCipherSuite()));
                        if (sslEngine.getUseClientMode()) {
                            metrics.incrementClientTlsConnectionsEstablished();
                        } else {
                            metrics.incrementServerTlsConnectionsEstablished();
                        }
                        return HandshakeState.COMPLETED;
                    case NEED_TASK:
                        return HandshakeState.NEED_WORK;
                    case NEED_UNWRAP:
                        if (wrapBuffer.bytes() > 0) return HandshakeState.NEED_WRITE;
                        if (!handshakeUnwrap()) return HandshakeState.NEED_READ;
                        break;
                    case NEED_WRAP:
                        if (!handshakeWrap()) return HandshakeState.NEED_WRITE;
                        break;
                    default:
                        throw new IllegalStateException("Unexpected handshake status: " + sslEngine.getHandshakeStatus());
                }
            }
        } catch (SSLHandshakeException e) {
            if (!(e.getCause() instanceof PeerAuthorizationFailedException)) {
                metrics.incrementTlsCertificateVerificationFailures();
            }
            throw e;
        }
    }

    private static HandshakeResult toHandshakeResult(HandshakeState state) {
        switch (state) {
            case NEED_READ:
                return HandshakeResult.NEED_READ;
            case NEED_WRITE:
                return HandshakeResult.NEED_WRITE;
            case NEED_WORK:
                return HandshakeResult.NEED_WORK;
            case COMPLETED:
                return HandshakeResult.DONE;
            default:
                throw unhandledStateException(state);
        }
    }

    @Override
    public int getMinimumReadBufferSize() {
        return sessionApplicationBufferSize;
    }

    @Override
    public int read(ByteBuffer dst) throws IOException {
        verifyHandshakeCompleted();
        int bytesUnwrapped = drain(dst);
        if (bytesUnwrapped > 0) return bytesUnwrapped;

        int bytesRead = channelRead();
        if (bytesRead == 0) return 0;
        return drain(dst);
    }

    @Override
    public int drain(ByteBuffer dst) throws IOException {
        verifyHandshakeCompleted();
        int totalBytesUnwrapped = 0;
        while (true) {
            int result = applicationDataUnwrap(dst);
            if (result < 0) return totalBytesUnwrapped;
            totalBytesUnwrapped += result;
        }
    }

    @Override
    public int write(ByteBuffer src) throws IOException {
        verifyHandshakeCompleted();
        if (flush() == FlushResult.NEED_WRITE) return 0;
        int totalBytesWrapped = 0;
        int bytesWrapped;
        do {
            bytesWrapped = applicationDataWrap(src);
            totalBytesWrapped += bytesWrapped;
        } while (bytesWrapped > 0 && wrapBuffer.bytes() < sessionPacketBufferSize);
        return totalBytesWrapped;
    }

    @Override
    public FlushResult flush() throws IOException {
        verifyHandshakeCompleted();
        channelWrite();
        return wrapBuffer.bytes() > 0 ? FlushResult.NEED_WRITE : FlushResult.DONE;
    }

    @Override public void dropEmptyBuffers() {
        wrapBuffer.shrink(0);
        unwrapBuffer.shrink(0);
    }

    @Override
    public ConnectionAuthContext connectionAuthContext() {
        if (handshakeState != HandshakeState.COMPLETED) throw new IllegalStateException("Handshake not complete");
        return Objects.requireNonNull(authContext);
    }

    private boolean handshakeWrap() throws IOException {
        SSLEngineResult result = sslEngineWrap(NULL_BUFFER);
        switch (result.getStatus()) {
            case OK:
                return true;
            case BUFFER_OVERFLOW:
                // This is to ensure we have large enough buffer during handshake phase too.
                sessionPacketBufferSize = sslEngine.getSession().getPacketBufferSize();
                return false;
            default:
                throw unexpectedStatusException(result.getStatus());
        }
    }

    private int applicationDataWrap(ByteBuffer src) throws IOException {
        SSLEngineResult result = sslEngineWrap(src);
        failIfRenegotiationDetected(result);
        switch (result.getStatus()) {
            case OK:
                return result.bytesConsumed();
            case BUFFER_OVERFLOW:
                return 0;
            default:
                throw unexpectedStatusException(result.getStatus());
        }
    }

    private SSLEngineResult sslEngineWrap(ByteBuffer src) throws IOException {
        SSLEngineResult result = sslEngine.wrap(src, wrapBuffer.getWritable(sessionPacketBufferSize));
        failIfCloseSignalDetected(result);
        return result;
    }

    private boolean handshakeUnwrap() throws IOException {
        SSLEngineResult result = sslEngineUnwrap(handshakeDummyBuffer);
        switch (result.getStatus()) {
            case OK:
                if (result.bytesProduced() > 0) throw new SSLException("Got application data in handshake unwrap");
                return true;
            case BUFFER_UNDERFLOW:
                return false;
            default:
                throw unexpectedStatusException(result.getStatus());
        }
    }

    private int applicationDataUnwrap(ByteBuffer dst) throws IOException {
        SSLEngineResult result = sslEngineUnwrap(dst);
        failIfRenegotiationDetected(result);
        switch (result.getStatus()) {
            case OK:
                return result.bytesProduced();
            case BUFFER_OVERFLOW:
            case BUFFER_UNDERFLOW:
                return -1;
            default:
                throw unexpectedStatusException(result.getStatus());
        }
    }

    private SSLEngineResult sslEngineUnwrap(ByteBuffer dst) throws IOException {
        SSLEngineResult result = sslEngine.unwrap(unwrapBuffer.getReadable(), dst);
        failIfCloseSignalDetected(result);
        return result;
    }

    // returns number of bytes read
    private int channelRead() throws IOException {
        int read = channel.read(unwrapBuffer.getWritable(sessionPacketBufferSize));
        if (read == -1) throw new ClosedChannelException();
        return read;
    }

    // returns number of bytes written
    private int channelWrite() throws IOException {
        return channel.write(wrapBuffer.getReadable());
    }

    private static void failIfCloseSignalDetected(SSLEngineResult result) throws ClosedChannelException {
        if (result.getStatus() == Status.CLOSED) throw new ClosedChannelException();
    }

    private static void failIfRenegotiationDetected(SSLEngineResult result) throws SSLException {
        if (result.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING
                && result.getHandshakeStatus() != HandshakeStatus.FINISHED) {
            throw new SSLException("Renegotiation detected");
        }
    }

    private static IllegalStateException unhandledStateException(HandshakeState state) {
        return new IllegalStateException("Unhandled state: " + state);
    }

    private static IllegalStateException unexpectedStatusException(Status status) {
        return new IllegalStateException("Unexpected status: " + status);
    }

    private void verifyHandshakeCompleted() throws SSLException {
        if (handshakeState != HandshakeState.COMPLETED)
            throw new SSLException("Handshake not completed: handshakeState=" + handshakeState);
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy