All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.netty.protocol.mqtt.MqttSessionRegistry Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2012-2018 The original author or authors
 * ------------------------------------------------------
 * All rights reserved. This program and the accompanying materials
 * are made available under the terms of the Eclipse Public License v1.0
 * and Apache License v2.0 which accompanies this distribution.
 *
 * The Eclipse Public License is available at
 * http://www.eclipse.org/legal/epl-v10.html
 *
 * The Apache License v2.0 is available at
 * http://www.opensource.org/licenses/apache2.0.php
 *
 * You may elect to redistribute this code under either of these licenses.
 */
package com.github.netty.protocol.mqtt;

import com.github.netty.core.util.LoggerFactoryX;
import com.github.netty.core.util.LoggerX;
import com.github.netty.protocol.mqtt.MqttSession.SessionStatus;
import com.github.netty.protocol.mqtt.exception.MqttSessionCorruptedException;
import com.github.netty.protocol.mqtt.subscriptions.ISubscriptionsDirectory;
import com.github.netty.protocol.mqtt.subscriptions.Subscription;
import com.github.netty.protocol.mqtt.subscriptions.Topic;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.handler.codec.mqtt.MqttConnectMessage;
import io.netty.handler.codec.mqtt.MqttQoS;

import java.net.InetSocketAddress;
import java.util.Collection;
import java.util.Optional;
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.stream.Collectors;

public class MqttSessionRegistry {

    private static final LoggerX LOG = LoggerFactoryX.getLogger(MqttSessionRegistry.class);
    private final ConcurrentMap pool = new ConcurrentHashMap<>();
    private final ISubscriptionsDirectory subscriptionsDirectory;
    private final IQueueRepository queueRepository;
    private final ConcurrentMap> queues = new ConcurrentHashMap<>();

    public MqttSessionRegistry(ISubscriptionsDirectory subscriptionsDirectory, IQueueRepository queueRepository) {
        this.subscriptionsDirectory = subscriptionsDirectory;
        this.queueRepository = queueRepository;
    }

    void bindToSession(MqttConnection mqttConnection, MqttConnectMessage msg, String clientId) {
        boolean isSessionAlreadyStored = false;
        PostConnectAction postConnectAction = PostConnectAction.NONE;
        if (!pool.containsKey(clientId)) {
            // case 1
            final MqttSession newSession = createNewSession(mqttConnection, msg, clientId);

            // publish the session
            final MqttSession previous = pool.putIfAbsent(clientId, newSession);
            final boolean success = previous == null;

            if (success) {
                LOG.trace("case 1, not existing session with CId {}", clientId);
            } else {
                postConnectAction = bindToExistingSession(mqttConnection, msg, clientId, newSession);
                isSessionAlreadyStored = true;
            }
        } else {
            final MqttSession newSession = createNewSession(mqttConnection, msg, clientId);
            postConnectAction = bindToExistingSession(mqttConnection, msg, clientId, newSession);
            isSessionAlreadyStored = true;
        }
        final boolean msgCleanSessionFlag = msg.variableHeader().isCleanSession();
        boolean isSessionAlreadyPresent = !msgCleanSessionFlag && isSessionAlreadyStored;
        mqttConnection.sendConnAck(isSessionAlreadyPresent);

        if (postConnectAction == PostConnectAction.SEND_STORED_MESSAGES) {
            final MqttSession session = pool.get(clientId);
            session.sendQueuedMessagesWhileOffline();
        }
    }

    Collection listConnectedClients() {
        return pool.values().stream()
                .filter(MqttSession::connected)
                .map(this::createClientDescriptor)
                .filter(Optional::isPresent)
                .map(Optional::get)
                .collect(Collectors.toList());
    }

    private Optional createClientDescriptor(MqttSession s) {
        final String clientID = s.getClientID();
        final Optional remoteAddressOpt = s.remoteAddress();
        return remoteAddressOpt.map(r -> new MqttClientDescriptor(clientID, r.getHostString(), r.getPort()));
    }

