All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.lettuce.core.pubsub.PubSubEndpoint Maven / Gradle / Ivy

Go to download

Advanced and thread-safe Java Redis client for synchronous, asynchronous, and reactive usage. Supports Cluster, Sentinel, Pipelining, Auto-Reconnect, Codecs and much more.

The newest version!
/*
 * Copyright 2011-Present, Redis Ltd. and Contributors
 * All rights reserved.
 *
 * Licensed under the MIT License.
 *
 * This file contains contributions from third-party contributors
 * licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.lettuce.core.pubsub;

import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;

import io.lettuce.core.ClientOptions;
import io.lettuce.core.ConnectionState;
import io.lettuce.core.RedisException;
import io.lettuce.core.protocol.CommandType;
import io.lettuce.core.protocol.DefaultEndpoint;
import io.lettuce.core.protocol.ProtocolVersion;
import io.lettuce.core.protocol.RedisCommand;
import io.lettuce.core.resource.ClientResources;
import io.netty.channel.Channel;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;

/**
 * @author Mark Paluch
 * @author dengliming
 */
public class PubSubEndpoint extends DefaultEndpoint {

    private static final InternalLogger logger = InternalLoggerFactory.getInstance(PubSubEndpoint.class);

    private static final Set ALLOWED_COMMANDS_SUBSCRIBED;

    private static final Set SUBSCRIBE_COMMANDS;

    private final List> listeners = new CopyOnWriteArrayList<>();

    private final Set> channels;

    private final Set> shardChannels;

    private final Set> patterns;

    private volatile boolean subscribeWritten = false;

    private ConnectionState connectionState;

    static {

        ALLOWED_COMMANDS_SUBSCRIBED = new HashSet<>(6, 1);

        ALLOWED_COMMANDS_SUBSCRIBED.add(CommandType.SUBSCRIBE.name());
        ALLOWED_COMMANDS_SUBSCRIBED.add(CommandType.PSUBSCRIBE.name());
        ALLOWED_COMMANDS_SUBSCRIBED.add(CommandType.UNSUBSCRIBE.name());
        ALLOWED_COMMANDS_SUBSCRIBED.add(CommandType.PUNSUBSCRIBE.name());
        ALLOWED_COMMANDS_SUBSCRIBED.add(CommandType.SSUBSCRIBE.name());
        ALLOWED_COMMANDS_SUBSCRIBED.add(CommandType.QUIT.name());
        ALLOWED_COMMANDS_SUBSCRIBED.add(CommandType.PING.name());

        SUBSCRIBE_COMMANDS = new HashSet<>(2, 1);

        SUBSCRIBE_COMMANDS.add(CommandType.SUBSCRIBE.name());
        SUBSCRIBE_COMMANDS.add(CommandType.PSUBSCRIBE.name());
        SUBSCRIBE_COMMANDS.add(CommandType.SSUBSCRIBE.name());
    }

    /**
     * Initialize a new instance that handles commands from the supplied queue.
     *
     * @param clientOptions client options for this connection, must not be {@code null}
     * @param clientResources client resources for this connection, must not be {@code null}.
     */
    public PubSubEndpoint(ClientOptions clientOptions, ClientResources clientResources) {

        super(clientOptions, clientResources);

        this.channels = ConcurrentHashMap.newKeySet();
        this.patterns = ConcurrentHashMap.newKeySet();
        this.shardChannels = ConcurrentHashMap.newKeySet();
    }

    /**
     * Add a new {@link RedisPubSubListener listener}.
     *
     * @param listener the listener, must not be {@code null}.
     */
    public void addListener(RedisPubSubListener listener) {
        listeners.add(listener);
    }

    /**
     * Remove an existing {@link RedisPubSubListener listener}.
     *
     * @param listener the listener, must not be {@code null}.
     */
    public void removeListener(RedisPubSubListener listener) {
        listeners.remove(listener);
    }

    protected List> getListeners() {
        return listeners;
    }

    public boolean hasChannelSubscriptions() {
        return !channels.isEmpty();
    }

    public Set getChannels() {
        return unwrap(this.channels);
    }

    public boolean hasShardChannelSubscriptions() {
        return !shardChannels.isEmpty();
    }

    public Set getShardChannels() {
        return unwrap(this.shardChannels);
    }

    public boolean hasPatternSubscriptions() {
        return !patterns.isEmpty();
    }

    public Set getPatterns() {
        return unwrap(this.patterns);
    }

    @Override
    public void notifyChannelActive(Channel channel) {
        subscribeWritten = false;
        super.notifyChannelActive(channel);
    }

    @Override
    public  RedisCommand write(RedisCommand command) {

        if (isSubscribed() && !isAllowed(command)) {
            rejectCommand(command);
            return command;
        }

        if (!subscribeWritten && SUBSCRIBE_COMMANDS.contains(command.getType().toString())) {
            subscribeWritten = true;
        }

        return super.write(command);
    }

    @Override
    public  Collection> write(Collection> redisCommands) {

        if (isSubscribed()) {

            if (containsViolatingCommands(redisCommands)) {
                rejectCommands(redisCommands);
                return (Collection>) redisCommands;
            }
        }

        if (!subscribeWritten) {
            for (RedisCommand redisCommand : redisCommands) {
                if (SUBSCRIBE_COMMANDS.contains(redisCommand.getType().toString())) {
                    subscribeWritten = true;
                    break;
                }
            }
        }

        return super.write(redisCommands);
    }

    protected void rejectCommand(RedisCommand command) {
        command.completeExceptionally(
                new RedisException(String.format("Command %s not allowed while subscribed. Allowed commands are: %s",
                        command.getType().toString(), ALLOWED_COMMANDS_SUBSCRIBED)));
    }

    protected void rejectCommands(Collection> redisCommands) {
        for (RedisCommand command : redisCommands) {
            command.completeExceptionally(
                    new RedisException(String.format("Command %s not allowed while subscribed. Allowed commands are: %s",
                            command.getType().toString(), ALLOWED_COMMANDS_SUBSCRIBED)));
        }
    }

    protected boolean containsViolatingCommands(Collection> redisCommands) {

        for (RedisCommand redisCommand : redisCommands) {

            if (!isAllowed(redisCommand)) {
                return true;
            }
        }

        return false;
    }

    private boolean isAllowed(RedisCommand command) {

        ProtocolVersion protocolVersion = connectionState != null ? connectionState.getNegotiatedProtocolVersion() : null;

        if (protocolVersion == null) {
            protocolVersion = getProtocolVersion();
        }

        return protocolVersion == ProtocolVersion.RESP3 || ALLOWED_COMMANDS_SUBSCRIBED.contains(command.getType().toString());
    }

    public boolean isSubscribed() {
        return subscribeWritten && (hasChannelSubscriptions() || hasPatternSubscriptions());
    }

    void setConnectionState(ConnectionState connectionState) {
        this.connectionState = connectionState;
    }

    void notifyMessage(PubSubMessage message) {

        // drop empty messages
        if (message.type() == null || (message.pattern() == null && message.channel() == null && message.body() == null)) {
            return;
        }

        updateInternalState(message);
        try {
            notifyListeners(message);
        } catch (Exception e) {
            logger.error("Unexpected error occurred in RedisPubSubListener callback", e);
        }
    }

    protected void notifyListeners(PubSubMessage message) {
        // update listeners
        for (RedisPubSubListener listener : listeners) {
            switch (message.type()) {
                case message:
                    listener.message(message.channel(), message.body());
                    break;
                case pmessage:
                    listener.message(message.pattern(), message.channel(), message.body());
                    break;
                case psubscribe:
                    listener.psubscribed(message.pattern(), message.count());
                    break;
                case punsubscribe:
                    listener.punsubscribed(message.pattern(), message.count());
                    break;
                case subscribe:
                    listener.subscribed(message.channel(), message.count());
                    break;
                case unsubscribe:
                    listener.unsubscribed(message.channel(), message.count());
                    break;
                case smessage:
                    listener.smessage(message.channel(), message.body());
                    break;
                case ssubscribe:
                    listener.ssubscribed(message.channel(), message.count());
                    break;
                case sunsubscribe:
                    listener.sunsubscribed(message.channel(), message.count());
                    break;
                default:
                    throw new UnsupportedOperationException("Operation " + message.type() + " not supported");
            }
        }
    }

    private void updateInternalState(PubSubMessage message) {
        // update internal state
        switch (message.type()) {
            case psubscribe:
                patterns.add(new Wrapper<>(message.pattern()));
                break;
            case punsubscribe:
                patterns.remove(new Wrapper<>(message.pattern()));
                break;
            case subscribe:
                channels.add(new Wrapper<>(message.channel()));
                break;
            case unsubscribe:
                channels.remove(new Wrapper<>(message.channel()));
                break;
            case ssubscribe:
                shardChannels.add(new Wrapper<>(message.channel()));
                break;
            case sunsubscribe:
                shardChannels.remove(new Wrapper<>(message.channel()));
                break;
            default:
                break;
        }
    }

    private Set unwrap(Set> wrapped) {

        Set result = new LinkedHashSet<>(wrapped.size());

        for (Wrapper channel : wrapped) {
            result.add(channel.name);
        }

        return result;
    }

    /**
     * Comparison/equality wrapper with specific {@code byte[]} equals and hashCode implementations.
     *
     * @param 
     */
    static class Wrapper {

        protected final K name;

        public Wrapper(K name) {
            this.name = name;
        }

        @Override
        public int hashCode() {

            if (name instanceof byte[]) {
                return Arrays.hashCode((byte[]) name);
            }
            return name.hashCode();
        }

        @Override
        public boolean equals(Object obj) {

            if (!(obj instanceof Wrapper)) {
                return false;
            }

            Wrapper that = (Wrapper) obj;

            if (name instanceof byte[] && that.name instanceof byte[]) {
                return Arrays.equals((byte[]) name, (byte[]) that.name);
            }

            return name.equals(that.name);
        }

        @Override
        public String toString() {
            final StringBuffer sb = new StringBuffer();
            sb.append(getClass().getSimpleName());
            sb.append(" [name=").append(name);
            sb.append(']');
            return sb.toString();
        }

    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy