io.undertow.conduits.DeflatingStreamSinkConduit Maven / Gradle / Ivy
Go to download
This artifact provides a single jar that contains all classes required to use remote EJB and JMS, including
all dependencies. It is intended for use by those not using maven, maven users should just import the EJB and
JMS BOM's instead (shaded JAR's cause lots of problems with maven, as it is very easy to inadvertently end up
with different versions on classes on the class path).
/*
* JBoss, Home of Professional Open Source.
* Copyright 2014 Red Hat, Inc., and individual contributors
* as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.undertow.conduits;
import static org.xnio.Bits.allAreClear;
import static org.xnio.Bits.allAreSet;
import static org.xnio.Bits.anyAreSet;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.FileChannel;
import java.util.concurrent.TimeUnit;
import java.util.zip.Deflater;
import io.undertow.server.Connectors;
import org.xnio.IoUtils;
import io.undertow.connector.PooledByteBuffer;
import org.xnio.XnioIoThread;
import org.xnio.XnioWorker;
import org.xnio.channels.StreamSourceChannel;
import org.xnio.conduits.ConduitWritableByteChannel;
import org.xnio.conduits.Conduits;
import org.xnio.conduits.StreamSinkConduit;
import org.xnio.conduits.WriteReadyHandler;
import io.undertow.UndertowLogger;
import io.undertow.server.HttpServerExchange;
import io.undertow.util.ConduitFactory;
import io.undertow.util.NewInstanceObjectPool;
import io.undertow.util.ObjectPool;
import io.undertow.util.Headers;
import io.undertow.util.PooledObject;
import io.undertow.util.SimpleObjectPool;
/**
* Channel that handles deflate compression
*
* @author Stuart Douglas
*/
public class DeflatingStreamSinkConduit implements StreamSinkConduit {
protected volatile Deflater deflater;
protected final PooledObject pooledObject;
private final ConduitFactory conduitFactory;
private final HttpServerExchange exchange;
private StreamSinkConduit next;
private WriteReadyHandler writeReadyHandler;
/**
* The streams buffer. This is freed when the next is shutdown
*/
protected PooledByteBuffer currentBuffer;
/**
* there may have been some additional data that did not fit into the first buffer
*/
private ByteBuffer additionalBuffer;
private int state = 0;
private static final int SHUTDOWN = 1;
private static final int NEXT_SHUTDOWN = 1 << 1;
private static final int FLUSHING_BUFFER = 1 << 2;
private static final int WRITES_RESUMED = 1 << 3;
private static final int CLOSED = 1 << 4;
private static final int WRITTEN_TRAILER = 1 << 5;
public DeflatingStreamSinkConduit(final ConduitFactory conduitFactory, final HttpServerExchange exchange) {
this(conduitFactory, exchange, Deflater.DEFLATED);
}
public DeflatingStreamSinkConduit(final ConduitFactory conduitFactory, final HttpServerExchange exchange, int deflateLevel) {
this(conduitFactory, exchange, newInstanceDeflaterPool(deflateLevel));
}
public DeflatingStreamSinkConduit(final ConduitFactory conduitFactory, final HttpServerExchange exchange, ObjectPool deflaterPool) {
this.pooledObject = deflaterPool.allocate();
this.deflater = pooledObject.getObject();
this.currentBuffer = exchange.getConnection().getByteBufferPool().allocate();
this.exchange = exchange;
this.conduitFactory = conduitFactory;
setWriteReadyHandler(new WriteReadyHandler.ChannelListenerHandler<>(Connectors.getConduitSinkChannel(exchange)));
}
public static ObjectPool newInstanceDeflaterPool(int deflateLevel) {
return new NewInstanceObjectPool<>(() -> new Deflater(deflateLevel, true), Deflater::end);
}
public static ObjectPool simpleDeflaterPool(int poolSize, int deflateLevel) {
return new SimpleObjectPool<>(poolSize, () -> new Deflater(deflateLevel, true), Deflater::reset, Deflater::end);
}
@Override
public int write(final ByteBuffer src) throws IOException {
if (anyAreSet(state, SHUTDOWN | CLOSED) || currentBuffer == null) {
throw new ClosedChannelException();
}
try {
if (!performFlushIfRequired()) {
return 0;
}
if (src.remaining() == 0) {
return 0;
}
//we may already have some input, if so compress it
if (!deflater.needsInput()) {
deflateData(false);
if (!deflater.needsInput()) {
return 0;
}
}
byte[] data = new byte[src.remaining()];
src.get(data);
preDeflate(data);
deflater.setInput(data);
Connectors.updateResponseBytesSent(exchange, 0 - data.length);
deflateData(false);
return data.length;
} catch (IOException | RuntimeException | Error e) {
freeBuffer();
throw e;
}
}
protected void preDeflate(byte[] data) {
}
@Override
public long write(final ByteBuffer[] srcs, final int offset, final int length) throws IOException {
if (anyAreSet(state, SHUTDOWN | CLOSED) || currentBuffer == null) {
throw new ClosedChannelException();
}
try {
int total = 0;
for (int i = offset; i < offset + length; ++i) {
if (srcs[i].hasRemaining()) {
int ret = write(srcs[i]);
total += ret;
if (ret == 0) {
return total;
}
}
}
return total;
} catch (IOException | RuntimeException | Error e) {
freeBuffer();
throw e;
}
}
@Override
public int writeFinal(ByteBuffer src) throws IOException {
return Conduits.writeFinalBasic(this, src);
}
@Override
public long writeFinal(ByteBuffer[] srcs, int offset, int length) throws IOException {
return Conduits.writeFinalBasic(this, srcs, offset, length);
}
@Override
public long transferFrom(final FileChannel src, final long position, final long count) throws IOException {
if (anyAreSet(state, SHUTDOWN | CLOSED)) {
throw new ClosedChannelException();
}
if (!performFlushIfRequired()) {
return 0;
}
return src.transferTo(position, count, new ConduitWritableByteChannel(this));
}
@Override
public long transferFrom(final StreamSourceChannel source, final long count, final ByteBuffer throughBuffer) throws IOException {
if (anyAreSet(state, SHUTDOWN | CLOSED)) {
throw new ClosedChannelException();
}
if (!performFlushIfRequired()) {
return 0;
}
return IoUtils.transfer(source, count, throughBuffer, new ConduitWritableByteChannel(this));
}
@Override
public XnioWorker getWorker() {
return exchange.getConnection().getWorker();
}
@Override
public void suspendWrites() {
if (next == null) {
state = state & ~WRITES_RESUMED;
} else {
next.suspendWrites();
}
}
@Override
public boolean isWriteResumed() {
if (next == null) {
return anyAreSet(state, WRITES_RESUMED);
} else {
return next.isWriteResumed();
}
}
@Override
public void wakeupWrites() {
if (next == null) {
resumeWrites();
} else {
next.wakeupWrites();
}
}
@Override
public void resumeWrites() {
if (next == null) {
state |= WRITES_RESUMED;
queueWriteListener();
} else {
next.resumeWrites();
}
}
private void queueWriteListener() {
exchange.getConnection().getIoThread().execute(new Runnable() {
@Override
public void run() {
if (writeReadyHandler != null) {
try {
writeReadyHandler.writeReady();
} finally {
//if writes are still resumed queue up another one
if (next == null && isWriteResumed()) {
queueWriteListener();
}
}
}
}
});
}
@Override
public void terminateWrites() throws IOException {
if (deflater != null) {
deflater.finish();
}
state |= SHUTDOWN;
}
@Override
public boolean isWriteShutdown() {
return anyAreSet(state, SHUTDOWN);
}
@Override
public void awaitWritable() throws IOException {
if (next == null) {
return;
} else {
next.awaitWritable();
}
}
@Override
public void awaitWritable(final long time, final TimeUnit timeUnit) throws IOException {
if (next == null) {
return;
} else {
next.awaitWritable(time, timeUnit);
}
}
@Override
public XnioIoThread getWriteThread() {
return exchange.getConnection().getIoThread();
}
@Override
public void setWriteReadyHandler(final WriteReadyHandler handler) {
this.writeReadyHandler = handler;
}
@Override
public boolean flush() throws IOException {
if (currentBuffer == null) {
if (anyAreSet(state, NEXT_SHUTDOWN)) {
return next.flush();
} else {
return true;
}
}
try {
boolean nextCreated = false;
try {
if (anyAreSet(state, SHUTDOWN)) {
if (anyAreSet(state, NEXT_SHUTDOWN)) {
return next.flush();
} else {
if (!performFlushIfRequired()) {
return false;
}
//if the deflater has not been fully flushed we need to flush it
if (!deflater.finished()) {
deflateData(false);
//if could not fully flush
if (!deflater.finished()) {
return false;
}
}
final ByteBuffer buffer = currentBuffer.getBuffer();
if (allAreClear(state, WRITTEN_TRAILER)) {
state |= WRITTEN_TRAILER;
byte[] data = getTrailer();
if (data != null) {
Connectors.updateResponseBytesSent(exchange, data.length);
if(additionalBuffer != null) {
byte[] newData = new byte[additionalBuffer.remaining() + data.length];
int pos = 0;
while (additionalBuffer.hasRemaining()) {
newData[pos++] = additionalBuffer.get();
}
for (byte aData : data) {
newData[pos++] = aData;
}
this.additionalBuffer = ByteBuffer.wrap(newData);
} else if(anyAreSet(state, FLUSHING_BUFFER) && buffer.capacity() - buffer.remaining() >= data.length) {
buffer.compact();
buffer.put(data);
buffer.flip();
} else if (data.length <= buffer.remaining() && !anyAreSet(state, FLUSHING_BUFFER)) {
buffer.put(data);
} else {
additionalBuffer = ByteBuffer.wrap(data);
}
}
}
//ok the deflater is flushed, now we need to flush the buffer
if (!anyAreSet(state, FLUSHING_BUFFER)) {
buffer.flip();
state |= FLUSHING_BUFFER;
if (next == null) {
nextCreated = true;
this.next = createNextChannel();
}
}
if (performFlushIfRequired()) {
state |= NEXT_SHUTDOWN;
freeBuffer();
next.terminateWrites();
return next.flush();
} else {
return false;
}
}
} else {
if(allAreClear(state, FLUSHING_BUFFER)) {
if (next == null) {
nextCreated = true;
this.next = createNextChannel();
}
deflateData(true);
if(allAreClear(state, FLUSHING_BUFFER)) {
//deflateData can cause this to be change
currentBuffer.getBuffer().flip();
this.state |= FLUSHING_BUFFER;
}
}
if(!performFlushIfRequired()) {
return false;
}
return next.flush();
}
} finally {
if (nextCreated) {
if (anyAreSet(state, WRITES_RESUMED) && !anyAreSet(state ,NEXT_SHUTDOWN)) {
try {
next.resumeWrites();
} catch (Throwable e) {
UndertowLogger.REQUEST_LOGGER.debug("Failed to resume", e);
}
}
}
}
} catch (IOException | RuntimeException | Error e) {
freeBuffer();
throw e;
}
}
/**
* called before the stream is finally flushed.
*/
protected byte[] getTrailer() {
return null;
}
/**
* The we are in the flushing state then we flush to the underlying stream, otherwise just return true
*
* @return false if there is still more to flush
*/
private boolean performFlushIfRequired() throws IOException {
if (anyAreSet(state, FLUSHING_BUFFER)) {
final ByteBuffer[] bufs = new ByteBuffer[additionalBuffer == null ? 1 : 2];
long totalLength = 0;
bufs[0] = currentBuffer.getBuffer();
totalLength += bufs[0].remaining();
if (additionalBuffer != null) {
bufs[1] = additionalBuffer;
totalLength += bufs[1].remaining();
}
if (totalLength > 0) {
long total = 0;
long res = 0;
do {
res = next.write(bufs, 0, bufs.length);
total += res;
if (res == 0) {
return false;
}
} while (total < totalLength);
}
additionalBuffer = null;
currentBuffer.getBuffer().clear();
state = state & ~FLUSHING_BUFFER;
}
return true;
}
private StreamSinkConduit createNextChannel() {
if (deflater.finished() && allAreSet(state, WRITTEN_TRAILER)) {
//the deflater was fully flushed before we created the channel. This means that what is in the buffer is
//all there is
int remaining = currentBuffer.getBuffer().remaining();
if (additionalBuffer != null) {
remaining += additionalBuffer.remaining();
}
if(!exchange.getResponseHeaders().contains(Headers.TRANSFER_ENCODING)) {
exchange.getResponseHeaders().put(Headers.CONTENT_LENGTH, Integer.toString(remaining));
}
} else {
exchange.getResponseHeaders().remove(Headers.CONTENT_LENGTH);
}
return conduitFactory.create();
}
/**
* Runs the current data through the deflater. As much as possible this will be buffered in the current output
* stream.
*
* @throws IOException
*/
private void deflateData(boolean force) throws IOException {
//we don't need to flush here, as this should have been called already by the time we get to
//this point
boolean nextCreated = false;
try (PooledByteBuffer arrayPooled = this.exchange.getConnection().getByteBufferPool().getArrayBackedPool().allocate()) {
PooledByteBuffer pooled = this.currentBuffer;
final ByteBuffer outputBuffer = pooled.getBuffer();
final boolean shutdown = anyAreSet(state, SHUTDOWN);
ByteBuffer buf = arrayPooled.getBuffer();
while (force || !deflater.needsInput() || (shutdown && !deflater.finished())) {
int count = deflater.deflate(buf.array(), buf.arrayOffset(), buf.remaining(), force ? Deflater.SYNC_FLUSH: Deflater.NO_FLUSH);
Connectors.updateResponseBytesSent(exchange, count);
if (count != 0) {
int remaining = outputBuffer.remaining();
if (remaining > count) {
outputBuffer.put(buf.array(), buf.arrayOffset(), count);
} else {
if (remaining == count) {
outputBuffer.put(buf.array(), buf.arrayOffset(), count);
} else {
outputBuffer.put(buf.array(), buf.arrayOffset(), remaining);
additionalBuffer = ByteBuffer.allocate(count - remaining);
additionalBuffer.put(buf.array(), buf.arrayOffset() + remaining, count - remaining);
additionalBuffer.flip();
}
outputBuffer.flip();
this.state |= FLUSHING_BUFFER;
if (next == null) {
nextCreated = true;
this.next = createNextChannel();
}
if (!performFlushIfRequired()) {
return;
}
}
} else {
force = false;
}
}
} finally {
if (nextCreated) {
if (anyAreSet(state, WRITES_RESUMED)) {
next.resumeWrites();
}
}
}
}
@Override
public void truncateWrites() throws IOException {
freeBuffer();
state |= CLOSED;
next.truncateWrites();
}
private void freeBuffer() {
if (currentBuffer != null) {
currentBuffer.close();
currentBuffer = null;
state = state & ~FLUSHING_BUFFER;
}
if (deflater != null) {
deflater = null;
pooledObject.close();
}
}
}