cn.tom.transport.aio.AioSession Maven / Gradle / Ivy
The newest version!
package cn.tom.transport.aio;
import java.io.Closeable;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.AsynchronousSocketChannel;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import cn.tom.kit.IoBuffer;
import cn.tom.kit.ObjectId;
import cn.tom.transport.Id;
import cn.tom.transport.IoAdaptor;
import cn.tom.transport.Session;
import cn.tom.transport.WriteBufPool;
import cn.tom.transport.WriteBufPool.MergeRunnable;
public class AioSession implements Session, Closeable{
private static final Logger log = LoggerFactory.getLogger(AioSession.class);
private final Long id;
private long lastAsk = System.currentTimeMillis();
private int bufferSize = 0;
private IoBuffer readBuffer = null;
private LinkedBlockingQueue writeBufferQ = new LinkedBlockingQueue();
private ConcurrentMap attributes;
private CountDownLatch connectLatch = new CountDownLatch(1);
private AsynchronousSocketChannel channel;
private IoAdaptor ioAdaptor;
private AioServer aioServer;
private static WriteBufPool writeBufPool = new WriteBufPool(); //TODO
private boolean isServer = true;
public AioSession(IoAdaptor ioAdaptor, AsynchronousSocketChannel channel) {
this.ioAdaptor = ioAdaptor;
this.channel = channel;
this.id = new ObjectId().toLong();
this.bufferSize = ioAdaptor.getReadWritebufferSize();
readBuffer = IoBuffer.allocate(bufferSize);
}
@Override
public long id(){
return this.id;
}
@Override
public String getRemoteAddress() {
if (isActive()) {
try {
InetSocketAddress addr = (InetSocketAddress) this.channel.getRemoteAddress();
return String.format("%s:%d", addr.getHostString(),addr.getPort());
} catch (IOException e) {}
}
return null;
}
@Override
public String getLocalAddress() {
if (isActive()) {
try {
InetSocketAddress addr = (InetSocketAddress) this.channel.getLocalAddress();
return String.format("%s:%d", addr.getHostString(),addr.getPort());
} catch (IOException e) {}
}
return null;
}
@Override
public void write(T msg) {
try {
write(ioAdaptor.encode(msg, this));
writeBufPool.exec(new MergeRunnable(this)); //多线程的时候 增大writeBufferQ 的数量, 一次性提交
} catch (Exception e) {
catchError(e);
}
}
@Override
public void writeAndFlush(T msg) {
try {
doWrite(ioAdaptor.encode(msg, this).buf());
} catch (InterruptedException | ExecutionException e) {
catchError(e);
}
}
protected void write(IoBuffer buf) throws IOException, InterruptedException{
if(buf.limit() > bufferSize){
log.warn("writebuff {} > bufferSize {} ", buf.limit(), bufferSize);
throw new IOException("writebuffSize > initbufferSize need ioadaptor set bufferSide");
}
if(!writeBufferQ.offer(buf.buf(), 60, TimeUnit.SECONDS)){
String msg = "Session write buffer queue is full, message count="+writeBufferQ.size();
log.warn(msg);
throw new IOException(msg);
}
}
protected void doMergeWrite() throws InterruptedException, ExecutionException {
synchronized (writeBufferQ) {
while(true){
int size = writeBufferQ.size();
if (size == 0) {
break;
}
if(size == 1){
ByteBuffer buf = writeBufferQ.poll(); //拷贝副本
doWrite(buf);
break;
}
if (size > 256) size = 256;
int capacity = writeBufferQ.peek().capacity() * size;
IoBuffer buffer = IoBuffer.allocate(capacity);
for (int i = 0; i < size; i++) {
ByteBuffer buf = writeBufferQ.poll();
buffer.writeBytes(buf.array(), 0, buf.limit());
}
buffer.flip();
ByteBuffer buf = buffer.buf();
doWrite(buf);
}
}
}
protected int doWrite(ByteBuffer buffer) throws InterruptedException, ExecutionException{
int n =0;
do {
n += this.channel.write(buffer).get();
} while (next(buffer));
return n;
}
private boolean next(ByteBuffer buffer){
if(buffer.remaining() == 0){
return false;
}
try {
Thread.sleep(1);
} catch (InterruptedException e) {
// ignore
}
return true;
}
/**
* 不合并
* @return
* @throws IOException
* @throws InterruptedException
* @throws ExecutionException
*/
protected int doWrite() throws InterruptedException, ExecutionException{
int n = 0;
synchronized (writeBufferQ) {
while(true){
ByteBuffer buf = writeBufferQ.peek();
if(buf == null) break;
int wbytes = doWrite(buf);
if(buf.remaining() > 0 || wbytes == 0){
continue;
}
n += wbytes;
if(buf.remaining() == 0){
writeBufferQ.remove();
continue;
}
}
}
return n;
}
@Override
public void read() {
try{ // 多线程 最好同步 synch()
channel.read(readBuffer.buf(), this, new ReadHandler());
}catch(Exception e){
catchError(e);
}
}
public void doRead() throws IOException{ // 多线程, 最好同步 synch
IoBuffer buffer = readBuffer.flip();
while(true){
final T msg = ioAdaptor.decode(buffer, this);
if(msg == null){
readBuffer.compact();
break;
}
//业务线程应抛给业务层处理 ThreadPool.exec(onMessage(msg))
onMessage(msg);
}
updateLastAsk();
}
void onMessage(T msg){
try{
ioAdaptor.onMessage(msg, this);
} catch(Exception e){
try {
ioAdaptor.onException(e, this);
} catch (IOException ee) {
close();
}
}finally{
msg = null;
}
}
@Override
public void flush() throws Exception {
doMergeWrite();
// doWrite();
}
@Override
public boolean isActive() {
if (channel != null && channel.isOpen()) {
return true;
}
return false;
}
@Override
public void asyncClose() throws IOException {
if(!isActive()){
return;
}
if(channel !=null) {
this.channel.close();
}
ioAdaptor.onSessionDestroyed(this);
ioAdaptor.close();
writeBufferQ.clear();
readBuffer.buf.clear();
writeBufPool.shutdown(this);
}
@SuppressWarnings("unchecked")
public V getAttr(String key){
if(this.attributes == null){
return null;
}
return (V)this.attributes.get(key);
}
public void setAttr(String key, V value){
if(this.attributes == null){
synchronized (this) {
if(this.attributes == null){
this.attributes = new ConcurrentHashMap();
}
}
}
this.attributes.put(key, value);
}
public void finishConnect(){
this.connectLatch.countDown();
writeBufPool.setCorePoolSize(2);
}
public boolean waitToConnect(long millis){
try {
return this.connectLatch.await(millis, TimeUnit.MILLISECONDS);
}catch (InterruptedException e) {
log.error(e.getMessage(), e);
}
return false;
}
public void setIoAdaptor(IoAdaptor ioAdaptor) {
this.ioAdaptor = ioAdaptor;
}
public IoAdaptor getIoAdaptor() {
return ioAdaptor;
}
@Override
public void close(){
try {
asyncClose();
} catch (IOException e) {
log.error(e.getMessage(), e);
}
}
@Override
public String toString() {
return "Session ["
+ "remote=" + getRemoteAddress()
+ ", isActive=" + isActive()
+ ", id=" + id
+ "]";
}
@Override
public void catchError(Throwable e) {
if(e instanceof IOException){
close();
}
log.error(e.getMessage(), e);
}
@Override
public boolean isServer() {
return isServer;
}
public void setServer(boolean isServer) {
this.isServer = isServer;
}
public AioServer getAioServer() {
return aioServer;
}
public void setAioServer(AioServer aioServer) {
this.aioServer = aioServer;
}
public long getLastAsk() {
return lastAsk;
}
public void updateLastAsk() {
this.lastAsk = System.currentTimeMillis();
}
}