All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cn.tom.transport.aio.AioSession Maven / Gradle / Ivy

The newest version!
package cn.tom.transport.aio;

import java.io.Closeable;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.AsynchronousSocketChannel;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import cn.tom.kit.IoBuffer;
import cn.tom.kit.ObjectId;
import cn.tom.transport.Id;
import cn.tom.transport.IoAdaptor;
import cn.tom.transport.Session;
import cn.tom.transport.WriteBufPool;
import cn.tom.transport.WriteBufPool.MergeRunnable;


public class AioSession implements Session, Closeable{
	private static final Logger log = LoggerFactory.getLogger(AioSession.class); 
	private final Long id; 
	
	private long lastAsk = System.currentTimeMillis();
	private int bufferSize = 0;
	private IoBuffer readBuffer = null;
	private LinkedBlockingQueue writeBufferQ = new LinkedBlockingQueue();
	private ConcurrentMap attributes;
	
	private CountDownLatch connectLatch = new CountDownLatch(1);
	private AsynchronousSocketChannel	channel;
	private IoAdaptor ioAdaptor;
	private AioServer aioServer;
	private static WriteBufPool writeBufPool =  new WriteBufPool();  //TODO
	private boolean isServer = true; 
	
	public AioSession(IoAdaptor ioAdaptor, AsynchronousSocketChannel channel) {
		this.ioAdaptor = ioAdaptor;
		this.channel = channel;
		this.id = new ObjectId().toLong();
		this.bufferSize = ioAdaptor.getReadWritebufferSize();
		readBuffer = IoBuffer.allocate(bufferSize);
	}

	@Override
	public long id(){
		return this.id;
	} 
	
	@Override
	public String getRemoteAddress() {
		if (isActive()) { 
			try {
				InetSocketAddress addr = (InetSocketAddress) this.channel.getRemoteAddress();
				return String.format("%s:%d", addr.getHostString(),addr.getPort());
			} catch (IOException e) {}
		} 
		return null;
	}

	@Override
	public String getLocalAddress() {
		if (isActive()) {
			try {
				InetSocketAddress addr = (InetSocketAddress) this.channel.getLocalAddress();
				return String.format("%s:%d", addr.getHostString(),addr.getPort());
			} catch (IOException e) {}
		}
		return null;
	}

	@Override
	public void write(T msg) {
		try {
			write(ioAdaptor.encode(msg, this));
			writeBufPool.exec(new MergeRunnable(this)); //多线程的时候 增大writeBufferQ 的数量, 一次性提交
		} catch (Exception e) {
			catchError(e);
		}
	}

	
	@Override
	public void writeAndFlush(T msg) {
		try {
			doWrite(ioAdaptor.encode(msg, this).buf());
		} catch (InterruptedException | ExecutionException e) {
			catchError(e);
		}
	}
	
	
	protected void write(IoBuffer buf) throws IOException, InterruptedException{
		if(buf.limit() > bufferSize){
			log.warn("writebuff {} > bufferSize {} ", buf.limit(), bufferSize);
			throw new IOException("writebuffSize  > initbufferSize need ioadaptor set bufferSide");
		}
		if(!writeBufferQ.offer(buf.buf(), 60, TimeUnit.SECONDS)){
			String msg = "Session write buffer queue is full, message count="+writeBufferQ.size();
			log.warn(msg);
			throw new IOException(msg);
		}
	}
	
	protected void doMergeWrite() throws InterruptedException, ExecutionException {
		synchronized (writeBufferQ) {
			while(true){
				int size = writeBufferQ.size();
				if (size == 0) {
					break;
				}
				if(size == 1){
					ByteBuffer buf = writeBufferQ.poll();  //拷贝副本
					doWrite(buf);
					break;
				}
				
				if (size > 256) size = 256;
				int capacity = writeBufferQ.peek().capacity() * size;
				IoBuffer buffer = IoBuffer.allocate(capacity);
				for (int i = 0; i < size; i++) {
					ByteBuffer buf = writeBufferQ.poll();
					buffer.writeBytes(buf.array(), 0, buf.limit());
				}
				buffer.flip();
				ByteBuffer buf = buffer.buf();
				doWrite(buf);
			}
		}
	}
	
	
	protected int doWrite(ByteBuffer buffer) throws InterruptedException, ExecutionException{
		int n =0;
		do {
			n += this.channel.write(buffer).get();
		} while (next(buffer));
		
		return n;
	}
	
	private boolean next(ByteBuffer buffer){
		if(buffer.remaining() == 0){
			return false;
		}
		try {
			Thread.sleep(1);
		} catch (InterruptedException e) {
			// ignore
		}
		return true;
	}
	
	
	/**
	 * 不合并
	 * @return
	 * @throws IOException
	 * @throws InterruptedException
	 * @throws ExecutionException
	 */
	protected int doWrite() throws InterruptedException, ExecutionException{ 
		int n = 0;
		synchronized (writeBufferQ) {
			while(true){
				ByteBuffer buf = writeBufferQ.peek();
				if(buf == null) break;
				int wbytes = doWrite(buf); 
				if(buf.remaining() > 0 || wbytes == 0){
					continue;
				}
				n += wbytes;
				if(buf.remaining() == 0){
					writeBufferQ.remove();
					continue;
				} 
			}
		}
		return n;
	}
	
	@Override
	public void read() {
		try{   // 多线程 最好同步 synch()
			channel.read(readBuffer.buf(), this, new ReadHandler());
		}catch(Exception e){
			catchError(e);
		}
	}
	
	public void doRead() throws IOException{   // 多线程, 最好同步 synch
		IoBuffer buffer = readBuffer.flip();
		while(true){
			final T msg = ioAdaptor.decode(buffer, this);
			if(msg == null){
				readBuffer.compact();
				break;
			}
			
			//业务线程应抛给业务层处理 ThreadPool.exec(onMessage(msg))
			onMessage(msg);			
		}
		
		updateLastAsk();
	}
	
	
	void onMessage(T msg){ 
		try{
			ioAdaptor.onMessage(msg, this);
		} catch(Exception e){ 
			try {
				ioAdaptor.onException(e, this);
			} catch (IOException ee) {
				close();
			} 
		}finally{
			msg = null;
		}
	}
	

	@Override
	public void flush() throws Exception {
		doMergeWrite();
//		doWrite();
	}

	@Override
	public boolean isActive() {
		if (channel != null && channel.isOpen()) {
			return true;
		}
		return false;
	}

	@Override
	public void asyncClose() throws IOException {
		if(!isActive()){
			return;
		}
		
		if(channel !=null) {
			this.channel.close();  
		}
		ioAdaptor.onSessionDestroyed(this);
		ioAdaptor.close();
		writeBufferQ.clear();
		readBuffer.buf.clear();
		writeBufPool.shutdown(this);
	}

	@SuppressWarnings("unchecked")
	public  V getAttr(String key){
		if(this.attributes == null){
			return null;
		}
		return (V)this.attributes.get(key);
	}
	
	public  void setAttr(String key, V value){
		if(this.attributes == null){
			synchronized (this) {
				if(this.attributes == null){
					this.attributes = new ConcurrentHashMap();
				}
			} 
		}
		this.attributes.put(key, value);
	}
	
	public void finishConnect(){
		this.connectLatch.countDown();
		writeBufPool.setCorePoolSize(2);
	}
	
	public boolean waitToConnect(long millis){
		try { 
			return this.connectLatch.await(millis, TimeUnit.MILLISECONDS); 
		}catch (InterruptedException e) {
			log.error(e.getMessage(), e);
		}
		return false;
	}

	public void setIoAdaptor(IoAdaptor ioAdaptor) {
		this.ioAdaptor = ioAdaptor;
	}
	
	public IoAdaptor getIoAdaptor() {
		return ioAdaptor;
	}

	@Override
	public void close(){
		try {
			asyncClose();
		} catch (IOException e) {
			log.error(e.getMessage(), e);
		}
	}
	
	@Override
	public String toString() {
		return "Session ["
				+ "remote=" + getRemoteAddress()
				+ ", isActive=" + isActive()  
	            + ", id=" + id   
				+ "]";
	}
	
	@Override
	public void catchError(Throwable e) {
		if(e instanceof IOException){
			close();
		}
		log.error(e.getMessage(), e);
	}

	@Override
	public boolean isServer() {
		return isServer;
	}

	public void setServer(boolean isServer) {
		this.isServer = isServer;
	}

	public AioServer getAioServer() {
		return aioServer;
	}

	public void setAioServer(AioServer aioServer) {
		this.aioServer = aioServer;
	}

	public long getLastAsk() {
		return lastAsk;
	}

	public void updateLastAsk() {
		this.lastAsk = System.currentTimeMillis();
	}  
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy