org.xnio.streams.ChannelInputStream Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of xnio-api Show documentation
Show all versions of xnio-api Show documentation
The API JAR of the XNIO project
/*
* JBoss, Home of Professional Open Source
*
* Copyright 2008 Red Hat, Inc. and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.xnio.streams;
import static org.xnio._private.Messages.msg;
import java.io.InputStream;
import java.io.IOException;
import java.io.InterruptedIOException;
import java.nio.ByteBuffer;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import org.xnio.Bits;
import org.xnio.channels.Channels;
import org.xnio.channels.StreamSourceChannel;
/**
* An input stream which reads from a stream source channel. All read operations are directly
* performed upon the channel, so for optimal performance, a buffering input stream should be
* used to wrap this class.
*
* @apiviz.exclude
*
* @since 1.2
*/
public class ChannelInputStream extends InputStream {
protected final StreamSourceChannel channel;
@SuppressWarnings("unused")
private volatile int flags;
private volatile long timeout;
private static final AtomicIntegerFieldUpdater flagsUpdater = AtomicIntegerFieldUpdater.newUpdater(ChannelInputStream.class, "flags");
private static final int FLAG_EOF = 2;
private static final int FLAG_ENTERED = 1;
/**
* Construct a new instance. The stream will have no read timeout.
*
* @param channel the channel to wrap
*/
public ChannelInputStream(final StreamSourceChannel channel) {
if (channel == null) {
throw msg.nullParameter("channel");
}
this.channel = channel;
}
/**
* Construct a new instance.
*
* @param channel the channel to wrap
* @param timeout the read timeout, or O for none
* @param timeoutUnit the time unit for read timeouts
*/
public ChannelInputStream(final StreamSourceChannel channel, final long timeout, final TimeUnit timeoutUnit) {
if (channel == null) {
throw msg.nullParameter("channel");
}
if (timeoutUnit == null) {
throw msg.nullParameter("timeoutUnit");
}
if (timeout < 0L) {
throw msg.parameterOutOfRange("timeout");
}
this.channel = channel;
final long calcTimeout = timeoutUnit.toNanos(timeout);
this.timeout = timeout == 0L ? 0L : calcTimeout < 1L ? 1L : calcTimeout;
}
private boolean enter() {
int old = flags;
do {
if (Bits.allAreSet(old, FLAG_ENTERED)) {
throw msg.concurrentAccess();
}
} while (! flagsUpdater.compareAndSet(this, old, old | FLAG_ENTERED));
return Bits.allAreSet(old, FLAG_EOF);
}
private void exit(boolean setEof) {
int oldFlags, newFlags;
do {
oldFlags = flags;
newFlags = oldFlags &~ FLAG_ENTERED;
if (setEof) {
newFlags |= FLAG_EOF;
}
} while (! flagsUpdater.compareAndSet(this, oldFlags, newFlags));
}
/**
* Get the read timeout.
*
* @param unit the time unit
* @return the timeout in the given unit
*/
public long getReadTimeout(TimeUnit unit) {
if (unit == null) {
throw msg.nullParameter("unit");
}
return unit.convert(timeout, TimeUnit.NANOSECONDS);
}
/**
* Set the read timeout. Does not affect read operations in progress.
*
* @param timeout the read timeout, or 0 for none
* @param unit the time unit
*/
public void setReadTimeout(long timeout, TimeUnit unit) {
if (timeout < 0L) {
throw msg.parameterOutOfRange("timeout");
}
if (unit == null) {
throw msg.nullParameter("unit");
}
final long calcTimeout = unit.toNanos(timeout);
this.timeout = timeout == 0L ? 0L : calcTimeout < 1L ? 1L : calcTimeout;
}
/** {@inheritDoc} */
public int read() throws IOException {
boolean eof = enter();
try {
if (eof) return -1;
final byte[] array = new byte[1];
final ByteBuffer buffer = ByteBuffer.wrap(array);
int res = channel.read(buffer);
if (res == 0) {
long timeout;
long start = System.nanoTime();
long elapsed = 0L;
do {
timeout = this.timeout;
if (timeout == 0L) {
channel.awaitReadable();
} else if (timeout < elapsed) {
throw msg.readTimeout();
} else {
channel.awaitReadable(timeout - elapsed, TimeUnit.NANOSECONDS);
}
elapsed = System.nanoTime() - start;
res = channel.read(buffer);
} while (res == 0);
}
return (eof = res == -1) ? -1 : array[0] & 0xff;
} finally {
exit(eof);
}
}
/** {@inheritDoc} */
public int read(final byte[] b) throws IOException {
return read(b, 0, b.length);
}
/** {@inheritDoc} */
public int read(final byte[] b, final int off, final int len) throws IOException {
if (len < 1 || off+len > b.length) {
return 0;
}
boolean eof = enter();
try {
if (eof) return -1;
final ByteBuffer buffer = ByteBuffer.wrap(b, off, len);
int res = channel.read(buffer);
if (res == 0) {
long timeout;
long start = System.nanoTime();
long elapsed = 0L;
do {
timeout = this.timeout;
if (timeout == 0L) {
channel.awaitReadable();
} else if (timeout < elapsed) {
throw msg.readTimeout();
} else {
channel.awaitReadable(timeout - elapsed, TimeUnit.NANOSECONDS);
}
elapsed = System.nanoTime() - start;
res = channel.read(buffer);
} while (res == 0);
}
return (eof = res == -1) ? -1 : buffer.position() - off;
} finally {
exit(eof);
}
}
/**
* Skip bytes in the stream.
*
* @param n the number of bytes to skip
* @return the number of bytes skipped (0 if the end of stream has been reached)
* @throws IOException if an I/O error occurs
*/
public long skip(long n) throws IOException {
if (n < 1L) {
return 0L;
}
boolean eof = enter();
try {
if (eof) return 0L;
// if we don't do this, InterruptedIOException might not be able to report a correct result
n = Math.min(n, (long)Integer.MAX_VALUE);
long total = 0L;
long timeout;
long start = System.nanoTime();
long elapsed = 0L;
long res;
for (;;) {
if (n == 0L) return total;
res = Channels.drain(channel, n);
if (res == -1) {
return total;
} else if (res == 0) {
timeout = this.timeout;
try {
if (timeout == 0L) {
channel.awaitReadable();
} else if (timeout < elapsed) {
throw msg.readTimeout();
} else {
channel.awaitReadable(timeout - elapsed, TimeUnit.NANOSECONDS);
}
} catch (InterruptedIOException e) {
assert total < (long) Integer.MAX_VALUE;
e.bytesTransferred = (int) total;
throw e;
}
elapsed = System.nanoTime() - start;
} else {
total += res;
n -= res;
}
}
} finally {
exit(eof);
}
}
/** {@inheritDoc} */
public void close() throws IOException {
enter();
try {
channel.shutdownReads();
} finally {
exit(true);
}
}
}