io.netty.channel.socket.nio.NioDomainSocketChannel Maven / Gradle / Ivy
Go to download
This artifact provides a single jar that contains all classes required to use remote EJB and JMS, including
all dependencies. It is intended for use by those not using maven, maven users should just import the EJB and
JMS BOM's instead (shaded JAR's cause lots of problems with maven, as it is very easy to inadvertently end up
with different versions on classes on the class path).
/*
* Copyright 2024 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.channel.socket.nio;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelConfig;
import io.netty.channel.ChannelException;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelOutboundBuffer;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelConfig;
import io.netty.channel.EventLoop;
import io.netty.channel.FileRegion;
import io.netty.channel.MessageSizeEstimator;
import io.netty.channel.RecvByteBufAllocator;
import io.netty.channel.WriteBufferWaterMark;
import io.netty.channel.nio.AbstractNioByteChannel;
import io.netty.channel.socket.DuplexChannel;
import io.netty.channel.socket.DuplexChannelConfig;
import io.netty.channel.socket.ServerSocketChannel;
import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.SocketUtils;
import io.netty.util.internal.SuppressJava6Requirement;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import java.io.IOException;
import java.lang.reflect.Method;
import java.net.SocketAddress;
import java.net.StandardSocketOptions;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.nio.channels.spi.SelectorProvider;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import static io.netty.channel.ChannelOption.SO_RCVBUF;
import static io.netty.channel.ChannelOption.SO_SNDBUF;
import static io.netty.channel.internal.ChannelUtils.MAX_BYTES_PER_GATHERING_WRITE_ATTEMPTED_LOW_THRESHOLD;
/**
* {@link DuplexChannel} which uses NIO selector based implementation to support
* UNIX Domain Sockets. This is only supported when using Java 16+.
*/
public final class NioDomainSocketChannel extends AbstractNioByteChannel
implements DuplexChannel {
private static final InternalLogger logger = InternalLoggerFactory.getInstance(NioDomainSocketChannel.class);
private static final SelectorProvider DEFAULT_SELECTOR_PROVIDER = SelectorProvider.provider();
private static final Method OPEN_SOCKET_CHANNEL_WITH_FAMILY =
SelectorProviderUtil.findOpenMethod("openSocketChannel");
private final ChannelConfig config;
private volatile boolean isInputShutdown;
private volatile boolean isOutputShutdown;
private static SocketChannel newChannel(SelectorProvider provider) {
if (PlatformDependent.javaVersion() < 16) {
throw new UnsupportedOperationException("Only supported on java 16+");
}
try {
SocketChannel channel = SelectorProviderUtil.newDomainSocketChannel(
OPEN_SOCKET_CHANNEL_WITH_FAMILY, provider);
if (channel == null) {
throw new ChannelException("Failed to open a socket.");
}
return channel;
} catch (IOException e) {
throw new ChannelException("Failed to open a socket.", e);
}
}
/**
* Create a new instance
*/
public NioDomainSocketChannel() {
this(DEFAULT_SELECTOR_PROVIDER);
}
/**
* Create a new instance using the given {@link SelectorProvider}.
*/
public NioDomainSocketChannel(SelectorProvider provider) {
this(newChannel(provider));
}
/**
* Create a new instance using the given {@link SocketChannel}.
*/
public NioDomainSocketChannel(SocketChannel socket) {
this(null, socket);
}
/**
* Create a new instance
*
* @param parent the {@link Channel} which created this instance or {@code null} if it was created by the user
* @param socket the {@link SocketChannel} which will be used
*/
public NioDomainSocketChannel(Channel parent, SocketChannel socket) {
super(parent, socket);
if (PlatformDependent.javaVersion() < 16) {
throw new UnsupportedOperationException("Only supported on java 16+");
}
config = new NioDomainSocketChannelConfig(this, socket);
}
@Override
public ServerSocketChannel parent() {
return (ServerSocketChannel) super.parent();
}
@Override
public ChannelConfig config() {
return config;
}
@Override
protected SocketChannel javaChannel() {
return (SocketChannel) super.javaChannel();
}
@Override
public boolean isActive() {
SocketChannel ch = javaChannel();
return ch.isOpen() && ch.isConnected();
}
@Override
public boolean isOutputShutdown() {
return isOutputShutdown || !isActive();
}
@Override
public boolean isInputShutdown() {
return isInputShutdown || !isActive();
}
@Override
public boolean isShutdown() {
return isInputShutdown() && isOutputShutdown() || !isActive();
}
@SuppressJava6Requirement(reason = "guarded by version check")
@Override
protected void doShutdownOutput() throws Exception {
javaChannel().shutdownOutput();
isOutputShutdown = true;
}
@Override
public ChannelFuture shutdownOutput() {
return shutdownOutput(newPromise());
}
@Override
public ChannelFuture shutdownOutput(final ChannelPromise promise) {
final EventLoop loop = eventLoop();
if (loop.inEventLoop()) {
((AbstractUnsafe) unsafe()).shutdownOutput(promise);
} else {
loop.execute(new Runnable() {
@Override
public void run() {
((AbstractUnsafe) unsafe()).shutdownOutput(promise);
}
});
}
return promise;
}
@Override
public ChannelFuture shutdownInput() {
return shutdownInput(newPromise());
}
@Override
protected boolean isInputShutdown0() {
return isInputShutdown();
}
@Override
public ChannelFuture shutdownInput(final ChannelPromise promise) {
EventLoop loop = eventLoop();
if (loop.inEventLoop()) {
shutdownInput0(promise);
} else {
loop.execute(new Runnable() {
@Override
public void run() {
shutdownInput0(promise);
}
});
}
return promise;
}
@Override
public ChannelFuture shutdown() {
return shutdown(newPromise());
}
@Override
public ChannelFuture shutdown(final ChannelPromise promise) {
ChannelFuture shutdownOutputFuture = shutdownOutput();
if (shutdownOutputFuture.isDone()) {
shutdownOutputDone(shutdownOutputFuture, promise);
} else {
shutdownOutputFuture.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(final ChannelFuture shutdownOutputFuture) throws Exception {
shutdownOutputDone(shutdownOutputFuture, promise);
}
});
}
return promise;
}
private void shutdownOutputDone(final ChannelFuture shutdownOutputFuture, final ChannelPromise promise) {
ChannelFuture shutdownInputFuture = shutdownInput();
if (shutdownInputFuture.isDone()) {
shutdownDone(shutdownOutputFuture, shutdownInputFuture, promise);
} else {
shutdownInputFuture.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture shutdownInputFuture) throws Exception {
shutdownDone(shutdownOutputFuture, shutdownInputFuture, promise);
}
});
}
}
private static void shutdownDone(ChannelFuture shutdownOutputFuture,
ChannelFuture shutdownInputFuture,
ChannelPromise promise) {
Throwable shutdownOutputCause = shutdownOutputFuture.cause();
Throwable shutdownInputCause = shutdownInputFuture.cause();
if (shutdownOutputCause != null) {
if (shutdownInputCause != null) {
logger.debug("Exception suppressed because a previous exception occurred.",
shutdownInputCause);
}
promise.setFailure(shutdownOutputCause);
} else if (shutdownInputCause != null) {
promise.setFailure(shutdownInputCause);
} else {
promise.setSuccess();
}
}
private void shutdownInput0(final ChannelPromise promise) {
try {
shutdownInput0();
promise.setSuccess();
} catch (Throwable t) {
promise.setFailure(t);
}
}
@SuppressJava6Requirement(reason = "Usage guarded by java version check")
private void shutdownInput0() throws Exception {
javaChannel().shutdownInput();
isInputShutdown = true;
}
@SuppressJava6Requirement(reason = "Usage guarded by java version check")
@Override
protected SocketAddress localAddress0() {
try {
return javaChannel().getLocalAddress();
} catch (Exception ignore) {
// ignore
}
return null;
}
@SuppressJava6Requirement(reason = "Usage guarded by java version check")
@Override
protected SocketAddress remoteAddress0() {
try {
return javaChannel().getRemoteAddress();
} catch (Exception ignore) {
// ignore
}
return null;
}
@Override
protected void doBind(SocketAddress localAddress) throws Exception {
SocketUtils.bind(javaChannel(), localAddress);
}
@Override
protected boolean doConnect(SocketAddress remoteAddress, SocketAddress localAddress) throws Exception {
if (localAddress != null) {
doBind(localAddress);
}
boolean success = false;
try {
boolean connected = SocketUtils.connect(javaChannel(), remoteAddress);
if (!connected) {
selectionKey().interestOps(SelectionKey.OP_CONNECT);
}
success = true;
return connected;
} finally {
if (!success) {
doClose();
}
}
}
@Override
protected void doFinishConnect() throws Exception {
if (!javaChannel().finishConnect()) {
throw new Error();
}
}
@Override
protected void doDisconnect() throws Exception {
doClose();
}
@Override
protected void doClose() throws Exception {
super.doClose();
javaChannel().close();
SocketAddress local = localAddress();
if (local != null) {
NioDomainSocketUtil.deleteSocketFile(local);
}
}
@Override
protected int doReadBytes(ByteBuf byteBuf) throws Exception {
final RecvByteBufAllocator.Handle allocHandle = unsafe().recvBufAllocHandle();
allocHandle.attemptedBytesRead(byteBuf.writableBytes());
return byteBuf.writeBytes(javaChannel(), allocHandle.attemptedBytesRead());
}
@Override
protected int doWriteBytes(ByteBuf buf) throws Exception {
final int expectedWrittenBytes = buf.readableBytes();
return buf.readBytes(javaChannel(), expectedWrittenBytes);
}
@Override
protected long doWriteFileRegion(FileRegion region) throws Exception {
final long position = region.transferred();
return region.transferTo(javaChannel(), position);
}
private void adjustMaxBytesPerGatheringWrite(int attempted, int written, int oldMaxBytesPerGatheringWrite) {
// By default we track the SO_SNDBUF when ever it is explicitly set. However some OSes may dynamically change
// SO_SNDBUF (and other characteristics that determine how much data can be written at once) so we should try
// make a best effort to adjust as OS behavior changes.
if (attempted == written) {
if (attempted << 1 > oldMaxBytesPerGatheringWrite) {
((NioDomainSocketChannelConfig) config).setMaxBytesPerGatheringWrite(attempted << 1);
}
} else if (attempted > MAX_BYTES_PER_GATHERING_WRITE_ATTEMPTED_LOW_THRESHOLD && written < attempted >>> 1) {
((NioDomainSocketChannelConfig) config).setMaxBytesPerGatheringWrite(attempted >>> 1);
}
}
@Override
protected void doWrite(ChannelOutboundBuffer in) throws Exception {
SocketChannel ch = javaChannel();
int writeSpinCount = config().getWriteSpinCount();
do {
if (in.isEmpty()) {
// All written so clear OP_WRITE
clearOpWrite();
// Directly return here so incompleteWrite(...) is not called.
return;
}
// Ensure the pending writes are made of ByteBufs only.
int maxBytesPerGatheringWrite = ((NioDomainSocketChannelConfig) config).getMaxBytesPerGatheringWrite();
ByteBuffer[] nioBuffers = in.nioBuffers(1024, maxBytesPerGatheringWrite);
int nioBufferCnt = in.nioBufferCount();
// Always use nioBuffers() to workaround data-corruption.
// See https://github.com/netty/netty/issues/2761
switch (nioBufferCnt) {
case 0:
// We have something else beside ByteBuffers to write so fallback to normal writes.
writeSpinCount -= doWrite0(in);
break;
case 1: {
// Only one ByteBuf so use non-gathering write
// Zero length buffers are not added to nioBuffers by ChannelOutboundBuffer, so there is no need
// to check if the total size of all the buffers is non-zero.
ByteBuffer buffer = nioBuffers[0];
int attemptedBytes = buffer.remaining();
final int localWrittenBytes = ch.write(buffer);
if (localWrittenBytes <= 0) {
incompleteWrite(true);
return;
}
adjustMaxBytesPerGatheringWrite(attemptedBytes, localWrittenBytes, maxBytesPerGatheringWrite);
in.removeBytes(localWrittenBytes);
--writeSpinCount;
break;
}
default: {
// Zero length buffers are not added to nioBuffers by ChannelOutboundBuffer, so there is no need
// to check if the total size of all the buffers is non-zero.
// We limit the max amount to int above so cast is safe
long attemptedBytes = in.nioBufferSize();
final long localWrittenBytes = ch.write(nioBuffers, 0, nioBufferCnt);
if (localWrittenBytes <= 0) {
incompleteWrite(true);
return;
}
// Casting to int is safe because we limit the total amount of data in the nioBuffers to int above.
adjustMaxBytesPerGatheringWrite((int) attemptedBytes, (int) localWrittenBytes,
maxBytesPerGatheringWrite);
in.removeBytes(localWrittenBytes);
--writeSpinCount;
break;
}
}
} while (writeSpinCount > 0);
incompleteWrite(writeSpinCount < 0);
}
@Override
protected AbstractNioUnsafe newUnsafe() {
return new NioSocketChannelUnsafe();
}
private final class NioSocketChannelUnsafe extends NioByteUnsafe {
// Only extending it so we create a new instance in newUnsafe() and return it.
}
private final class NioDomainSocketChannelConfig extends DefaultChannelConfig
implements DuplexChannelConfig {
private volatile boolean allowHalfClosure;
private volatile int maxBytesPerGatheringWrite = Integer.MAX_VALUE;
private final SocketChannel javaChannel;
private NioDomainSocketChannelConfig(NioDomainSocketChannel channel, SocketChannel javaChannel) {
super(channel);
this.javaChannel = javaChannel;
calculateMaxBytesPerGatheringWrite();
}
@Override
public boolean isAllowHalfClosure() {
return allowHalfClosure;
}
@Override
public NioDomainSocketChannelConfig setAllowHalfClosure(boolean allowHalfClosure) {
this.allowHalfClosure = allowHalfClosure;
return this;
}
@Override
public Map, Object> getOptions() {
List> options = new ArrayList>();
options.add(SO_RCVBUF);
options.add(SO_SNDBUF);
for (ChannelOption> opt : NioChannelOption.getOptions(jdkChannel())) {
options.add(opt);
}
return getOptions(super.getOptions(), options.toArray(new ChannelOption[0]));
}
@SuppressWarnings("unchecked")
@Override
public T getOption(ChannelOption option) {
if (option == SO_RCVBUF) {
return (T) Integer.valueOf(getReceiveBufferSize());
}
if (option == SO_SNDBUF) {
return (T) Integer.valueOf(getSendBufferSize());
}
if (option instanceof NioChannelOption) {
return NioChannelOption.getOption(jdkChannel(), (NioChannelOption) option);
}
return super.getOption(option);
}
@Override
public boolean setOption(ChannelOption option, T value) {
if (option == SO_RCVBUF) {
validate(option, value);
setReceiveBufferSize((Integer) value);
} else if (option == SO_SNDBUF) {
validate(option, value);
setSendBufferSize((Integer) value);
} else if (option instanceof NioChannelOption) {
return NioChannelOption.setOption(jdkChannel(), (NioChannelOption) option, value);
} else {
return super.setOption(option, value);
}
return true;
}
@SuppressJava6Requirement(reason = "Usage guarded by java version check")
private int getReceiveBufferSize() {
try {
return javaChannel.getOption(StandardSocketOptions.SO_RCVBUF);
} catch (IOException e) {
throw new ChannelException(e);
}
}
@SuppressJava6Requirement(reason = "Usage guarded by java version check")
private NioDomainSocketChannelConfig setReceiveBufferSize(int receiveBufferSize) {
try {
javaChannel.setOption(StandardSocketOptions.SO_RCVBUF, receiveBufferSize);
} catch (IOException e) {
throw new ChannelException(e);
}
return this;
}
@SuppressJava6Requirement(reason = "Usage guarded by java version check")
private int getSendBufferSize() {
try {
return javaChannel.getOption(StandardSocketOptions.SO_SNDBUF);
} catch (IOException e) {
throw new ChannelException(e);
}
}
@SuppressJava6Requirement(reason = "Usage guarded by java version check")
private NioDomainSocketChannelConfig setSendBufferSize(int sendBufferSize) {
try {
javaChannel.setOption(StandardSocketOptions.SO_SNDBUF, sendBufferSize);
} catch (IOException e) {
throw new ChannelException(e);
}
return this;
}
@Override
public NioDomainSocketChannelConfig setConnectTimeoutMillis(int connectTimeoutMillis) {
super.setConnectTimeoutMillis(connectTimeoutMillis);
return this;
}
@Override
@Deprecated
public NioDomainSocketChannelConfig setMaxMessagesPerRead(int maxMessagesPerRead) {
super.setMaxMessagesPerRead(maxMessagesPerRead);
return this;
}
@Override
public NioDomainSocketChannelConfig setWriteSpinCount(int writeSpinCount) {
super.setWriteSpinCount(writeSpinCount);
return this;
}
@Override
public NioDomainSocketChannelConfig setAllocator(ByteBufAllocator allocator) {
super.setAllocator(allocator);
return this;
}
@Override
public NioDomainSocketChannelConfig setRecvByteBufAllocator(RecvByteBufAllocator allocator) {
super.setRecvByteBufAllocator(allocator);
return this;
}
@Override
public NioDomainSocketChannelConfig setAutoRead(boolean autoRead) {
super.setAutoRead(autoRead);
return this;
}
@Override
public NioDomainSocketChannelConfig setAutoClose(boolean autoClose) {
super.setAutoClose(autoClose);
return this;
}
@Override
public NioDomainSocketChannelConfig setWriteBufferHighWaterMark(int writeBufferHighWaterMark) {
super.setWriteBufferHighWaterMark(writeBufferHighWaterMark);
return this;
}
@Override
public NioDomainSocketChannelConfig setWriteBufferLowWaterMark(int writeBufferLowWaterMark) {
super.setWriteBufferLowWaterMark(writeBufferLowWaterMark);
return this;
}
@Override
public NioDomainSocketChannelConfig setWriteBufferWaterMark(WriteBufferWaterMark writeBufferWaterMark) {
super.setWriteBufferWaterMark(writeBufferWaterMark);
return this;
}
@Override
public NioDomainSocketChannelConfig setMessageSizeEstimator(MessageSizeEstimator estimator) {
super.setMessageSizeEstimator(estimator);
return this;
}
@Override
protected void autoReadCleared() {
clearReadPending();
}
void setMaxBytesPerGatheringWrite(int maxBytesPerGatheringWrite) {
this.maxBytesPerGatheringWrite = maxBytesPerGatheringWrite;
}
int getMaxBytesPerGatheringWrite() {
return maxBytesPerGatheringWrite;
}
private void calculateMaxBytesPerGatheringWrite() {
// Multiply by 2 to give some extra space in case the OS can process write data faster than we can provide.
int newSendBufferSize = getSendBufferSize() << 1;
if (newSendBufferSize > 0) {
setMaxBytesPerGatheringWrite(newSendBufferSize);
}
}
private SocketChannel jdkChannel() {
return javaChannel;
}
}
}