io.undertow.websockets.core.StreamSourceFrameChannel Maven / Gradle / Ivy
Go to download
This artifact provides a single jar that contains all classes required to use remote EJB and JMS, including
all dependencies. It is intended for use by those not using maven, maven users should just import the EJB and
JMS BOM's instead (shaded JAR's cause lots of problems with maven, as it is very easy to inadvertently end up
with different versions on classes on the class path).
/*
* JBoss, Home of Professional Open Source.
* Copyright 2014 Red Hat, Inc., and individual contributors
* as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.undertow.websockets.core;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import io.undertow.connector.PooledByteBuffer;
import io.undertow.websockets.core.function.ChannelFunction;
import io.undertow.websockets.core.function.ChannelFunctionFileChannel;
import io.undertow.websockets.core.protocol.version07.Masker;
import io.undertow.websockets.core.protocol.version07.UTF8Checker;
import io.undertow.websockets.extensions.ExtensionFunction;
import io.undertow.websockets.extensions.NoopExtensionFunction;
import org.xnio.channels.StreamSinkChannel;
import io.undertow.server.protocol.framed.AbstractFramedStreamSourceChannel;
import io.undertow.server.protocol.framed.FrameHeaderData;
/**
* Base class for processes Frame bases StreamSourceChannels.
*
* @author Norman Maurer
*/
public abstract class StreamSourceFrameChannel extends AbstractFramedStreamSourceChannel {
protected final WebSocketFrameType type;
private boolean finalFragment;
private final int rsv;
private final ChannelFunction[] functions;
private final ExtensionFunction extensionFunction;
private Masker masker;
private UTF8Checker checker;
protected StreamSourceFrameChannel(WebSocketChannel wsChannel, WebSocketFrameType type, PooledByteBuffer pooled, long frameLength) {
this(wsChannel, type, 0, true, pooled, frameLength, null);
}
protected StreamSourceFrameChannel(WebSocketChannel wsChannel, WebSocketFrameType type, int rsv, boolean finalFragment, PooledByteBuffer pooled, long frameLength, Masker masker, ChannelFunction... functions) {
super(wsChannel, pooled, frameLength);
this.type = type;
this.finalFragment = finalFragment;
this.rsv = rsv;
this.functions = functions;
this.masker = masker;
checker = null;
for (ChannelFunction func : functions) {
if (func instanceof UTF8Checker) {
checker = (UTF8Checker) func;
}
}
if (rsv > 0) {
this.extensionFunction = wsChannel.getExtensionFunction();
} else {
this.extensionFunction = NoopExtensionFunction.INSTANCE;
}
}
/**
* Return the {@link WebSocketFrameType} or {@code null} if its not known at the calling time.
*/
public WebSocketFrameType getType() {
return type;
}
/**
* Flag to indicate if this frame is the final fragment in a message. The first fragment (frame) may also be the
* final fragment.
*/
public boolean isFinalFragment() {
return finalFragment;
}
/**
* Return the rsv which is used for extensions.
*/
public int getRsv() {
return rsv;
}
int getWebSocketFrameCount() {
return getReadFrameCount();
}
@Override
protected WebSocketChannel getFramedChannel() {
return super.getFramedChannel();
}
public WebSocketChannel getWebSocketChannel() {
return getFramedChannel();
}
public void finalFrame() {
this.lastFrame();
this.finalFragment = true;
}
@Override
protected void handleHeaderData(FrameHeaderData headerData) {
super.handleHeaderData(headerData);
if (((WebSocketFrame) headerData).isFinalFragment()) {
finalFrame();
}
if(masker != null) {
masker.newFrame(headerData);
}
if(functions != null) {
for(ChannelFunction func : functions) {
func.newFrame(headerData);
}
}
}
@Override
public final long transferTo(long position, long count, FileChannel target) throws IOException {
long r;
if (functions != null && functions.length > 0) {
r = super.transferTo(position, count, new ChannelFunctionFileChannel(target, functions));
} else {
r = super.transferTo(position, count, target);
}
return r;
}
@Override
public final long transferTo(long count, ByteBuffer throughBuffer, StreamSinkChannel target) throws IOException {
// use this because of XNIO bug
// See https://issues.jboss.org/browse/XNIO-185
return WebSocketUtils.transfer(this, count, throughBuffer, target);
}
@Override
public int read(ByteBuffer dst) throws IOException {
int position = dst.position();
int r = super.read(dst);
if (r > 0) {
checker(dst, position, dst.position() - position, false);
} else if(r == -1) {
checkComplete();
}
return r;
}
@Override
public final long read(ByteBuffer[] dsts) throws IOException {
return read(dsts, 0, dsts.length);
}
@Override
public long read(ByteBuffer[] dsts, int offset, int length) throws IOException {
Bounds[] old = new Bounds[length];
for (int i = offset; i < length; i++) {
ByteBuffer dst = dsts[i];
old[i - offset] = new Bounds(dst.position(), dst.limit());
}
long b = super.read(dsts, offset, length);
if (b > 0) {
for (int i = offset; i < length; i++) {
ByteBuffer dst = dsts[i];
int oldPos = old[i - offset].position;
afterRead(dst, oldPos, dst.position() - oldPos);
}
} else if(b == -1){
checkComplete();
}
return b;
}
private void checkComplete() throws IOException {
try {
for (ChannelFunction func : functions) {
func.complete();
}
} catch (UnsupportedEncodingException e) {
getFramedChannel().markReadsBroken(e);
throw e;
}
}
/**
* Called after data was read into the {@link ByteBuffer}
*
* @param buffer the {@link ByteBuffer} into which the data was read
* @param position the position it was written to
* @param length the number of bytes there were written
* @throws IOException thrown if an error occurs
*/
protected void afterRead(ByteBuffer buffer, int position, int length) throws IOException {
try {
for (ChannelFunction func : functions) {
func.afterRead(buffer, position, length);
}
if (isComplete()) {
checkComplete();
}
} catch (UnsupportedEncodingException e) {
getFramedChannel().markReadsBroken(e);
throw e;
}
}
protected void checker(ByteBuffer buffer, int position, int length, boolean complete) throws IOException {
if (checker == null) {
return;
}
try {
checker.afterRead(buffer, position, length);
if (complete) {
try {
checker.complete();
} catch (UnsupportedEncodingException e) {
getFramedChannel().markReadsBroken(e);
throw e;
}
}
} catch (UnsupportedEncodingException e) {
getFramedChannel().markReadsBroken(e);
throw e;
}
}
@Override
protected PooledByteBuffer processFrameData(PooledByteBuffer frameData, boolean lastFragmentOfFrame) throws IOException {
if(masker != null) {
masker.afterRead(frameData.getBuffer(), frameData.getBuffer().position(), frameData.getBuffer().remaining());
}
try {
return extensionFunction.transformForRead(frameData, this, lastFragmentOfFrame && isFinalFragment());
} catch (IOException e) {
getWebSocketChannel().markReadsBroken(new WebSocketFrameCorruptedException(e));
throw e;
} catch (Exception e) {
getWebSocketChannel().markReadsBroken(new WebSocketFrameCorruptedException(e));
throw new IOException(e);
}
}
private static class Bounds {
final int position;
final int limit;
Bounds(int position, int limit) {
this.position = position;
this.limit = limit;
}
}
}