All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.rocketmq.mqtt.common.model.Trie Maven / Gradle / Ivy

There is a newer version: 1.0.1
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.rocketmq.mqtt.common.model;

import org.apache.rocketmq.mqtt.common.util.TopicUtils;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;


public class Trie {

    private TrieNode rootNode = new TrieNode(null);

    public synchronized V addNode(String key, V nodeValue, K nodeKey) {
        try {
            String[] keyArray = key.split(Constants.MQTT_TOPIC_DELIMITER);
            TrieNode currentNode = rootNode;
            int level = 0;
            while (level < keyArray.length) {
                TrieNode trieNode = currentNode.children.get(keyArray[level]);
                if (trieNode == null) {
                    trieNode = new TrieNode(currentNode);
                    TrieNode oldNode = currentNode.children.putIfAbsent(keyArray[level], trieNode);
                    if (oldNode != null) {
                        trieNode = oldNode;
                    }
                }
                level++;
                currentNode = trieNode;
            }
            V old = currentNode.valueSet.put(nodeKey, nodeValue);
            return old;
        } catch (Throwable e) {
            throw new TrieException(e);
        }
    }

    /**
     * @param key
     * @param valueKey
     * @return null if can not find the key and valueKey or return the value
     */
    public synchronized V deleteNode(String key, K valueKey) {
        try {
            String[] keyArray = key.split(Constants.MQTT_TOPIC_DELIMITER);
            TrieNode currentNode = rootNode;
            int level = 0;
            while (level < keyArray.length) {
                TrieNode trieNode = currentNode.children.get(keyArray[level]);
                if (trieNode == null) {
                    break;
                }
                level++;
                currentNode = trieNode;
            }
            V oldValue = currentNode.valueSet.remove(valueKey);
            //clean the empty node
            while (currentNode.children.isEmpty() && currentNode.valueSet.isEmpty() && currentNode.parentNode != null) {
                currentNode.parentNode.children.remove(keyArray[--level]);
                currentNode = currentNode.parentNode;
            }
            return oldValue;
        } catch (Throwable e) {
            throw new TrieException(e);
        }
    }

    public long countSubRecords() {
        return countLevelRecords(rootNode);
    }

    private long countLevelRecords(TrieNode currentNode) {
        if (currentNode == null) {
            return 0;
        }
        if (currentNode.children.isEmpty()) {
            return currentNode.valueSet.size();
        }
        long childrenCount = 0;
        for (Map.Entry> entry : currentNode.children.entrySet()) {
            childrenCount += countLevelRecords(entry.getValue());
        }
        return childrenCount + currentNode.valueSet.size();
    }

    public Map getNode(String key) {
        try {
            String[] keyArray = key.split(Constants.MQTT_TOPIC_DELIMITER);
            Map result = findValueSet(rootNode, keyArray, 0, keyArray.length, false);
            return result;
        } catch (Throwable e) {
            throw new TrieException(e);
        }
    }

    public void traverseAll(TrieMethod method) {
        StringBuilder builder = new StringBuilder(128);
        traverse(rootNode, method, builder);
    }

    public Set getNodePath(String key) {
        try {
            String[] keyArray = key.split(Constants.MQTT_TOPIC_DELIMITER);
            StringBuilder builder = new StringBuilder(key.length());
            Set result = findValuePath(rootNode, keyArray, 0, keyArray.length, builder, false);
            return result;
        } catch (Throwable e) {
            throw new TrieException(e);
        }
    }

    private Set findValuePath(TrieNode currentNode, String[] topicArray, int level, int maxLevel,
                                      StringBuilder builder, boolean isNumberSign) {
        Set result = new HashSet<>();
        // match end of path
        boolean isPathEnd = (level == maxLevel || isNumberSign) && !currentNode.valueSet.isEmpty() && builder.length() > 0;
        if (isPathEnd) {
            result.add(TopicUtils.normalizeTopic(builder.toString().substring(0, builder.length() - 1)));
        }
        // match the '#'
        TrieNode numberMatch = currentNode.children.get(Constants.NUMBER_SIGN);
        if (numberMatch != null) {
            int start = builder.length();
            builder.append(Constants.NUMBER_SIGN).append(Constants.MQTT_TOPIC_DELIMITER);
            result.addAll(findValuePath(numberMatch, topicArray, level + 1, maxLevel, builder, true));
            builder.delete(start, builder.length());
        }
        // match the mqtt-topic path
        if (level < maxLevel && !currentNode.children.isEmpty()) {
            // match the precise
            TrieNode trieNode = currentNode.children.get(topicArray[level]);
            if (trieNode != null) {
                int start = builder.length();
                builder.append(topicArray[level]).append(Constants.MQTT_TOPIC_DELIMITER);
                result.addAll(findValuePath(trieNode, topicArray, level + 1, maxLevel, builder, false));
                builder.delete(start, builder.length());
            }
            // match the '+'
            TrieNode plusMatch = currentNode.children.get(Constants.PLUS_SIGN);
            if (plusMatch != null) {
                int start = builder.length();
                builder.append(Constants.PLUS_SIGN).append(Constants.MQTT_TOPIC_DELIMITER);
                result.addAll(findValuePath(plusMatch, topicArray, level + 1, maxLevel, builder, false));
                builder.delete(start, builder.length());
            }
        }
        return result;
    }

    private void traverse(TrieNode currentNode, TrieMethod method, StringBuilder builder) {
        for (Map.Entry> entry : currentNode.children.entrySet()) {
            int start = builder.length();
            builder.append(entry.getKey()).append(Constants.MQTT_TOPIC_DELIMITER);
            traverse(entry.getValue(), method, builder);
            builder.delete(start, builder.length());
        }
        Iterator> iterator = currentNode.valueSet.entrySet().iterator();
        while (iterator.hasNext()) {
            Map.Entry entry = iterator.next();
            try {
                method.doMethod(builder.toString(), entry.getKey());
            } catch (Throwable e) {
            }
        }
    }

    private Map findValueSet(TrieNode currentNode, String[] topicArray, int level, int maxLevel,
                                    boolean isNumberSign) {
        Map result = new HashMap<>(16);
        // match the mqtt-topic leaf or match the leaf node of trie
        if (level == maxLevel || isNumberSign) {
            result.putAll(currentNode.valueSet);
        }
        // match the '#'
        TrieNode numberMatch = currentNode.children.get(Constants.NUMBER_SIGN);
        if (numberMatch != null) {
            result.putAll(findValueSet(numberMatch, topicArray, level + 1, maxLevel, true));
        }
        // match the mqtt-topic path
        if (level < maxLevel && !currentNode.children.isEmpty()) {
            // match the precise
            TrieNode trieNode = currentNode.children.get(topicArray[level]);
            if (trieNode != null) {
                result.putAll(findValueSet(trieNode, topicArray, level + 1, maxLevel, false));
            }
            // match the '+'
            TrieNode plusMatch = currentNode.children.get(Constants.PLUS_SIGN);
            if (plusMatch != null) {
                result.putAll(findValueSet(plusMatch, topicArray, level + 1, maxLevel, false));
            }
        }
        return result;
    }

    class TrieNode {
        public TrieNode parentNode;
        public Map> children = new ConcurrentHashMap<>();
        public Map valueSet = new ConcurrentHashMap<>();

        public TrieNode(TrieNode parentNode) {
            this.parentNode = parentNode;
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy