All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.epam.deltix.util.vsocket.VSServerFramework Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2023 EPAM Systems, Inc
 *
 * See the NOTICE file distributed with this work for additional information
 * regarding copyright ownership. Licensed under the Apache License,
 * Version 2.0 (the "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */
package com.epam.deltix.util.vsocket;

import com.epam.deltix.util.ContextContainer;
import com.epam.deltix.util.collections.generated.IntegerToObjectHashMap;
import com.epam.deltix.util.concurrent.QuickExecutor;
import com.epam.deltix.util.io.aeron.DXAeron;
import com.epam.deltix.util.io.offheap.OffHeap;
import com.epam.deltix.util.lang.DisposableListener;
import com.epam.deltix.util.tomcat.ConnectionHandshakeHandler;
import com.epam.deltix.util.vsocket.transport.Connection;
import com.epam.deltix.util.vsocket.transport.SocketConnectionFactory;

import java.net.*;
import java.util.*;
import java.io.*;
import java.util.logging.Level;

/**
 *
 */
public class VSServerFramework implements ConnectionHandshakeHandler, DisposableListener, Closeable {
    public static volatile VSServerFramework INSTANCE = null;

    static final int         MIN_COMPATIBLE_CLIENT_VERSION = 1014;
    static final int         MAX_COMPATIBLE_CLIENT_VERSION = VSProtocol.VERSION;

    public final static int        MAX_CONNECTIONS                = 100;
    public final static short      MAX_SOCKETS_PER_CONNECTION      = 8;

    private final Map         dispatchers =
        new HashMap <> ();

    private final QuickExecutor                 executor;
    private final ContextContainer contextContainer;

    private volatile VSConnectionListener       connectionListener;
    private final int                           connectionsLimit;
    private final short                         transportsLimit;
    private final long                          time;
    private int                                 reconnectInterval;
    private final VSCompression                 compression;

    private TLSContext                          tlsContext;

    private TransportType                       transportType = TransportType.SOCKET_TCP;

    public static final Comparator  comparator = new Comparator () {

        @Override
        public int              compare (VSDispatcher o1, VSDispatcher o2) {
            return (o1.getCreationDate ().compareTo (o2.getCreationDate ()));
        }
    };

    public VSServerFramework(QuickExecutor executor, int reconnectInterval,
                             VSCompression compression, int connectionsLimit, short socketsPerConnection, ContextContainer contextContainer) {
        this.executor = executor;
        this.reconnectInterval = reconnectInterval;
        this.time = System.currentTimeMillis();
        this.compression = compression;
        this.connectionsLimit = connectionsLimit;
        this.transportsLimit = socketsPerConnection;
        this.contextContainer = contextContainer;
        INSTANCE = this;
    }

    public VSServerFramework(QuickExecutor executor, int reconnectInterval, VSCompression compression, ContextContainer contextContainer) {
        this(executor, reconnectInterval, compression, MAX_CONNECTIONS, MAX_SOCKETS_PER_CONNECTION, contextContainer);
    }

    public QuickExecutor getExecutor () {
        return executor;
    }

    public int                  getDispatchersCount() {
        synchronized (dispatchers) {
            return dispatchers.size();
        }
    }
    public VSDispatcher []      getDispatchers () {
        VSDispatcher [] ret;

        synchronized (dispatchers) {
            ret = new VSDispatcher[dispatchers.size()];
            int index = 0;
            for (Connector value : dispatchers.values())
                ret[index++] = value.dispatcher;
        }

        Arrays.sort (ret, comparator);
        return (ret);
    }

    public VSDispatcher         getDispatcher(String id) {
        synchronized (dispatchers) {
            Connector connector = dispatchers.get(id);
            return connector != null ? connector.dispatcher : null;
        }
    }

    /*
        Returns current throughput: bytes per second
     */

    public long                 getThroughput() {
        long throughput = 0;
        synchronized (dispatchers) {
            for (Connector value : dispatchers.values())
                throughput += value.dispatcher.getAverageThroughput();
        }

        return throughput;
    }

    public void                 setConnectionListener (VSConnectionListener lnr) {
        connectionListener = lnr;
    }

    public void                 initSSLSocketFactory(TLSContext context) {
        this.tlsContext = context;
    }

    public void                 initTransport(TransportProperties transportProperties) {
        if (transportProperties != null) {
            transportType = transportProperties.transportType;
            if (transportType == TransportType.AERON_IPC) {
                DXAeron.start(transportProperties.transportDir, false);
            } else if (transportType == TransportType.OFFHEAP_IPC) {
                OffHeap.start(transportProperties.transportDir, true);
            }
        }
    }

    public boolean              handleHandshake (Socket s) throws IOException {
        s.setSoTimeout (0);
        s.setTcpNoDelay(true);
        s.setKeepAlive(true);

        return handleHandshake(
            SocketConnectionFactory.createConnection(
                s, new BufferedInputStream(s.getInputStream()), s.getOutputStream())
        );
    }

    /** Handles inbound HTTP connection handshake when TB is running inside Tomcat */
    @Override
    public boolean handleHandshake(Socket s, BufferedInputStream is, OutputStream os) throws IOException {
        s.setSoTimeout (0);
        s.setTcpNoDelay(true);
        s.setKeepAlive(true);

        return handleHandshake(
            SocketConnectionFactory.createConnection(s, is, os)
        );
    }

    /**
     *  Handle the initial transport-level handshake with a client.
     *
     *  @param c    Socket to perform the handshake with.
     *  @return     true if the connection is accepted and socket added to the
     *              set of transport channels. false if socket should be closed.
     *
     *  @throws IOException
     */
    public boolean              handleHandshake (Connection c) throws IOException {
        Level handshakeTimeLogLevel = Level.FINE;
        long handshakeStart = VSProtocol.LOGGER.isLoggable(handshakeTimeLogLevel) ? System.currentTimeMillis() : 0;
        try {
            return handleHandshakeInternal(c);
        } finally {
            if (VSProtocol.LOGGER.isLoggable(handshakeTimeLogLevel)) {
                long handshakeEnd = System.currentTimeMillis();
                VSProtocol.LOGGER.log(handshakeTimeLogLevel, "handleHandshake took " + (handshakeEnd - handshakeStart) + " ms");
            }
        }
    }

    /**
     * See {@link #handleHandshake(Connection)}.
     */
    private boolean              handleHandshakeInternal (Connection c) throws IOException {

        processSSLHandshake(c);

        DataInputStream     dis = new DataInputStream (c.getInputStream());
        DataOutputStream    dout = new DataOutputStream (c.getOutputStream());

        int                 clientVersion = dis.readInt ();

        boolean isCompatible = clientVersion >= MIN_COMPATIBLE_CLIENT_VERSION && clientVersion <= MAX_COMPATIBLE_CLIENT_VERSION;

        dout.writeInt (isCompatible ? clientVersion : VSProtocol.VERSION);
        dout.writeUTF(String.valueOf(VSProtocol.VERSION));
        dout.flush ();

        String              clientId = dis.readUTF ();
        //dout.writeUTF (Version.VERSION_STRING);

        if (!isCompatible) {
            VSProtocol.LOGGER.severe (
                "Connection from " + clientId + " rejected due to incompatible protocol version #" +
                clientVersion + " (accepted: " +
                MIN_COMPATIBLE_CLIENT_VERSION + " .. " +
                MAX_COMPATIBLE_CLIENT_VERSION + ")"
            );

            dout.writeByte (VSProtocol.CONN_RESP_INCOMPATIBLE_CLIENT);
            dout.flush ();
            return (false);
        }

        dout.writeInt(tlsContext == null ? 0 : tlsContext.port);

        if (clientVersion > 1014)
            processTransportHandshake(c);

        boolean             isNew = dis.readBoolean();
        int                 sCode = dis.readInt();
        long                recieved = dis.readLong();

        Connector           connector = process(clientId);
        if (connector == null) {
            dout.writeByte (VSProtocol.CONN_RESP_CONNECTION_REJECTED);
            dout.flush ();
            return (false);
        }

        if (VSProtocol.LOGGER.isLoggable (Level.FINE))
            VSProtocol.LOGGER.fine ("Accepted connection from " + clientId);

        VSocketRecoveryInfo brokenSocketRecoveryInfo = isNew ? null : connector.remove(sCode);
        VSocket broken = null;

        if (brokenSocketRecoveryInfo != null) {
            synchronized (brokenSocketRecoveryInfo) {
                // Prevent other threads against attempting to recover same socket
                if (brokenSocketRecoveryInfo.startRecoveryAttempt()) {
                    broken = brokenSocketRecoveryInfo.getSocket();
                } else {
                    VSProtocol.LOGGER.fine("Recovery attempt failed because another attempt for this thread in progress");
                }
            }
        }

        boolean success = false;
        try {

            String transportTag = sCode + " / " + Integer.toHexString(sCode);
            if (broken != null) {
                VSProtocol.LOGGER.info("Restoring connection (" + transportTag + ") for " + clientId);
                broken.getOutputStream().confirm(recieved);
            } else {
                if (!isNew) {
                    VSProtocol.LOGGER.warning("Connection restore failed for transport (" + transportTag + ") for " + clientId + " because server side transport is not found");
                }
            }

            dout.writeByte(VSProtocol.CONN_RESP_OK);
            dout.writeLong(time);
            dout.writeInt(reconnectInterval);
            dout.writeUTF(compression.toString());

            // writing -1 means socket wasn't found
            dout.writeLong(broken != null ? broken.getInputStream().getBytesRead() : (isNew ? 0 : -1));
            dout.flush();

            VSocket socket = broken != null ? c.create(broken) : c.create(sCode);

            if (!connector.addTransportChannel(socket)) {
                VSProtocol.LOGGER.info("Connection(" + transportTag + ") rejected for " + clientId);
                return (false);
            }

            if (broken != null)
                VSProtocol.LOGGER.info("Connection(" + transportTag + ") restored for " + clientId);

            success = true;
            return (true);
        } finally {
            if (broken != null) {
                // assert brokenSocketRecoveryInfo != null;
                synchronized (brokenSocketRecoveryInfo) {
                    brokenSocketRecoveryInfo.stopRecoveryAttempt();
                    if (success) {
                        brokenSocketRecoveryInfo.markRecoverySucceeded();
                    }
                    brokenSocketRecoveryInfo.notifyAll();
                }
            }
        }
    }

    private Connector                process(String clientId) {
        Connector connector;

        synchronized (dispatchers) {
            connector = dispatchers.get(clientId);

            if (connector == null && dispatchers.size() >= connectionsLimit) {
                VSProtocol.LOGGER.severe (
                        "Connection from " + clientId + " rejected due to connections limit = (" + connectionsLimit + ")"
                );
                return null;
            }

            if (connector == null) {
                VSDispatcher dispatcher = new VSDispatcher (clientId, false, contextContainer);
                dispatcher.setConnectionListener(connectionListener);
                dispatcher.setLingerInterval(reconnectInterval);
                dispatcher.addDisposableListener(this);
                dispatchers.put (clientId, (connector = new Connector(dispatcher, transportsLimit)));
            }
        }

        return connector;
    }

    private void                    processSSLHandshake(Connection c) throws IOException {
        boolean enableSSL = (tlsContext != null);
        boolean sslForLoopback = (tlsContext != null) && !tlsContext.preserveLoopback;

        BufferedInputStream     is = c.getInputStream();
        OutputStream            os = c.getOutputStream();

        //initial byte (0 for VS protocol)
        int b = is.read();
        assert b == 0;

        int clientHeader = is.read();
        if (clientHeader == VSProtocol.SSL_HEADER && !enableSSL) {
            os.write(VSProtocol.CONN_RESP_SSL_NOT_SUPPORTED);
            throw new IOException("Client wants SSL but server have not prepared for handshake.");
        }
        os.write(VSProtocol.CONN_RESP_OK);

        //server choice
        boolean sslConnection = false;
        if (enableSSL && sslForLoopback)
            sslConnection = true;
        else if (enableSSL && !sslForLoopback && !c.isLoopback())
            sslConnection = true;
        else if (enableSSL && !sslForLoopback && c.isLoopback() && clientHeader == VSProtocol.SSL_HEADER)
            sslConnection = true;

        //write server decision (SSL or NON-SSL)
        if (sslConnection) {
            os.write(VSProtocol.SSL_HEADER);
            c.upgradeToSSL(tlsContext.context.getSocketFactory());
            VSProtocol.LOGGER.fine("Socket upgraded to SSL socket! Now connection is secured.");
        } else {
            os.write(VSProtocol.HEADER);
        }
    }

    private void                    processTransportHandshake(Connection c) throws IOException {
        DataOutputStream dout = new DataOutputStream (c.getOutputStream());

        TransportType type = (!c.isLoopback() || transportType == null) ? TransportType.SOCKET_TCP : transportType;
        dout.writeInt(type.ordinal());

        c.setTransportType(type);
        if (type == TransportType.AERON_IPC)
            dout.writeUTF(DXAeron.getAeronDir());
        else if (type == TransportType.OFFHEAP_IPC)
            dout.writeUTF(OffHeap.getOffHeapDir());
    }

    public VSCompression            getCompression() {
        return compression;
    }

    @Override
    public void                     disposed(VSDispatcher resource) {
        synchronized (dispatchers) {
            Connector c = dispatchers.remove(resource.getClientId());
            if (c != null)
                c.close();
        }
    }

    @Override
    public void close() {
        if (transportType == TransportType.AERON_IPC)
            DXAeron.shutdown();
    }

    static class Connector extends ConnectionStateListener implements Closeable {
        // May contain null values. Null value indicates that transport is still considered active (not stopped).
        private final IntegerToObjectHashMap   stopped =
                new IntegerToObjectHashMap<>();

        private final VSDispatcher dispatcher;
        private final int           limit; // limit of transport channels

        Connector(VSDispatcher d, int limit) {
            this.dispatcher = d;
            this.limit = limit;
            dispatcher.setStateListener(this);
        }

        boolean addTransportChannel(VSocket socket) throws IOException {

            synchronized (stopped) {

                // not accept new channels, only that can be restored
                if (!stopped.containsKey(socket.getCode()) && stopped.size() >= limit)
                    return false;

                // Note: we may override previous state of socket here (in case of recovery)
                stopped.put(socket.getCode(), ACTIVE);
                stopped.notifyAll();
            }

            dispatcher.addTransportChannel(socket);
            return true;
        }

        @Override
        boolean onTransportStopped(VSocketRecoveryInfo recoveryInfo) {
            VSocket socket = recoveryInfo.getSocket();

            int code = socket.getCode();
            synchronized (stopped) {
                VSocketRecoveryInfo prevValue = stopped.get(code, null);
                if (prevValue != ACTIVE) {
                    String clientId = dispatcher.getClientId();
                    if (prevValue == null) {
                        VSProtocol.LOGGER.severe("Transport for " + clientId + " has stopped. It can't be recovered because the recovery state is invalid.");
                    } else {
                        VSProtocol.LOGGER.warning("Transport for " + clientId + " has stopped. It can't be recovered because a recovery attempt is already initiated.");
                    }
                    return true;
                }
                stopped.put(code, recoveryInfo);
                stopped.notifyAll();
                return false;
            }
        }

        @Override
        boolean onTransportBroken(VSocketRecoveryInfo recoveryInfo) {
            try {
                synchronized (recoveryInfo) {
                    while (recoveryInfo.isRecoveryAttemptInProgress()) {
                        recoveryInfo.wait();
                    }
                    if (recoveryInfo.isRecoverySucceeded()) {
                        return false;
                    }
                    recoveryInfo.markRecoveryFailed();

                    VSocket socket = recoveryInfo.getSocket();
                    int code = socket.getCode();
                    synchronized (stopped) {
                        VSocketRecoveryInfo currentValue = stopped.get(code, null);
                        if (currentValue == recoveryInfo) {
                            stopped.remove(code);
                            stopped.notifyAll();
                            recoveryInfo.notifyAll();
                            return true;
                        }
                    }

                    recoveryInfo.notifyAll();
                    return false;
                }
            } catch (InterruptedException e) {
                return true;
            }
        }

        @Override
        public void         close() {
            dispatcher.setStateListener(null);

            synchronized (stopped) {
                stopped.clear();
                stopped.notifyAll();
            }
        }

        /**
         * Attempts to get transport for a recovery attempt.
         */
        public VSocketRecoveryInfo remove(int code) {
            try {
                synchronized (stopped) {
                    VSocketRecoveryInfo currentValue;
                    while (true) {
                        currentValue = stopped.get(code, null);
                        // request for restoring connection may came faster that socket is stopped
                        if (currentValue == ACTIVE) {
                            stopped.wait();
                        } else {
                            break;
                        }
                    }

                    // assert currentValue != ACTIVE

                    if (currentValue == null) {
                        // This transport never existed OR permanently removed
                        return null;
                    }

                    return currentValue;
                }
            } catch (InterruptedException e) {
                return null;
            }
        }

        @Override
        void onDisconnected() {
            synchronized (stopped) {
                stopped.clear();
                stopped.notifyAll();
            }
        }

        @Override
        void onReconnected() {
        }
    }

    // Special marker values

    // This value indicates that the transport for corresponding code is active (not stopped yet)
    private static final FakeRecoveryInfo ACTIVE = new FakeRecoveryInfo("ACTIVE");

    /**
     * Special class to represent special values in {@link Connector#stopped} map.
     */
    private static class FakeRecoveryInfo extends VSocketRecoveryInfo {
        private final String label;

        FakeRecoveryInfo(String label) {
            super(null, Long.MIN_VALUE);
            this.label = label;
        }

        @Override
        public String toString() {
            return "FakeVSocket{" +
                    "" + label + '\'' +
                    '}';
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy