All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.undertow.conduits.DeflatingStreamSinkConduit Maven / Gradle / Ivy

Go to download

This artifact provides a single jar that contains all classes required to use remote EJB and JMS, including all dependencies. It is intended for use by those not using maven, maven users should just import the EJB and JMS BOM's instead (shaded JAR's cause lots of problems with maven, as it is very easy to inadvertently end up with different versions on classes on the class path).

There is a newer version: 34.0.0.Final
Show newest version
/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2014 Red Hat, Inc., and individual contributors
 * as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

package io.undertow.conduits;

import static org.xnio.Bits.allAreClear;
import static org.xnio.Bits.allAreSet;
import static org.xnio.Bits.anyAreSet;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.FileChannel;
import java.util.concurrent.TimeUnit;
import java.util.zip.Deflater;

import io.undertow.server.Connectors;
import org.xnio.IoUtils;
import io.undertow.connector.PooledByteBuffer;
import org.xnio.XnioIoThread;
import org.xnio.XnioWorker;
import org.xnio.channels.StreamSourceChannel;
import org.xnio.conduits.ConduitWritableByteChannel;
import org.xnio.conduits.Conduits;
import org.xnio.conduits.StreamSinkConduit;
import org.xnio.conduits.WriteReadyHandler;

import io.undertow.UndertowLogger;
import io.undertow.server.HttpServerExchange;
import io.undertow.util.ConduitFactory;
import io.undertow.util.NewInstanceObjectPool;
import io.undertow.util.ObjectPool;
import io.undertow.util.Headers;
import io.undertow.util.PooledObject;
import io.undertow.util.SimpleObjectPool;

/**
 * Channel that handles deflate compression
 *
 * @author Stuart Douglas
 */
public class DeflatingStreamSinkConduit implements StreamSinkConduit {

    protected volatile Deflater deflater;

    protected final PooledObject pooledObject;
    private final ConduitFactory conduitFactory;
    private final HttpServerExchange exchange;

    private StreamSinkConduit next;
    private WriteReadyHandler writeReadyHandler;


    /**
     * The streams buffer. This is freed when the next is shutdown
     */
    protected PooledByteBuffer currentBuffer;
    /**
     * there may have been some additional data that did not fit into the first buffer
     */
    private ByteBuffer additionalBuffer;

    private int state = 0;

    private static final int SHUTDOWN = 1;
    private static final int NEXT_SHUTDOWN = 1 << 1;
    private static final int FLUSHING_BUFFER = 1 << 2;
    private static final int WRITES_RESUMED = 1 << 3;
    private static final int CLOSED = 1 << 4;
    private static final int WRITTEN_TRAILER = 1 << 5;

    public DeflatingStreamSinkConduit(final ConduitFactory conduitFactory, final HttpServerExchange exchange) {
        this(conduitFactory, exchange, Deflater.DEFLATED);
    }

    public DeflatingStreamSinkConduit(final ConduitFactory conduitFactory, final HttpServerExchange exchange, int deflateLevel) {
        this(conduitFactory, exchange, newInstanceDeflaterPool(deflateLevel));
    }

    public DeflatingStreamSinkConduit(final ConduitFactory conduitFactory, final HttpServerExchange exchange, ObjectPool deflaterPool) {
        this.pooledObject = deflaterPool.allocate();
        this.deflater = pooledObject.getObject();
        this.currentBuffer = exchange.getConnection().getByteBufferPool().allocate();
        this.exchange = exchange;
        this.conduitFactory = conduitFactory;
        setWriteReadyHandler(new WriteReadyHandler.ChannelListenerHandler<>(Connectors.getConduitSinkChannel(exchange)));
    }

    public static ObjectPool newInstanceDeflaterPool(int deflateLevel) {
        return new NewInstanceObjectPool<>(() -> new Deflater(deflateLevel, true), Deflater::end);
    }

    public static ObjectPool simpleDeflaterPool(int poolSize, int deflateLevel) {
        return new SimpleObjectPool<>(poolSize, () -> new Deflater(deflateLevel, true), Deflater::reset, Deflater::end);
    }


    @Override
    public int write(final ByteBuffer src) throws IOException {
        if (anyAreSet(state, SHUTDOWN | CLOSED) || currentBuffer == null) {
            throw new ClosedChannelException();
        }
        try {
            if (!performFlushIfRequired()) {
                return 0;
            }
            if (src.remaining() == 0) {
                return 0;
            }
            //we may already have some input, if so compress it
            if (!deflater.needsInput()) {
                deflateData(false);
                if (!deflater.needsInput()) {
                    return 0;
                }
            }
            byte[] data = new byte[src.remaining()];
            src.get(data);
            preDeflate(data);
            deflater.setInput(data);
            Connectors.updateResponseBytesSent(exchange, 0 - data.length);
            deflateData(false);
            return data.length;
        } catch (IOException | RuntimeException | Error e) {
            freeBuffer();
            throw e;
        }
    }

    protected void preDeflate(byte[] data) {

    }

    @Override
    public long write(final ByteBuffer[] srcs, final int offset, final int length) throws IOException {
        if (anyAreSet(state, SHUTDOWN | CLOSED) || currentBuffer == null) {
            throw new ClosedChannelException();
        }
        try {
            int total = 0;
            for (int i = offset; i < offset + length; ++i) {
                if (srcs[i].hasRemaining()) {
                    int ret = write(srcs[i]);
                    total += ret;
                    if (ret == 0) {
                        return total;
                    }
                }
            }
            return total;
        } catch (IOException | RuntimeException | Error e) {
            freeBuffer();
            throw e;
        }
    }

    @Override
    public int writeFinal(ByteBuffer src) throws IOException {
        return Conduits.writeFinalBasic(this, src);
    }

    @Override
    public long writeFinal(ByteBuffer[] srcs, int offset, int length) throws IOException {
        return Conduits.writeFinalBasic(this, srcs, offset, length);
    }

    @Override
    public long transferFrom(final FileChannel src, final long position, final long count) throws IOException {
        if (anyAreSet(state, SHUTDOWN | CLOSED)) {
            throw new ClosedChannelException();
        }
        if (!performFlushIfRequired()) {
            return 0;
        }
        return src.transferTo(position, count, new ConduitWritableByteChannel(this));
    }


    @Override
    public long transferFrom(final StreamSourceChannel source, final long count, final ByteBuffer throughBuffer) throws IOException {
        if (anyAreSet(state, SHUTDOWN | CLOSED)) {
            throw new ClosedChannelException();
        }
        if (!performFlushIfRequired()) {
            return 0;
        }
        return IoUtils.transfer(source, count, throughBuffer, new ConduitWritableByteChannel(this));
    }

    @Override
    public XnioWorker getWorker() {
        return exchange.getConnection().getWorker();
    }

    @Override
    public void suspendWrites() {
        if (next == null) {
            state = state & ~WRITES_RESUMED;
        } else {
            next.suspendWrites();
        }
    }


    @Override
    public boolean isWriteResumed() {
        if (next == null) {
            return anyAreSet(state, WRITES_RESUMED);
        } else {
            return next.isWriteResumed();
        }
    }

    @Override
    public void wakeupWrites() {
        if (next == null) {
            resumeWrites();
        } else {
            next.wakeupWrites();
        }
    }

    @Override
    public void resumeWrites() {
        if (next == null) {
            state |= WRITES_RESUMED;
            queueWriteListener();
        } else {
            next.resumeWrites();
        }
    }

    private void queueWriteListener() {
        exchange.getConnection().getIoThread().execute(new Runnable() {
            @Override
            public void run() {
                if (writeReadyHandler != null) {
                    try {
                        writeReadyHandler.writeReady();
                    } finally {
                        //if writes are still resumed queue up another one
                        if (next == null && isWriteResumed()) {
                            queueWriteListener();
                        }
                    }
                }
            }
        });
    }

    @Override
    public void terminateWrites() throws IOException {
        if (deflater != null) {
            deflater.finish();
        }
        state |= SHUTDOWN;
    }

    @Override
    public boolean isWriteShutdown() {
        return anyAreSet(state, SHUTDOWN);
    }

    @Override
    public void awaitWritable() throws IOException {
        if (next == null) {
            return;
        } else {
            next.awaitWritable();
        }
    }

    @Override
    public void awaitWritable(final long time, final TimeUnit timeUnit) throws IOException {
        if (next == null) {
            return;
        } else {
            next.awaitWritable(time, timeUnit);
        }
    }

    @Override
    public XnioIoThread getWriteThread() {
        return exchange.getConnection().getIoThread();
    }

    @Override
    public void setWriteReadyHandler(final WriteReadyHandler handler) {
        this.writeReadyHandler = handler;
    }

    @Override
    public boolean flush() throws IOException {
        if (currentBuffer == null) {
            if (anyAreSet(state, NEXT_SHUTDOWN)) {
                return next.flush();
            } else {
                return true;
            }
        }
        try {
            boolean nextCreated = false;
            try {
                if (anyAreSet(state, SHUTDOWN)) {
                    if (anyAreSet(state, NEXT_SHUTDOWN)) {
                        return next.flush();
                    } else {
                        if (!performFlushIfRequired()) {
                            return false;
                        }
                        //if the deflater has not been fully flushed we need to flush it
                        if (!deflater.finished()) {
                            deflateData(false);
                            //if could not fully flush
                            if (!deflater.finished()) {
                                return false;
                            }
                        }
                        final ByteBuffer buffer = currentBuffer.getBuffer();
                        if (allAreClear(state, WRITTEN_TRAILER)) {
                            state |= WRITTEN_TRAILER;
                            byte[] data = getTrailer();
                            if (data != null) {
                                Connectors.updateResponseBytesSent(exchange, data.length);
                                if(additionalBuffer != null) {
                                    byte[] newData = new byte[additionalBuffer.remaining() + data.length];
                                    int pos = 0;
                                    while (additionalBuffer.hasRemaining()) {
                                        newData[pos++] = additionalBuffer.get();
                                    }
                                    for (byte aData : data) {
                                        newData[pos++] = aData;
                                    }
                                    this.additionalBuffer = ByteBuffer.wrap(newData);
                                } else if(anyAreSet(state, FLUSHING_BUFFER) && buffer.capacity() - buffer.remaining() >= data.length) {
                                    buffer.compact();
                                    buffer.put(data);
                                    buffer.flip();
                                } else if (data.length <= buffer.remaining() && !anyAreSet(state, FLUSHING_BUFFER)) {
                                    buffer.put(data);
                                } else {
                                    additionalBuffer = ByteBuffer.wrap(data);
                                }
                            }
                        }

                        //ok the deflater is flushed, now we need to flush the buffer
                        if (!anyAreSet(state, FLUSHING_BUFFER)) {
                            buffer.flip();
                            state |= FLUSHING_BUFFER;
                            if (next == null) {
                                nextCreated = true;
                                this.next = createNextChannel();
                            }
                        }
                        if (performFlushIfRequired()) {
                            state |= NEXT_SHUTDOWN;
                            freeBuffer();
                            next.terminateWrites();
                            return next.flush();
                        } else {
                            return false;
                        }
                    }
                } else {
                    if(allAreClear(state, FLUSHING_BUFFER)) {
                        if (next == null) {
                            nextCreated = true;
                            this.next = createNextChannel();
                        }
                        deflateData(true);
                        if(allAreClear(state, FLUSHING_BUFFER)) {
                            //deflateData can cause this to be change
                            currentBuffer.getBuffer().flip();
                            this.state |= FLUSHING_BUFFER;
                        }
                    }
                    if(!performFlushIfRequired()) {
                        return false;
                    }
                    return next.flush();
                }
            } finally {
                if (nextCreated) {
                    if (anyAreSet(state, WRITES_RESUMED) && !anyAreSet(state ,NEXT_SHUTDOWN)) {
                        try {
                            next.resumeWrites();
                        } catch (Throwable e) {
                            UndertowLogger.REQUEST_LOGGER.debug("Failed to resume", e);
                        }
                    }
                }
            }
        } catch (IOException | RuntimeException | Error e) {
            freeBuffer();
            throw e;
        }
    }

    /**
     * called before the stream is finally flushed.
     */
    protected byte[] getTrailer() {
        return null;
    }

    /**
     * The we are in the flushing state then we flush to the underlying stream, otherwise just return true
     *
     * @return false if there is still more to flush
     */
    private boolean performFlushIfRequired() throws IOException {
        if (anyAreSet(state, FLUSHING_BUFFER)) {
            final ByteBuffer[] bufs = new ByteBuffer[additionalBuffer == null ? 1 : 2];
            long totalLength = 0;
            bufs[0] = currentBuffer.getBuffer();
            totalLength += bufs[0].remaining();
            if (additionalBuffer != null) {
                bufs[1] = additionalBuffer;
                totalLength += bufs[1].remaining();
            }
            if (totalLength > 0) {
                long total = 0;
                long res = 0;
                do {
                    res = next.write(bufs, 0, bufs.length);
                    total += res;
                    if (res == 0) {
                        return false;
                    }
                } while (total < totalLength);
            }
            additionalBuffer = null;
            currentBuffer.getBuffer().clear();
            state = state & ~FLUSHING_BUFFER;
        }
        return true;
    }


    private StreamSinkConduit createNextChannel() {
        if (deflater.finished() && allAreSet(state, WRITTEN_TRAILER)) {
            //the deflater was fully flushed before we created the channel. This means that what is in the buffer is
            //all there is
            int remaining = currentBuffer.getBuffer().remaining();
            if (additionalBuffer != null) {
                remaining += additionalBuffer.remaining();
            }
            if(!exchange.getResponseHeaders().contains(Headers.TRANSFER_ENCODING)) {
                exchange.getResponseHeaders().put(Headers.CONTENT_LENGTH, Integer.toString(remaining));
            }
        } else {
            exchange.getResponseHeaders().remove(Headers.CONTENT_LENGTH);
        }
        return conduitFactory.create();
    }

    /**
     * Runs the current data through the deflater. As much as possible this will be buffered in the current output
     * stream.
     *
     * @throws IOException
     */
    private void deflateData(boolean force) throws IOException {
        //we don't need to flush here, as this should have been called already by the time we get to
        //this point
        boolean nextCreated = false;
        try (PooledByteBuffer arrayPooled = this.exchange.getConnection().getByteBufferPool().getArrayBackedPool().allocate()) {
            PooledByteBuffer pooled = this.currentBuffer;
            final ByteBuffer outputBuffer = pooled.getBuffer();

            final boolean shutdown = anyAreSet(state, SHUTDOWN);
            ByteBuffer buf = arrayPooled.getBuffer();
            while (force || !deflater.needsInput() || (shutdown && !deflater.finished())) {
                int count = deflater.deflate(buf.array(), buf.arrayOffset(), buf.remaining(), force ? Deflater.SYNC_FLUSH: Deflater.NO_FLUSH);
                Connectors.updateResponseBytesSent(exchange, count);
                if (count != 0) {
                    int remaining = outputBuffer.remaining();
                    if (remaining > count) {
                        outputBuffer.put(buf.array(), buf.arrayOffset(), count);
                    } else {
                        if (remaining == count) {
                            outputBuffer.put(buf.array(), buf.arrayOffset(), count);
                        } else {
                            outputBuffer.put(buf.array(), buf.arrayOffset(), remaining);
                            additionalBuffer = ByteBuffer.allocate(count - remaining);
                            additionalBuffer.put(buf.array(), buf.arrayOffset() + remaining, count - remaining);
                            additionalBuffer.flip();
                        }
                        outputBuffer.flip();
                        this.state |= FLUSHING_BUFFER;
                        if (next == null) {
                            nextCreated = true;
                            this.next = createNextChannel();
                        }
                        if (!performFlushIfRequired()) {
                            return;
                        }
                    }
                } else {
                    force = false;
                }
            }
        } finally {
            if (nextCreated) {
                if (anyAreSet(state, WRITES_RESUMED)) {
                    next.resumeWrites();
                }
            }
        }
    }


    @Override
    public void truncateWrites() throws IOException {
        freeBuffer();
        state |= CLOSED;
        next.truncateWrites();
    }

    private void freeBuffer() {
        if (currentBuffer != null) {
            currentBuffer.close();
            currentBuffer = null;
            state = state & ~FLUSHING_BUFFER;
        }
        if (deflater != null) {
            deflater = null;
            pooledObject.close();
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy