All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.undertow.server.handlers.sse.ServerSentEventConnection Maven / Gradle / Ivy

Go to download

This artifact provides a single jar that contains all classes required to use remote EJB and JMS, including all dependencies. It is intended for use by those not using maven, maven users should just import the EJB and JMS BOM's instead (shaded JAR's cause lots of problems with maven, as it is very easy to inadvertently end up with different versions on classes on the class path).

The newest version!
/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2014 Red Hat, Inc., and individual contributors
 * as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

package io.undertow.server.handlers.sse;

import io.undertow.UndertowLogger;
import io.undertow.connector.PooledByteBuffer;
import io.undertow.security.api.SecurityContext;
import io.undertow.security.idm.Account;
import io.undertow.server.HttpServerExchange;
import io.undertow.util.Attachable;
import io.undertow.util.AttachmentKey;
import io.undertow.util.AttachmentList;
import io.undertow.util.HeaderMap;
import org.xnio.ChannelExceptionHandler;
import org.xnio.ChannelListener;
import org.xnio.ChannelListeners;
import org.xnio.IoUtils;
import org.xnio.XnioExecutor;
import org.xnio.channels.StreamSinkChannel;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.Channel;
import java.nio.channels.ClosedChannelException;
import java.nio.charset.StandardCharsets;
import java.security.Principal;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;

/**
 * Represents the server side of a Server Sent Events connection.
 *
 * The class implements Attachable, which provides access to the underlying exchanges attachments.
 *
 * @author Stuart Douglas
 */
public class ServerSentEventConnection implements Channel, Attachable {

    private final HttpServerExchange exchange;
    private final StreamSinkChannel sink;
    private final SseWriteListener writeListener = new SseWriteListener();

    private PooledByteBuffer pooled;

    private final Deque queue = new ConcurrentLinkedDeque<>();
    private final Queue buffered = new ConcurrentLinkedDeque<>();
    /**
     * Messages that have been written to the channel but flush() has failed
     */
    private final Queue flushingMessages = new ArrayDeque<>();
    private final List> closeTasks = new CopyOnWriteArrayList<>();
    private Map parameters;
    private Map properties = new HashMap<>();

    private static final AtomicIntegerFieldUpdater openUpdater = AtomicIntegerFieldUpdater.newUpdater(ServerSentEventConnection.class, "open");
    private volatile int open = 1;
    private volatile boolean shutdown = false;
    private volatile long keepAliveTime = -1;
    private XnioExecutor.Key timerKey;


    public ServerSentEventConnection(HttpServerExchange exchange, StreamSinkChannel sink) {
        this.exchange = exchange;
        this.sink = sink;
        this.sink.getCloseSetter().set(new ChannelListener() {
            @Override
            public void handleEvent(StreamSinkChannel channel) {
                if(timerKey != null) {
                    timerKey.remove();
                }
                for (ChannelListener listener : closeTasks) {
                    ChannelListeners.invokeChannelListener(ServerSentEventConnection.this, listener);
                }
                IoUtils.safeClose(ServerSentEventConnection.this);
            }
        });
        this.sink.getWriteSetter().set(writeListener);
    }

    /**
     * Adds a listener that will be invoked when the channel is closed
     *
     * @param listener The listener to invoke
     */
    public synchronized void addCloseTask(ChannelListener listener) {
        this.closeTasks.add(listener);
    }

    /**
     *
     * @return The principal that was associated with the SSE request
     */
    public Principal getPrincipal() {
        Account account = getAccount();
        if (account != null) {
            return account.getPrincipal();
        }
        return null;
    }

    /**
     *
     * @return The account that was associated with the SSE request
     */
    public Account getAccount() {
        SecurityContext sc = exchange.getSecurityContext();
        if (sc != null) {
            return sc.getAuthenticatedAccount();
        }
        return null;
    }

    /**
     *
     * @return The request headers from the initial request that opened this connection
     */
    public HeaderMap getRequestHeaders() {
        return exchange.getRequestHeaders();
    }

    /**
     *
     * @return The response headers from the initial request that opened this connection
     */
    public HeaderMap getResponseHeaders() {
        return exchange.getResponseHeaders();
    }

    /**
     *
     * @return The request URI from the initial request that opened this connection
     */
    public String getRequestURI() {
        return exchange.getRequestURI();
    }

    /**
     *
     * @return the query parameters
     */
    public Map> getQueryParameters() {
        return exchange.getQueryParameters();
    }

    /**
     *
     * @return the query string
     */
    public String getQueryString() {
        return exchange.getQueryString();
    }

    /**
     * Sends an event to the remote client
     *
     * @param data The event data
     */
    public void send(String data) {
        send(data, null, null, null);
    }

    /**
     * Sends an event to the remote client
     *
     * @param data The event data
     * @param callback A callback that is notified on Success or failure
     */
    public void send(String data, EventCallback callback) {
        send(data, null, null, callback);
    }

    /**
     * Sends the 'retry' message to the client, instructing it how long to wait before attempting a reconnect.
     *
     * @param retry The retry time in milliseconds
     */
    public void sendRetry(long retry) {
        sendRetry(retry, null);
    }


    /**
     * Sends the 'retry' message to the client, instructing it how long to wait before attempting a reconnect.
     *
     * @param retry The retry time in milliseconds
     * @param callback The callback that is notified on success or failure
     */
    public synchronized void sendRetry(long retry, EventCallback callback) {

        if (open == 0 || shutdown) {
            if (callback != null) {
                callback.failed(this, null, null, null, new ClosedChannelException());
            }
            return;
        }
        queue.add(new SSEData(retry, callback));
        sink.getIoThread().execute(new Runnable() {
            @Override
            public void run() {
                synchronized (ServerSentEventConnection.this) {
                    if (pooled == null) {
                        fillBuffer();
                        writeListener.handleEvent(sink);
                    }
                }
            }
        });
    }

    /**
     * Sends an event to the remote client
     *
     * @param data The event data
     * @param event The event name
     * @param id The event ID
     * @param callback A callback that is notified on Success or failure
     */
    public synchronized void send(String data, String event, String id, EventCallback callback) {
        if (open == 0 || shutdown) {
            if (callback != null) {
                callback.failed(this, data, event, id, new ClosedChannelException());
            }
            return;
        }
        queue.add(new SSEData(event, data, id, callback));
        sink.getIoThread().execute(new Runnable() {
            @Override
            public void run() {
                synchronized (ServerSentEventConnection.this) {
                    if (pooled == null) {
                        fillBuffer();
                        writeListener.handleEvent(sink);
                    }
                }
            }
        });
    }

    public String getParameter(String name) {
        if(parameters == null) {
            return null;
        }
        return parameters.get(name);
    }

    public void setParameter(String name, String value) {
        if(parameters == null) {
            parameters = new HashMap<>();
        }
        parameters.put(name, value);
    }

    public Map getProperties() {
        return properties;
    }

    /**
     *
     *
     * @return The keep alive time
     */
    public long getKeepAliveTime() {
        return keepAliveTime;
    }

    /**
     * Sets the keep alive time in milliseconds. If this is larger than zero a ':' message will be sent this often
     * (assuming there is no activity) to keep the connection alive.
     *
     * The spec recommends a value of 15000 (15 seconds).
     *
     * @param keepAliveTime The time in milliseconds between keep alive messaged
     */
    public void setKeepAliveTime(long keepAliveTime) {
        this.keepAliveTime = keepAliveTime;
        if(this.timerKey != null) {
            this.timerKey.remove();
        }
        this.timerKey = sink.getIoThread().executeAtInterval(new Runnable() {
            @Override
            public void run() {
                if(shutdown || open == 0) {
                    if(timerKey != null) {
                        timerKey.remove();
                    }
                    return;
                }
                if(pooled == null) {
                    pooled = exchange.getConnection().getByteBufferPool().allocate();
                    pooled.getBuffer().put(":\n".getBytes(StandardCharsets.UTF_8));
                    pooled.getBuffer().flip();
                    writeListener.handleEvent(sink);
                }
            }
        }, keepAliveTime, TimeUnit.MILLISECONDS);
    }

    private void fillBuffer() {
        if (queue.isEmpty()) {
            if(pooled != null) {
                pooled.close();
                pooled = null;
                sink.suspendWrites();
            }
            return;
        }

        if (pooled == null) {
            pooled = exchange.getConnection().getByteBufferPool().allocate();
        } else {
            pooled.getBuffer().clear();
        }
        ByteBuffer buffer = pooled.getBuffer();

        while (!queue.isEmpty() && buffer.hasRemaining()) {
            SSEData data = queue.poll();
            buffered.add(data);
            if (data.leftOverData == null) {
                StringBuilder message = new StringBuilder();
                if(data.retry > 0) {
                    message.append("retry:");
                    message.append(data.retry);
                    message.append('\n');
                } else {
                    if (data.id != null) {
                        message.append("id:");
                        message.append(data.id);
                        message.append('\n');
                    }
                    if (data.event != null) {
                        message.append("event:");
                        message.append(data.event);
                        message.append('\n');
                    }
                    if (data.data != null) {
                        message.append("data:");
                        for (int i = 0; i < data.data.length(); ++i) {
                            char c = data.data.charAt(i);
                            if (c == '\n') {
                                message.append("\ndata:");
                            } else {
                                message.append(c);
                            }
                        }
                        message.append('\n');
                    }
                }
                message.append('\n');
                byte[] messageBytes = message.toString().getBytes(StandardCharsets.UTF_8);
                if (messageBytes.length < buffer.remaining()) {
                    buffer.put(messageBytes);
                    data.endBufferPosition = buffer.position();
                } else {
                    queue.addFirst(data);
                    int rem = buffer.remaining();
                    buffer.put(messageBytes, 0, rem);
                    data.leftOverData = messageBytes;
                    data.leftOverDataOffset = rem;
                }
            } else {
                int remainingData = data.leftOverData.length - data.leftOverDataOffset;
                if (remainingData > buffer.remaining()) {
                    queue.addFirst(data);
                    int toWrite = buffer.remaining();
                    buffer.put(data.leftOverData, data.leftOverDataOffset, toWrite);
                    data.leftOverDataOffset += toWrite;
                } else {
                    buffer.put(data.leftOverData, data.leftOverDataOffset, remainingData);
                    data.endBufferPosition = buffer.position();
                    data.leftOverData = null;
                }
            }
        }
        buffer.flip();
        sink.resumeWrites();
    }

    /**
     * execute a graceful shutdown once all data has been sent
     */
    public void shutdown() {
        if (open == 0 || shutdown) {
            return;
        }
        shutdown = true;
        sink.getIoThread().execute(new Runnable() {
            @Override
            public void run() {

                synchronized (ServerSentEventConnection.this) {
                    if (queue.isEmpty() && pooled == null) {
                        exchange.endExchange();
                    }
                }
            }
        });
    }

    @Override
    public boolean isOpen() {
        return open != 0;
    }

    @Override
    public void close() throws IOException {
        close(new ClosedChannelException());
    }

    private synchronized void close(IOException e) throws IOException {
        if (openUpdater.compareAndSet(this, 1, 0)) {
            if (pooled != null) {
                pooled.close();
                pooled = null;
            }
            List cb = new ArrayList<>(buffered.size() + queue.size() + flushingMessages.size());
            cb.addAll(buffered);
            cb.addAll(queue);
            cb.addAll(flushingMessages);
            queue.clear();
            buffered.clear();
            flushingMessages.clear();
            for (SSEData i : cb) {
                if (i.callback != null) {
                    try {
                        i.callback.failed(this, i.data, i.event, i.id, e);
                    } catch (Exception ex) {
                        UndertowLogger.REQUEST_LOGGER.failedToInvokeFailedCallback(i.callback, ex);
                    }
                }
            }
            sink.shutdownWrites();
            if(!sink.flush()) {
                sink.getWriteSetter().set(ChannelListeners.flushingChannelListener(null, new ChannelExceptionHandler() {
                    @Override
                    public void handleException(StreamSinkChannel channel, IOException exception) {
                        IoUtils.safeClose(sink);
                    }
                }));
                sink.resumeWrites();
            }
        }
    }

    @Override
    public  T getAttachment(AttachmentKey key) {
        return exchange.getAttachment(key);
    }

    @Override
    public  List getAttachmentList(AttachmentKey> key) {
        return exchange.getAttachmentList(key);
    }

    @Override
    public  T putAttachment(AttachmentKey key, T value) {
        return exchange.putAttachment(key, value);
    }

    @Override
    public  T removeAttachment(AttachmentKey key) {
        return exchange.removeAttachment(key);
    }

    @Override
    public  void addToAttachmentList(AttachmentKey> key, T value) {
        exchange.addToAttachmentList(key, value);
    }

    public interface EventCallback {

        /**
         * Notification that is called when a message is sucessfully sent
         *
         * @param connection The connection
         * @param data The message data
         * @param event The message event
         * @param id The message id
         */
        void done(ServerSentEventConnection connection, String data, String event, String id);

        /**
         * Notification that is called when a message send fails.
         *
         * @param connection The connection
         * @param data The message data
         * @param event The message event
         * @param id The message id
         * @param e The exception
         */
        void failed(ServerSentEventConnection connection, String data, String event, String id, IOException e);

    }

    private static class SSEData {
        final String event;
        final String data;
        final String id;
        final long retry;
        final EventCallback callback;
        private int endBufferPosition = -1;
        private byte[] leftOverData;
        private int leftOverDataOffset;

        private SSEData(String event, String data, String id, EventCallback callback) {
            this.event = event;
            this.data = data;
            this.id = id;
            this.callback = callback;
            this.retry = -1;
        }

        private SSEData(long retry, EventCallback callback) {
            this.event = null;
            this.data = null;
            this.id = null;
            this.callback = callback;
            this.retry = retry;
        }


    }

    private class SseWriteListener implements ChannelListener {
        @Override
        public void handleEvent(StreamSinkChannel channel) {
            synchronized (ServerSentEventConnection.this) {
                try {
                    if (!flushingMessages.isEmpty()) {
                        if (!channel.flush()) {
                            return;
                        }
                        for (SSEData data : flushingMessages) {
                            if (data.callback != null && data.leftOverData == null) {
                                data.callback.done(ServerSentEventConnection.this, data.data, data.event, data.id);
                            }
                        }
                        flushingMessages.clear();
                        ByteBuffer buffer = pooled.getBuffer();
                        if (!buffer.hasRemaining()) {
                            fillBuffer();
                            if (pooled == null) {
                                if (channel.flush()) {
                                    channel.suspendWrites();
                                }
                                return;
                            }
                        }
                    } else if (pooled == null) {
                        if (channel.flush()) {
                            channel.suspendWrites();
                        }
                        return;
                    }

                    ByteBuffer buffer = pooled.getBuffer();
                    int res;
                    do {
                        res = channel.write(buffer);
                        boolean flushed = channel.flush();
                        while (!buffered.isEmpty()) {
                            //figure out which messages are complete
                            SSEData data = buffered.peek();
                            if (data.endBufferPosition > 0 && buffer.position() >= data.endBufferPosition) {
                                buffered.poll();
                                if (flushed) {
                                    if (data.callback != null && data.leftOverData == null) {
                                        data.callback.done(ServerSentEventConnection.this, data.data, data.event, data.id);
                                    }
                                } else {
                                    //if flush was unsuccessful we defer the callback invocation, till it is actually on the wire
                                    flushingMessages.add(data);
                                }

                            } else {
                                if (data.endBufferPosition <= 0) {
                                    buffered.poll();
                                }
                                break;
                            }
                        }
                        if (!flushed && !flushingMessages.isEmpty()) {
                            sink.resumeWrites();
                            return;
                        }

                        if (!buffer.hasRemaining()) {
                            fillBuffer();
                            if (pooled == null) {
                                return;
                            }
                        } else if (res == 0) {
                            sink.resumeWrites();
                            return;
                        }

                    } while (res > 0);
                } catch (IOException e) {
                    handleException(e);
                }
            }
        }
    }

    private void handleException(IOException e) {
        IoUtils.safeClose(this, sink, exchange.getConnection());
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy