cn.tom.transport.nio.NioSession Maven / Gradle / Ivy
The newest version!
package cn.tom.transport.nio;
import java.io.Closeable;
import java.io.IOException;
import java.net.InetAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import cn.tom.kit.Helper;
import cn.tom.kit.IoBuffer;
import cn.tom.kit.ObjectId;
import cn.tom.transport.Id;
import cn.tom.transport.IoAdaptor;
import cn.tom.transport.Session;
import cn.tom.transport.WriteBufPool;
import cn.tom.transport.WriteBufPool.MergeRunnable;
public class NioSession implements Session,Closeable{
private static final Logger log = LoggerFactory.getLogger(NioSession.class);
private final Long id;
private long lastAsk = System.currentTimeMillis();
private int bufferSize = 0;
private IoBuffer readBuffer = null;
private LinkedBlockingQueue writeBufferQ = new LinkedBlockingQueue(); //LinkedBlockingQueue
private CountDownLatch connectLatch = new CountDownLatch(1);
private final Selectors selectors;
private final SocketChannel channel;
private SelectionKey registeredKey;
private SelectorThread selectorThread;
private ConcurrentMap attributes;
private static WriteBufPool writeBufPool = new WriteBufPool(); //TODO 针对性优化client or server
private Object attachment;
private final IoAdaptor ioAdaptor;
private boolean isServer = true;
public NioSession(Selectors selectors, SocketChannel channel, IoAdaptor ioAdaptor){
this(selectors, channel, null, ioAdaptor);
}
public NioSession(Selectors selectors, SocketChannel channel, Object attachment, IoAdaptor ioAdaptor){
this.selectors = selectors;
this.id = new ObjectId().toLong();
this.channel = channel;
this.attachment = attachment;
this.ioAdaptor = ioAdaptor;
this.bufferSize = ioAdaptor.getReadWritebufferSize();
readBuffer = IoBuffer.allocate(bufferSize);
}
public long id(){
return this.id;
}
public void write(T msg){
try {
write(ioAdaptor.encode(msg, this));
writeBufPool.exec(new MergeRunnable(this));
//取消触发模式 regOpsWrite();
} catch (Exception e) {
catchError(e);
}
}
public void write(IoBuffer buf) throws IOException{
if(buf.limit() > bufferSize){
log.warn("writebuff {} > bufferSize {} ", buf.limit(), bufferSize);
throw new IOException("writebuffSize > initbufferSize need ioadaptor set bufferSide");
}
if(!writeBufferQ.offer(buf.buf())){
String msg = "Session write buffer queue is full, message count="+writeBufferQ.size();
log.warn(msg);
throw new IOException(msg);
}
}
/**
* 添加write 标记,不再使用,用writeBufPool替代
* 原本使用selector 线程 控制 read 和 write 切换
* 现在使用 writeBufPool --> write
* selectThread -- read
*/
private void regOpsWrite() throws IOException{
if(this.registeredKey == null){
throw new IOException("Session not registered yet:"+this);
}
// //4 >> 2 &0x1 || 4 & SelectionKey.OP_WRITE
selectorThread.interestOps(registeredKey, registeredKey.interestOps() |SelectionKey.OP_WRITE);
}
@Override
public void writeAndFlush(T msg) {
try {
doWrite(ioAdaptor.encode(msg, this).buf());
} catch (IOException e) {
catchError(e);
}
}
@Override
public void flush() throws Exception {
synchronized (writeBufferQ) {
while(true){
int size = writeBufferQ.size();
if(size == 0) {
break;
}
if(size == 1){
ByteBuffer buf = writeBufferQ.poll(); //拷贝副本
doWrite(buf);
break;
}
if(size > 256) size = 256;
int capacity = writeBufferQ.peek().capacity() * size;
IoBuffer buffer = IoBuffer.allocate(capacity);
for (int i = 0; i < size; i++) {
ByteBuffer buf = writeBufferQ.poll();
buffer.writeBytes(buf.array(), 0, buf.limit());
}
buffer.flip();
ByteBuffer buf = buffer.buf();
doWrite(buf);
}
}
}
protected void doWrite(ByteBuffer buffer) throws IOException {
do {
this.channel.write(buffer);
} while (next(buffer));
}
private boolean next(ByteBuffer buffer){
if(buffer.remaining() == 0){
return false;
}
try {
Thread.sleep(1);
} catch (InterruptedException e) {
// ignore
}
return true;
}
protected int doWrite() throws IOException{
int n = 0;
synchronized (writeBufferQ) {
while(true){
ByteBuffer buf = writeBufferQ.peek();
if(buf == null){
selectorThread.interestOps(registeredKey, SelectionKey.OP_READ);
break;
}
int wbytes = this.channel.write(buf);
if(wbytes == 0 && buf.remaining() > 0){//
// 可以直接跳出, 不需要再添加 OP, 这个读写是循环的 OP次数> 循环次数
break;
}
n += wbytes;
if(buf.remaining() == 0){ // 使用 buf.remaining 来判断是否写完, 没写完, 继续写, 加入写不下了, 跳出 先读
buf.clear();
buf = null;
writeBufferQ.remove();
continue;
}
}
}
return n;
}
@Override
public void read(){
try {
doRead();
} catch (Exception e) {
catchError(e);;
}
}
public void doRead() throws IOException {
int n = 0;
boolean readed = false;
/* 一次性读写, 不多次扩展 readBuffer size, 最大传输字节数控制*/
/* 循环读取 依靠 SelectionKey 的触发*/
while((n = channel.read(readBuffer.buf())) > 0){
readed = true;
}
if(n == -1){ // 客户端断开判断 重要
NioSession.this.asyncClose();
return;
}
if(n == 0 && !readed) return;
IoBuffer buffer = readBuffer.flip();
while (true) {
final T msg = ioAdaptor.decode(buffer, this);
if(msg == null){
readBuffer.compact(); //压缩往前移
break;
}
//多线程执行抛到业务层 ThreadPool.exec(onMessage(msg))
onMessage(msg);
}
updateLastAsk();
}
void onMessage(T msg){
try{
ioAdaptor.onMessage(msg, this);
} catch(Exception e){
try{
ioAdaptor.onException(e, this);
}catch(IOException ee){
close();
}
}finally{
msg = null;
}
}
@Override
public int hashCode() {
return id.hashCode();
}
public String getRemoteAddress() {
if (isActive()) {
InetAddress addr = this.channel.socket().getInetAddress();
return String.format("%s:%d", addr.getHostAddress(),channel.socket().getPort());
}
return null;
}
public String getLocalAddress() {
if (isActive()) {
return Helper.localAddress(this.channel);
}
return null;
}
public int interestOps() throws IOException{
if(this.registeredKey == null){
throw new IOException("Session not registered yet:"+this);
}
return this.registeredKey.interestOps();
}
public void register(int interestOps) throws IOException{
selectors.registerSession(interestOps, this);
}
public SelectionKey getRegisteredKey() {
return registeredKey;
}
public void setRegisteredKey(SelectionKey key) {
this.registeredKey = key;
}
@Override
public void close() {
try{
if(!isActive()){
return;
}
ioAdaptor.onSessionDestroyed(this);
//放到channel.close前面,避免ClosedChannelException
if(this.registeredKey != null){
this.registeredKey.cancel();
this.registeredKey = null;
}
if(this.channel != null){
this.channel.close();
}
writeBufPool.shutdown(this);
writeBufferQ.clear();
readBuffer.buf.clear();
ioAdaptor.close();
}catch(IOException e){
log.error(e.getMessage(), e);
}
}
public void asyncClose() throws IOException{
selectorThread.unregisterSession(this);
}
public boolean isActive(){
if (channel != null && channel.isOpen()) {
return true;
}
return false;
}
public SocketChannel getChannel() {
return channel;
}
public Selectors selectors() {
return selectors;
}
public void finishConnect(){
this.connectLatch.countDown();
// 连接完成后初始化 写线程
writeBufPool.setCorePoolSize(2);
}
public boolean waitToConnect(long millis){
try {
return this.connectLatch.await(millis, TimeUnit.MILLISECONDS);
}catch (InterruptedException e) {
log.error(e.getMessage(), e);
}
return false;
}
@SuppressWarnings("unchecked")
public V getAttr(String key){
if(this.attributes == null){
return null;
}
return (V)this.attributes.get(key);
}
public void setAttr(String key, V value){
if(this.attributes == null){
synchronized (this) {
if(this.attributes == null){
this.attributes = new ConcurrentHashMap();
}
}
}
this.attributes.put(key, value);
}
@Override
public String toString() {
return "Session ["
+ "remote=" + getRemoteAddress()
+ ", isActive=" + isActive()
+ ", id=" + id
+ "]";
}
public Object getAttachment() {
return attachment;
}
public void setAttachment(Object attachment) {
this.attachment = attachment;
}
public IoAdaptor getIoAdaptor() {
return ioAdaptor;
}
public SelectorThread getSelectorThread() {
return selectorThread;
}
public void setSelectorThread(SelectorThread selectorThread) {
this.selectorThread = selectorThread;
}
@Override
public void catchError(Throwable e) {
log.error(e.getMessage(), e);
if(e instanceof IOException){
close();
}
}
@Override
public boolean isServer() {
return isServer;
}
public void setServer(boolean isServer) {
this.isServer = isServer;
}
public long getLastAsk() {
return lastAsk;
}
public void updateLastAsk() {
this.lastAsk = System.currentTimeMillis();
}
}