io.undertow.websockets.extensions.PerMessageDeflateFunction Maven / Gradle / Ivy
/*
* JBoss, Home of Professional Open Source.
* Copyright 2014 Red Hat, Inc., and individual contributors
* as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.undertow.websockets.extensions;
import io.undertow.connector.ByteBufferPool;
import io.undertow.connector.PooledByteBuffer;
import io.undertow.util.ImmediatePooledByteBuffer;
import io.undertow.websockets.core.StreamSinkFrameChannel;
import io.undertow.websockets.core.StreamSourceFrameChannel;
import io.undertow.websockets.core.WebSocketChannel;
import io.undertow.websockets.core.WebSocketLogger;
import io.undertow.websockets.core.WebSocketMessages;
import org.xnio.Buffers;
import org.xnio.IoUtils;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.zip.DataFormatException;
import java.util.zip.Deflater;
import java.util.zip.Inflater;
/**
* Implementation of {@code permessage-deflate} WebSocket Extension.
*
* This implementation supports parameters: {@code server_no_context_takeover, client_no_context_takeover} .
*
* This implementation does not support parameters: {@code server_max_window_bits, client_max_window_bits} .
*
* It uses the DEFLATE implementation algorithm packaged on {@link Deflater} and {@link Inflater} classes.
*
* @author Lucas Ponce
* @see Compression Extensions for WebSocket
*/
public class PerMessageDeflateFunction implements ExtensionFunction {
private static final byte[] TAIL = new byte[]{0x00, 0x00, (byte) 0xFF, (byte) 0xFF};
private final int deflaterLevel;
private final boolean compressContextTakeover;
private final boolean decompressContextTakeover;
private final Inflater decompress;
private final Deflater compress;
private StreamSourceFrameChannel currentReadChannel;
/**
* Create a new {@code PerMessageDeflateExtension} instance.
*
* @param deflaterLevel the level of configuration of DEFLATE algorithm implementation
* @param compressContextTakeover flag for compressor context takeover or without compressor context
* @param decompressContextTakeover flag for decompressor context takeover or without decompressor context
*/
public PerMessageDeflateFunction(final int deflaterLevel, boolean compressContextTakeover, boolean decompressContextTakeover) {
this.deflaterLevel = deflaterLevel;
this.decompress = new Inflater(true);
this.compress = new Deflater(this.deflaterLevel, true);
this.compressContextTakeover = compressContextTakeover;
this.decompressContextTakeover = decompressContextTakeover;
}
@Override
public int writeRsv(int rsv) {
return rsv | RSV1;
}
@Override
public boolean hasExtensionOpCode() {
return false;
}
@Override
public synchronized PooledByteBuffer transformForWrite(PooledByteBuffer pooledBuffer, StreamSinkFrameChannel channel, boolean lastFrame) throws IOException {
ByteBuffer buffer = pooledBuffer.getBuffer();
PooledByteBuffer inputBuffer = null;
if (buffer.hasArray()) {
compress.setInput(buffer.array(), buffer.arrayOffset() + buffer.position(), buffer.remaining());
} else {
inputBuffer = toArrayBacked(buffer, channel.getWebSocketChannel().getBufferPool());
compress.setInput(inputBuffer.getBuffer().array(), inputBuffer.getBuffer().arrayOffset() + inputBuffer.getBuffer().position(), inputBuffer.getBuffer().remaining());
}
PooledByteBuffer output = allocateBufferWithArray(channel.getWebSocketChannel(), 0); // first pass
ByteBuffer outputBuffer = output.getBuffer();
boolean onceOnly = true;
try {
while ((!compress.needsInput() && !compress.finished()) || !outputBuffer.hasRemaining() || (onceOnly && lastFrame)) {
onceOnly = false;
//we need the hasRemaining check, because if the inflater fails to flush needsInput() will return false but it may have flushed an incomplete deflate block
if (!outputBuffer.hasRemaining()) {
output = largerBuffer(output, channel.getWebSocketChannel(), outputBuffer.capacity() * 2);
outputBuffer = output.getBuffer();
}
int n = compress.deflate(
outputBuffer.array(),
outputBuffer.arrayOffset() + outputBuffer.position(),
outputBuffer.remaining(), lastFrame ? Deflater.SYNC_FLUSH : Deflater.NO_FLUSH );
outputBuffer.position(outputBuffer.position() + n);
}
} finally {
// Free the buffer AFTER compression so it doesn't get re-used out from under us
IoUtils.safeClose(pooledBuffer, inputBuffer);
}
if(lastFrame) {
outputBuffer.put((byte) 0);
if (!compressContextTakeover) {
compress.reset();
}
}
outputBuffer.flip();
return output;
}
private PooledByteBuffer toArrayBacked(ByteBuffer buffer, ByteBufferPool pool) {
if(pool.getBufferSize() < buffer.remaining()) {
return new ImmediatePooledByteBuffer(ByteBuffer.wrap(Buffers.take(buffer)));
}
PooledByteBuffer newBuf = pool.getArrayBackedPool().allocate();
newBuf.getBuffer().put(buffer);
newBuf.getBuffer().flip();
return newBuf;
}
private PooledByteBuffer largerBuffer(PooledByteBuffer smaller, WebSocketChannel channel, int newSize) {
ByteBuffer smallerBuffer = smaller.getBuffer();
smallerBuffer.flip();
PooledByteBuffer larger = allocateBufferWithArray(channel, newSize);
larger.getBuffer().put(smallerBuffer);
smaller.close();
return larger;
}
private PooledByteBuffer allocateBufferWithArray(WebSocketChannel channel, int size) {
if (size > 0) {
if(size > channel.getBufferPool().getBufferSize()) {
// TODO use newer XNIO sized pool thingies smartly
return new ImmediatePooledByteBuffer(ByteBuffer.allocate(size));
}
}
return channel.getBufferPool().getArrayBackedPool().allocate();
}
@Override
public synchronized PooledByteBuffer transformForRead(PooledByteBuffer pooledBuffer, StreamSourceFrameChannel channel, boolean lastFragmentOfMessage) throws IOException {
if ((channel.getRsv() & 4) == 0) {
//rsv bit not set, this message is not compressed
return pooledBuffer;
}
PooledByteBuffer output = allocateBufferWithArray(channel.getWebSocketChannel(), 0); // first pass
PooledByteBuffer inputBuffer = null;
if (currentReadChannel != null && currentReadChannel != channel) {
//new channel, we did not get a last fragment message which can happens sometimes
decompress.setInput(TAIL);
output = decompress(channel.getWebSocketChannel(), output);
}
ByteBuffer buffer = pooledBuffer.getBuffer();
if (buffer.hasArray()) {
decompress.setInput(buffer.array(), buffer.arrayOffset() + buffer.position(), buffer.remaining());
} else {
inputBuffer = toArrayBacked(buffer, channel.getWebSocketChannel().getBufferPool());
decompress.setInput(inputBuffer.getBuffer().array(), inputBuffer.getBuffer().arrayOffset() + inputBuffer.getBuffer().position(), inputBuffer.getBuffer().remaining());
}
try {
output = decompress(channel.getWebSocketChannel(), output);
} finally {
// Free the buffer AFTER decompression so it doesn't get re-used out from under us
IoUtils.safeClose(inputBuffer, pooledBuffer);
}
if (lastFragmentOfMessage) {
decompress.setInput(TAIL);
output = decompress(channel.getWebSocketChannel(), output);
currentReadChannel = null;
} else {
currentReadChannel = channel;
}
output.getBuffer().flip();
return output;
}
private PooledByteBuffer decompress(WebSocketChannel channel, PooledByteBuffer pooled) throws IOException {
ByteBuffer buffer = pooled.getBuffer();
while (!decompress.needsInput() && !decompress.finished()) {
if (!buffer.hasRemaining()) {
pooled = largerBuffer(pooled, channel, buffer.capacity() * 2);
buffer = pooled.getBuffer();
}
int n;
try {
n = decompress.inflate(buffer.array(),
buffer.arrayOffset() + buffer.position(),
buffer.remaining());
} catch (DataFormatException e) {
WebSocketLogger.EXTENSION_LOGGER.debug(e.getMessage(), e);
throw WebSocketMessages.MESSAGES.badCompressedPayload(e);
}
buffer.position(buffer.position() + n);
}
return pooled;
}
@Override
public void dispose() {
// Call end so that native zlib resources can be immediately released rather than relying on finalizer
compress.end();
decompress.end();
}
}