    private PostConnectAction bindToExistingSession(MqttConnection mqttConnection, MqttConnectMessage msg,
                                                    String clientId, MqttSession newSession) {
        PostConnectAction postConnectAction = PostConnectAction.NONE;
        final boolean newIsClean = msg.variableHeader().isCleanSession();
        final MqttSession oldSession = pool.get(clientId);
        if (newIsClean && oldSession.disconnected()) {
            // case 2
            dropQueuesForClient(clientId);
            unsubscribe(oldSession);

            // publish new session
            boolean result = oldSession.assignState(SessionStatus.DISCONNECTED, SessionStatus.CONNECTING);
            if (!result) {
                throw new MqttSessionCorruptedException("old session was already changed state");
            }
            copySessionConfig(msg, oldSession);
            oldSession.bind(mqttConnection);

            result = oldSession.assignState(SessionStatus.CONNECTING, SessionStatus.CONNECTED);
            if (!result) {
                throw new MqttSessionCorruptedException("old session moved in connected state by other thread");
            }
            final boolean published = pool.replace(clientId, oldSession, oldSession);
            if (!published) {
                throw new MqttSessionCorruptedException("old session was already removed");
            }
            LOG.trace("case 2, oldSession with same CId {} disconnected", clientId);
        } else if (!newIsClean && oldSession.disconnected()) {
            // case 3
            reactivateSubscriptions(oldSession);

            // mark as connected
            final boolean connecting = oldSession.assignState(SessionStatus.DISCONNECTED, SessionStatus.CONNECTING);
            if (!connecting) {
                throw new MqttSessionCorruptedException("old session moved in connected state by other thread");
            }
            oldSession.bind(mqttConnection);

            final boolean connected = oldSession.assignState(SessionStatus.CONNECTING, SessionStatus.CONNECTED);
            if (!connected) {
                throw new MqttSessionCorruptedException("old session moved in other state state by other thread");
            }

            // publish new session
            final boolean published = pool.replace(clientId, oldSession, oldSession);
            if (!published) {
                throw new MqttSessionCorruptedException("old session was already removed");
            }
            postConnectAction = PostConnectAction.SEND_STORED_MESSAGES;
            LOG.trace("case 3, oldSession with same CId {} disconnected", clientId);
        } else if (oldSession.connected()) {
            // case 4
            LOG.trace("case 4, oldSession with same CId {} still connected, force to close", clientId);
            oldSession.closeImmediately();
            //remove(clientId);
            // publish new session
            final boolean published = pool.replace(clientId, oldSession, newSession);
            if (!published) {
                throw new MqttSessionCorruptedException("old session was already removed");
            }
        }
        // case not covered new session is clean true/false and old session not in CONNECTED/DISCONNECTED
        return postConnectAction;
    }

    private void reactivateSubscriptions(MqttSession session) {
        for (Subscription existingSub : session.getSubscriptions()) {
            // TODO
//            subscriptionsDirectory.reactivate(existingSub.getTopicFilter(), session.getClientID());
        }
    }

    private void unsubscribe(MqttSession session) {
        for (Subscription existingSub : session.getSubscriptions()) {
            subscriptionsDirectory.removeSubscription(existingSub.getTopicFilter(), session.getClientID());
        }
    }

    private MqttSession createNewSession(MqttConnection mqttConnection, MqttConnectMessage msg, String clientId) {
        final boolean clean = msg.variableHeader().isCleanSession();
        final Queue sessionQueue =
                queues.computeIfAbsent(clientId, (String cli) -> queueRepository.createQueue(cli, clean));
        final MqttSession newSession;
        if (msg.variableHeader().isWillFlag()) {
            final MqttSession.Will will = createWill(msg);
            newSession = new MqttSession(clientId, clean, will, sessionQueue);
        } else {
            newSession = new MqttSession(clean, clientId, sessionQueue);
        }

        newSession.markConnected();
        newSession.bind(mqttConnection);

        return newSession;
    }

    private void copySessionConfig(MqttConnectMessage msg, MqttSession session) {
        final boolean clean = msg.variableHeader().isCleanSession();
        final MqttSession.Will will;
        if (msg.variableHeader().isWillFlag()) {
            will = createWill(msg);
        } else {
            will = null;
        }
        session.update(clean, will);
    }

    private MqttSession.Will createWill(MqttConnectMessage msg) {
        final ByteBuf willPayload = Unpooled.copiedBuffer(msg.payload().willMessageInBytes());
        final String willTopic = msg.payload().willTopic();
        final boolean retained = msg.variableHeader().isWillRetain();
        final MqttQoS qos = MqttQoS.valueOf(msg.variableHeader().willQos());
        return new MqttSession.Will(willTopic, willPayload, qos, retained);
    }

    MqttSession retrieve(String clientID) {
        return pool.get(clientID);
    }

    public void remove(String clientID) {
        pool.remove(clientID);
    }

    public void disconnect(String clientID) {
        final MqttSession session = retrieve(clientID);
        if (session == null) {
            LOG.debug("Some other thread already removed the session CId={}", clientID);
            return;
        }
        session.disconnect();
    }

    private void dropQueuesForClient(String clientId) {
        queues.remove(clientId);
    }

    private enum PostConnectAction {
        NONE, SEND_STORED_MESSAGES
    }

    public abstract static class EnqueuedMessage {
    }

    static class PublishedMessage extends EnqueuedMessage {

        final Topic topic;
        final MqttQoS publishingQos;
        final ByteBuf payload;

        PublishedMessage(Topic topic, MqttQoS publishingQos, ByteBuf payload) {
            this.topic = topic;
            this.publishingQos = publishingQos;
            this.payload = payload;
        }
    }

    static final class PubRelMarker extends EnqueuedMessage {
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy