All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.activemq.transport.mqtt.MQTTPacketIdGenerator Maven / Gradle / Ivy

There is a newer version: 6.1.2
Show newest version
/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.activemq.transport.mqtt;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import org.apache.activemq.Service;
import org.apache.activemq.broker.BrokerService;
import org.apache.activemq.command.ActiveMQMessage;
import org.apache.activemq.util.LRUCache;
import org.apache.activemq.util.ServiceStopper;
import org.apache.activemq.util.ServiceSupport;
import org.fusesource.mqtt.codec.PUBLISH;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Manages PUBLISH packet ids for clients.
 *
 * @author Dhiraj Bokde
 */
public class MQTTPacketIdGenerator extends ServiceSupport {

    private static final Logger LOG = LoggerFactory.getLogger(MQTTPacketIdGenerator.class);
    private static final Object LOCK = new Object();

    Map clientIdMap = new ConcurrentHashMap();

    private final NonZeroSequenceGenerator messageIdGenerator = new NonZeroSequenceGenerator();

    private MQTTPacketIdGenerator() {
    }

    @Override
    protected void doStop(ServiceStopper stopper) throws Exception {
        synchronized (this) {
            clientIdMap = new ConcurrentHashMap();
        }
    }

    @Override
    protected void doStart() throws Exception {
    }

    public void startClientSession(String clientId) {
        if (!clientIdMap.containsKey(clientId)) {
            clientIdMap.put(clientId, new PacketIdMaps());
        }
    }

    public boolean stopClientSession(String clientId) {
        return clientIdMap.remove(clientId) != null;
    }

    public short setPacketId(String clientId, MQTTSubscription subscription, ActiveMQMessage message, PUBLISH publish) {
        final PacketIdMaps idMaps = clientIdMap.get(clientId);
        if (idMaps == null) {
            // maybe its a cleansession=true client id, use session less message id
            final short id = messageIdGenerator.getNextSequenceId();
            publish.messageId(id);
            return id;
        } else {
            return idMaps.setPacketId(subscription, message, publish);
        }
    }

    public void ackPacketId(String clientId, short packetId) {
        final PacketIdMaps idMaps = clientIdMap.get(clientId);
        if (idMaps != null) {
            idMaps.ackPacketId(packetId);
        }
    }

    public short getNextSequenceId(String clientId) {
        final PacketIdMaps idMaps = clientIdMap.get(clientId);
        return idMaps != null ? idMaps.getNextSequenceId(): messageIdGenerator.getNextSequenceId();
    }

    public static MQTTPacketIdGenerator getMQTTPacketIdGenerator(BrokerService broker) {
        MQTTPacketIdGenerator result = null;
        if (broker != null) {
            synchronized (LOCK) {
                Service[] services = broker.getServices();
                if (services != null) {
                    for (Service service : services) {
                        if (service instanceof MQTTPacketIdGenerator) {
                            return (MQTTPacketIdGenerator) service;
                        }
                    }
                }
                result = new MQTTPacketIdGenerator();
                broker.addService(result);
                if (broker.isStarted()) {
                    try {
                        result.start();
                    } catch (Exception e) {
                        LOG.warn("Couldn't start MQTTPacketIdGenerator");
                    }
                }
            }
        }


        return result;
    }

    private class PacketIdMaps {

        private final NonZeroSequenceGenerator messageIdGenerator = new NonZeroSequenceGenerator();
        final Map activemqToPacketIds = new LRUCache(MQTTProtocolConverter.DEFAULT_CACHE_SIZE);
        final Map packetIdsToActivemq = new LRUCache(MQTTProtocolConverter.DEFAULT_CACHE_SIZE);

        short setPacketId(MQTTSubscription subscription, ActiveMQMessage message, PUBLISH publish) {
            // subscription key
            final StringBuilder subscriptionKey = new StringBuilder();
            subscriptionKey.append(subscription.getConsumerInfo().getDestination().getPhysicalName())
                .append(':').append(message.getJMSMessageID());
            final String keyStr = subscriptionKey.toString();
            Short packetId;
            synchronized (activemqToPacketIds) {
                packetId = activemqToPacketIds.get(keyStr);
                if (packetId == null) {
                    packetId = getNextSequenceId();
                    activemqToPacketIds.put(keyStr, packetId);
                    packetIdsToActivemq.put(packetId, keyStr);
                } else {
                    // mark publish as duplicate!
                    publish.dup(true);
                }
            }
            publish.messageId(packetId);
            return packetId;
        }

        void ackPacketId(short packetId) {
            synchronized (activemqToPacketIds) {
                final String subscriptionKey = packetIdsToActivemq.remove(packetId);
                if (subscriptionKey != null) {
                    activemqToPacketIds.remove(subscriptionKey);
                }
            }
        }

        short getNextSequenceId() {
            return messageIdGenerator.getNextSequenceId();
        }

    }

    private class NonZeroSequenceGenerator {

        private short lastSequenceId;

        public synchronized short getNextSequenceId() {
            final short val = ++lastSequenceId;
            return val != 0 ? val : ++lastSequenceId;
        }

    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy