All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.undertow.server.protocol.ajp.AjpServerResponseConduit Maven / Gradle / Ivy

There is a newer version: 2.3.18.Final
Show newest version
/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2014 Red Hat, Inc., and individual contributors
 * as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

package io.undertow.server.protocol.ajp;

import io.undertow.UndertowMessages;
import io.undertow.UndertowOptions;
import io.undertow.conduits.AbstractFramedStreamSinkConduit;
import io.undertow.conduits.ConduitListener;
import io.undertow.server.Connectors;
import io.undertow.server.HttpServerExchange;
import io.undertow.util.HeaderMap;
import io.undertow.util.Headers;
import io.undertow.util.HttpString;
import io.undertow.util.StatusCodes;
import org.jboss.logging.Logger;
import org.xnio.Buffers;
import org.xnio.IoUtils;
import io.undertow.connector.ByteBufferPool;
import io.undertow.connector.PooledByteBuffer;
import org.xnio.channels.StreamSourceChannel;
import org.xnio.conduits.ConduitWritableByteChannel;
import org.xnio.conduits.StreamSinkConduit;
import org.xnio.conduits.WriteReadyHandler;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.xnio.Bits.allAreClear;
import static org.xnio.Bits.anyAreSet;

/**
 * AJP response channel. For now we are going to assume that the buffers are sized to
 * fit complete packets. As AJP packets are limited to 8k this is a reasonable assumption.
 *
 * @author David M. Lloyd
 * @author Stuart Douglas
 */
final class AjpServerResponseConduit extends AbstractFramedStreamSinkConduit {

    private static final Logger log = Logger.getLogger("io.undertow.server.channel.ajp.response");

    private static final int DEFAULT_MAX_DATA_SIZE = 8192;

    private static final Map HEADER_MAP;

    private static final ByteBuffer FLUSH_PACKET = ByteBuffer.allocateDirect(8);

    static {
        final Map headers = new HashMap<>();
        headers.put(Headers.CONTENT_TYPE, 0xA001);
        headers.put(Headers.CONTENT_LANGUAGE, 0xA002);
        headers.put(Headers.CONTENT_LENGTH, 0xA003);
        headers.put(Headers.DATE, 0xA004);
        headers.put(Headers.LAST_MODIFIED, 0xA005);
        headers.put(Headers.LOCATION, 0xA006);
        headers.put(Headers.SET_COOKIE, 0xA007);
        headers.put(Headers.SET_COOKIE2, 0xA008);
        headers.put(Headers.SERVLET_ENGINE, 0xA009);
        headers.put(Headers.STATUS, 0xA00A);
        headers.put(Headers.WWW_AUTHENTICATE, 0xA00B);
        HEADER_MAP = Collections.unmodifiableMap(headers);

        FLUSH_PACKET.put((byte) 'A');
        FLUSH_PACKET.put((byte) 'B');
        FLUSH_PACKET.put((byte) 0);
        FLUSH_PACKET.put((byte) 4);
        FLUSH_PACKET.put((byte) 3);
        FLUSH_PACKET.put((byte) 0);
        FLUSH_PACKET.put((byte) 0);
        FLUSH_PACKET.put((byte) 0);
        FLUSH_PACKET.flip();
    }

    private static final int FLAG_START = 1; //indicates that the header has not been generated yet.
    private static final int FLAG_WRITE_RESUMED = 1 << 2;
    private static final int FLAG_WRITE_READ_BODY_CHUNK_FROM_LISTENER = 1 << 3;
    private static final int FLAG_WRITE_SHUTDOWN = 1 << 4;
    private static final int FLAG_READS_DONE = 1 << 5;
    private static final int FLAG_FLUSH_QUEUED = 1 << 6;

    private static final ByteBuffer CLOSE_FRAME_PERSISTENT;
    private static final ByteBuffer CLOSE_FRAME_NON_PERSISTENT;

    static {
        ByteBuffer buffer = ByteBuffer.wrap(new byte[6]);
        buffer.put((byte) 'A');
        buffer.put((byte) 'B');
        buffer.put((byte) 0);
        buffer.put((byte) 2);
        buffer.put((byte) 5);
        buffer.put((byte) 1); //reuse
        buffer.flip();
        CLOSE_FRAME_PERSISTENT = buffer;
        buffer = ByteBuffer.wrap(new byte[6]);
        buffer.put(CLOSE_FRAME_PERSISTENT.duplicate());
        buffer.put(5, (byte) 0);
        buffer.flip();
        CLOSE_FRAME_NON_PERSISTENT = buffer;
    }


    private final ByteBufferPool pool;

    /**
     * State flags
     */
    private int state = FLAG_START;

    private final HttpServerExchange exchange;

    private final ConduitListener finishListener;

    private final boolean headRequest;

    AjpServerResponseConduit(final StreamSinkConduit next, final ByteBufferPool pool, final HttpServerExchange exchange, ConduitListener finishListener, boolean headRequest) {
        super(next);
        this.pool = pool;
        this.exchange = exchange;
        this.finishListener = finishListener;
        this.headRequest = headRequest;
        state = FLAG_START;
    }

    private static void putInt(final ByteBuffer buf, int value) {
        buf.put((byte) ((value >> 8) & 0xFF));
        buf.put((byte) (value & 0xFF));
    }

    private static void putString(final ByteBuffer buf, String value) {
        final int length = value.length();
        putInt(buf, length);
        for (int i = 0; i < length; ++i) {
            char c = value.charAt(i);
            if(c != '\r' && c != '\n'){
                buf.put((byte) c);
            } else {
                buf.put((byte)' ');
            }
        }
        buf.put((byte) 0);
    }

    private void putHttpString(final ByteBuffer buf, HttpString value) {
        final int length = value.length();
        putInt(buf, length);
        value.appendTo(buf);
        buf.put((byte) 0);
    }

    /**
     * Handles generating the header if required, and adding it to the frame queue.
     *
     * No attempt is made to actually flush this, so a gathering write can be used to actually flush the data
     */
    private void processAJPHeader() {
        int oldState = this.state;
        if (anyAreSet(oldState, FLAG_START)) {

            PooledByteBuffer[] byteBuffers = null;

            //merge the cookies into the header map
            Connectors.flattenCookies(exchange);

            PooledByteBuffer pooled = pool.allocate();
            ByteBuffer buffer = pooled.getBuffer();
            buffer.put((byte) 'A');
            buffer.put((byte) 'B');
            buffer.put((byte) 0); //we fill the size in later
            buffer.put((byte) 0);
            buffer.put((byte) 4);
            putInt(buffer, exchange.getStatusCode());
            String reason = exchange.getReasonPhrase();
            if(reason == null) {
                reason = StatusCodes.getReason(exchange.getStatusCode());
            }
            if(reason.length() + 4 > buffer.remaining()) {
                pooled.close();
                throw UndertowMessages.MESSAGES.reasonPhraseToLargeForBuffer(reason);
            }
            putString(buffer, reason);

            int headers = 0;
            //we need to count the headers
            final HeaderMap responseHeaders = exchange.getResponseHeaders();
            for (HttpString name : responseHeaders.getHeaderNames()) {
                headers += responseHeaders.get(name).size();
            }

            putInt(buffer, headers);


            for (final HttpString header : responseHeaders.getHeaderNames()) {
                for (String headerValue : responseHeaders.get(header)) {
                    if(buffer.remaining() < header.length() + headerValue.length() + 6) {
                        //if there is not enough room in the buffer we need to allocate more
                        buffer.flip();
                        if(byteBuffers == null) {
                            byteBuffers = new PooledByteBuffer[2];
                            byteBuffers[0] = pooled;
                        } else {
                            PooledByteBuffer[] old = byteBuffers;
                            byteBuffers = new PooledByteBuffer[old.length + 1];
                            System.arraycopy(old, 0, byteBuffers, 0, old.length);
                        }
                        pooled = pool.allocate();
                        byteBuffers[byteBuffers.length - 1] = pooled;
                        buffer = pooled.getBuffer();
                    }

                    Integer headerCode = HEADER_MAP.get(header);
                    if (headerCode != null) {
                        putInt(buffer, headerCode);
                    } else {
                        putHttpString(buffer, header);
                    }
                    putString(buffer, headerValue);
                }
            }
            if(byteBuffers == null) {
                int dataLength = buffer.position() - 4;
                buffer.put(2, (byte) ((dataLength >> 8) & 0xFF));
                buffer.put(3, (byte) (dataLength & 0xFF));
                buffer.flip();
                queueFrame(new PooledBufferFrameCallback(pooled), buffer);
            } else {
                ByteBuffer[] bufs = new ByteBuffer[byteBuffers.length];
                for(int i = 0; i < bufs.length; ++i) {
                    bufs[i] = byteBuffers[i].getBuffer();
                }
                int dataLength = (int) (Buffers.remaining(bufs) - 4);
                bufs[0].put(2, (byte) ((dataLength >> 8) & 0xFF));
                bufs[0].put(3, (byte) (dataLength & 0xFF));
                buffer.flip();
                queueFrame(new PooledBuffersFrameCallback(byteBuffers), bufs);
            }
            state &= ~FLAG_START;
        }
    }


    @Override
    protected void queueCloseFrames() {
        processAJPHeader();
        final ByteBuffer buffer = exchange.isPersistent() ? CLOSE_FRAME_PERSISTENT.duplicate() : CLOSE_FRAME_NON_PERSISTENT.duplicate();
        queueFrame(null, buffer);
    }

    private void queueRemainingBytes(final ByteBuffer src, final ByteBuffer[] buffers) {
        List pools = new ArrayList<>(4);

        try {
            PooledByteBuffer newPooledBuffer = pool.allocate();
            pools.add(newPooledBuffer);
            while (src.remaining() > newPooledBuffer.getBuffer().remaining()) {
                ByteBuffer dupSrc = src.duplicate();
                dupSrc.limit(dupSrc.position() + newPooledBuffer.getBuffer().remaining());
                newPooledBuffer.getBuffer().put(dupSrc);
                src.position(dupSrc.position());
                newPooledBuffer.getBuffer().flip();
                newPooledBuffer = pool.allocate();
                pools.add(newPooledBuffer);
            }
            newPooledBuffer.getBuffer().put(src);
            newPooledBuffer.getBuffer().flip();

            ByteBuffer[] savedBuffers = new ByteBuffer[pools.size() + 2];
            int i = 0;
            savedBuffers[i++] = buffers[0];
            for (PooledByteBuffer p : pools) {
                savedBuffers[i++] = p.getBuffer();
            }
            savedBuffers[i] = buffers[2];
            queueFrame(new PooledBuffersFrameCallback(pools.toArray(new PooledByteBuffer[0])), savedBuffers);
        } catch (RuntimeException | Error e) {
            for (PooledByteBuffer p : pools) {
                p.close();
            }
            throw e;
        }
    }

    public int write(final ByteBuffer src) throws IOException {
        if(queuedDataLength() > 0) {
            //if there is data in the queue we flush and return
            //otherwise the queue can grow indefinitely
            if(!flushQueuedData()) {
                return 0;
            }
        }
        processAJPHeader();
        if (headRequest) {
            int remaining = src.remaining();
            src.position(src.position() + remaining);
            return remaining;
        }
        int limit = src.limit();
        try {
            int maxData = exchange.getConnection().getUndertowOptions().get(UndertowOptions.MAX_AJP_PACKET_SIZE, DEFAULT_MAX_DATA_SIZE) - 8;
            if (src.remaining() > maxData) {
                src.limit(src.position() + maxData);
            }
            final int writeSize = src.remaining();
            final ByteBuffer[] buffers = createHeader(src);
            int toWrite = 0;
            for (ByteBuffer buffer : buffers) {
                toWrite += buffer.remaining();
            }
            final int originalPayloadSize = writeSize;
            long r = 0;
            do {
                r = super.write(buffers, 0, buffers.length);
                toWrite -= r;
                if (r == -1) {
                    throw new ClosedChannelException();
                } else if (r == 0) {
                    // we need to queue all the remaining bytes for writing
                    queueRemainingBytes(src, buffers);
                    return originalPayloadSize;
                }
            } while (toWrite > 0);
            return originalPayloadSize;
        } catch (IOException | RuntimeException e) {
            IoUtils.safeClose(exchange.getConnection());
            throw e;
        } finally {
            src.limit(limit);
        }
    }

    private ByteBuffer[] createHeader(final ByteBuffer src) {
        int remaining = src.remaining();
        int chunkSize = remaining + 4;
        byte[] header = new byte[7];
        header[0] = (byte) 'A';
        header[1] = (byte) 'B';
        header[2] = (byte) ((chunkSize >> 8) & 0xFF);
        header[3] = (byte) (chunkSize & 0xFF);
        header[4] = (byte) (3 & 0xFF);
        header[5] = (byte) ((remaining >> 8) & 0xFF);
        header[6] = (byte) (remaining & 0xFF);

        byte[] footer = new byte[1];
        footer[0] = 0;

        final ByteBuffer[] buffers = new ByteBuffer[3];
        buffers[0] = ByteBuffer.wrap(header);
        buffers[1] = src;
        buffers[2] = ByteBuffer.wrap(footer);
        return buffers;
    }

    public long write(final ByteBuffer[] srcs) throws IOException {
        return write(srcs, 0, srcs.length);
    }

    public long write(final ByteBuffer[] srcs, final int offset, final int length) throws IOException {
        long total = 0;
        for (int i = offset; i < offset + length; ++i) {
            while (srcs[i].hasRemaining()) {
                int written = write(srcs[i]);
                if (written == 0) {
                    return total;
                }
                total += written;
            }
        }
        return total;
    }

    public long transferFrom(final FileChannel src, final long position, final long count) throws IOException {
        return src.transferTo(position, count, new ConduitWritableByteChannel(this));
    }

    public long transferFrom(final StreamSourceChannel source, final long count, final ByteBuffer throughBuffer) throws IOException {
        return IoUtils.transfer(source, count, throughBuffer, new ConduitWritableByteChannel(this));
    }

    @Override
    protected void finished() {
        if (finishListener != null) {
            finishListener.handleEvent(this);
        }
    }

    @Override
    public void setWriteReadyHandler(WriteReadyHandler handler) {
        next.setWriteReadyHandler(new AjpServerWriteReadyHandler(handler));
    }

    public void suspendWrites() {
        log.trace("suspend");
        state &= ~FLAG_WRITE_RESUMED;
        if (allAreClear(state, FLAG_WRITE_READ_BODY_CHUNK_FROM_LISTENER)) {
            next.suspendWrites();
        }
    }

    public void resumeWrites() {
        log.trace("resume");
        state |= FLAG_WRITE_RESUMED;
        next.resumeWrites();
    }

    public boolean flush() throws IOException {
        processAJPHeader();
        if(allAreClear(state, FLAG_FLUSH_QUEUED) && !isWritesTerminated()) {
            queueFrame(new FrameCallBack() {
                @Override
                public void done() {
                    state &= ~FLAG_FLUSH_QUEUED;
                }

                @Override
                public void failed(IOException e) {

                }
            }, FLUSH_PACKET.duplicate());
            state |= FLAG_FLUSH_QUEUED;
        }
        return flushQueuedData();
    }
    public boolean isWriteResumed() {
        return anyAreSet(state, FLAG_WRITE_RESUMED);
    }

    public void wakeupWrites() {
        log.trace("wakeup");
        state |= FLAG_WRITE_RESUMED;
        next.wakeupWrites();
    }

    @Override
    protected void doTerminateWrites() throws IOException {
        try {
            if (!exchange.isPersistent()) {
                next.terminateWrites();
            }
            state |= FLAG_WRITE_SHUTDOWN;
        } catch (IOException | RuntimeException e) {
            IoUtils.safeClose(exchange.getConnection());
            throw e;
        }
    }

    @Override
    public boolean isWriteShutdown() {
        return super.isWriteShutdown() || anyAreSet(state, FLAG_WRITE_SHUTDOWN);
    }

    boolean doGetRequestBodyChunk(ByteBuffer buffer, final AjpServerRequestConduit requestChannel) throws IOException {
        //first attempt to just write out the buffer
        //if there are other frames queued they will be written out first
        if(isWriteShutdown()) {
            throw UndertowMessages.MESSAGES.channelIsClosed();
        }
        super.write(buffer);
        if (buffer.hasRemaining()) {
            //write it out in a listener
            this.state |= FLAG_WRITE_READ_BODY_CHUNK_FROM_LISTENER;
            queueFrame(new FrameCallBack() {

                @Override
                public void done() {
                    state &= ~FLAG_WRITE_READ_BODY_CHUNK_FROM_LISTENER;
                    if (allAreClear(state, FLAG_WRITE_RESUMED)) {
                        next.suspendWrites();
                    }
                }

                @Override
                public void failed(IOException e) {
                    requestChannel.setReadBodyChunkError(e);
                }
            }, buffer);
            next.resumeWrites();
            return false;
        }
        return true;
    }

    private final class AjpServerWriteReadyHandler implements WriteReadyHandler {

        private final WriteReadyHandler delegate;

        private AjpServerWriteReadyHandler(WriteReadyHandler delegate) {
            this.delegate = delegate;
        }

        @Override
        public void writeReady() {
            if (anyAreSet(state, FLAG_WRITE_READ_BODY_CHUNK_FROM_LISTENER)) {
                try {
                    flushQueuedData();
                } catch (IOException e) {
                    log.debug("Error flushing when doing async READ_BODY_CHUNK flush", e);
                }
            }
            if (anyAreSet(state, FLAG_WRITE_RESUMED)) {
                delegate.writeReady();
            }
        }

        @Override
        public void forceTermination() {
            delegate.forceTermination();
        }

        @Override
        public void terminated() {
            delegate.terminated();
        }
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy