All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.aerospike.client.async.NettyRecover Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2012-2024 Aerospike, Inc.
 *
 * Portions may be licensed to Aerospike, Inc. under one or more contributor
 * license agreements WHICH ARE COMPATIBLE WITH THE APACHE LICENSE, VERSION 2.0.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */
package com.aerospike.client.async;

import java.security.cert.X509Certificate;
import java.util.concurrent.TimeUnit;

import javax.net.ssl.SSLSession;

import com.aerospike.client.AerospikeException;
import com.aerospike.client.Log;
import com.aerospike.client.admin.AdminCommand;
import com.aerospike.client.async.HashedWheelTimer.HashedWheelTimeout;
import com.aerospike.client.cluster.Cluster;
import com.aerospike.client.cluster.Connection;
import com.aerospike.client.cluster.Node;
import com.aerospike.client.cluster.Node.AsyncPool;
import com.aerospike.client.command.Buffer;
import com.aerospike.client.command.Command;
import com.aerospike.client.policy.TlsPolicy;
import com.aerospike.client.util.Util;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;

public final class NettyRecover implements TimerTask {
	//private static final AtomicInteger Counter = new AtomicInteger();

	private final Cluster cluster;
	private final NettyEventLoop eventLoop;
	private final Node node;
	private final EventState eventState;
	private final NettyConnection conn;
	private final HashedWheelTimeout timeoutTask;
	private byte[] dataBuffer;
	//private final int tranId;
	private int offset;
	private int length;
	private int state;
	private final boolean isSingle;
	private final boolean saveBuffer;
	private boolean isLastGroup;

	public NettyRecover(NettyCommand cmd) {
		this.cluster = cmd.cluster;
		this.eventLoop = cmd.eventLoop;
		this.node = cmd.node;
		this.eventState = cmd.eventState;
		this.conn = cmd.conn;
		this.state = cmd.state;

		AsyncCommand a = cmd.command;
		this.dataBuffer = a.dataBuffer;  // take ownership of dataBuffer.
		this.saveBuffer = (this.dataBuffer != null)? dataBuffer.length <= AsyncCommand.MAX_BUFFER_SIZE : false;
		this.offset = a.dataOffset;
		this.length = a.receiveSize;
		this.isSingle = a.isSingle;

		//tranId = Counter.getAndIncrement();
		//System.out.println("" + tranId + " timeout:" + a.isSingle + ',' + cmd.state + ',' + offset + ',' + length);

		if (cmd.state == AsyncCommand.AUTH_READ_BODY && offset >= 2 && dataBuffer[1] != 0) {
			// Authentication failed.
			//System.out.println("" + tranId + " invalid user/password:");
			timeoutTask = null;
			abort(false);
			return;
		}

		// Replace channel handler.
		ChannelPipeline p = conn.channel.pipeline();

		if (p.last() != null) {
			p.removeLast();
		}
		p.addLast(new InboundHandler(this));

		timeoutTask = new HashedWheelTimeout(this);
		eventLoop.timer.addTimeout(timeoutTask, System.nanoTime() + TimeUnit.MILLISECONDS.toNanos(a.policy.timeoutDelay));
	}

	@Override
	public final void timeout() {
		//System.out.println("" + tranId + " timeout expired. close connection");

		// Command has been delayed long enough.
		// User has already been notified.
		// timeoutTask has already been removed, so avoid cancel.
		abort(false);
	}

	private void channelActive() {
		if (cluster.authEnabled) {
			byte[] token = node.getSessionToken();

			if (token != null) {
				writeAuth(token);
				return;
			}
		}

		recover();
	}

	private void writeAuth(byte[] token) {
		state = AsyncCommand.AUTH_WRITE;
		dataBuffer = new byte[512];

		AdminCommand admin = new AdminCommand(dataBuffer);
		int len = admin.setAuthenticate(cluster, token);

		ByteBuf byteBuffer = PooledByteBufAllocator.DEFAULT.directBuffer(len);
		byteBuffer.clear();
		byteBuffer.writeBytes(dataBuffer, 0, len);

		ChannelFuture cf = conn.channel.writeAndFlush(byteBuffer);
		cf.addListener(new ChannelFutureListener() {
			@Override
			public void operationComplete(ChannelFuture future) {
				if (state == AsyncCommand.AUTH_WRITE) {
					state = AsyncCommand.AUTH_READ_HEADER;
					conn.channel.config().setAutoRead(true);
				}
			}
		});
	}

	public void drain(ByteBuf byteBuffer) {
		try {
			switch (state) {
			case AsyncCommand.AUTH_READ_HEADER:
				drainSingleHeader(byteBuffer, AsyncCommand.AUTH_READ_BODY);
				break;

			case AsyncCommand.AUTH_READ_BODY:
				drainSingleBody(byteBuffer);
				break;

			case AsyncCommand.COMMAND_READ_HEADER:
				if (isSingle) {
					drainSingleHeader(byteBuffer, AsyncCommand.COMMAND_READ_BODY);
				}
				else {
					drainMultiHeader(byteBuffer);
				}
				break;

			case AsyncCommand.COMMAND_READ_BODY:
				if (isSingle) {
					drainSingleBody(byteBuffer);
				}
				else {
					if (! drainMultiBody(byteBuffer)) {
						return;
					}
					drainMultiHeader(byteBuffer);
				}
				break;
			}
		}
		finally {
			byteBuffer.release();
		}
	}

	private final void drainSingleHeader(ByteBuf byteBuffer, int nextState) {
		int readableBytes = byteBuffer.readableBytes();
		int dataSize = offset + readableBytes;

		if (dataSize < 8) {
			byteBuffer.readBytes(dataBuffer, offset, readableBytes);
			offset = dataSize;
			return;
		}

		dataSize = 8 - offset;
		byteBuffer.readBytes(dataBuffer, offset, dataSize);
		readableBytes -= dataSize;
		length = ((int)(Buffer.bytesToLong(dataBuffer, 0) & 0xFFFFFFFFFFFFL));

		state = nextState;
		offset = 0;
		drainSingleBody(byteBuffer);
	}

	private final void drainSingleBody(ByteBuf byteBuffer) {
		int readableBytes = byteBuffer.readableBytes();

		if (state == AsyncCommand.AUTH_READ_BODY && offset < 2 && offset + readableBytes >= 2) {
			int len = 2 - offset;
			byteBuffer.readBytes(dataBuffer, 0, len);
			readableBytes -= len;
			offset += len;

			byte resultCode = dataBuffer[len - 1];

			if (resultCode != 0) {
				// Authentication failed.
				//System.out.println("" + tranId + " invalid user/password:");
				abort(false);
				return;
			}

			if (readableBytes <= 0) {
				return;
			}
		}

		byteBuffer.skipBytes(readableBytes);
		offset += readableBytes;

		if (offset >= length) {
			recover();
		}
	}

	private final void drainMultiHeader(ByteBuf byteBuffer) {
		int readableBytes = byteBuffer.readableBytes();
		int dataSize;

		do {
			dataSize = offset + readableBytes;

			if (dataSize < 8) {
				byteBuffer.readBytes(dataBuffer, offset, readableBytes);
				offset = dataSize;
				return;
			}

			dataSize = 8 - offset;
			byteBuffer.readBytes(dataBuffer, offset, dataSize);
			readableBytes -= dataSize;
			long proto = Buffer.bytesToLong(dataBuffer, 0);
			length = ((int)(proto & 0xFFFFFFFFFFFFL));

			if (length == 0) {
				// Read next header.
				offset = 0;
				continue;
			}

			boolean compressed = ((proto >> 48) & 0xff) == Command.MSG_TYPE_COMPRESSED;

			if (compressed) {
				// Do not recover connections with compressed data because that would
				// require saving large buffers with associated state and performing decompression
				// just to drain the connection.
				throw new AerospikeException("Recovering connections with compressed multi-record data is not supported");
			}

			state = AsyncCommand.COMMAND_READ_BODY;
			offset = 0;

			if (readableBytes <= 0) {
				return;
			}

			if (! drainMultiBody(byteBuffer)) {
				return;
			}

			readableBytes = byteBuffer.readableBytes();
		} while (true);
	}

	private final boolean drainMultiBody(ByteBuf byteBuffer) {
		int readableBytes = byteBuffer.readableBytes();

		if (offset < 4 && offset + readableBytes >= 4) {
			int len = 4 - offset;
			byteBuffer.readBytes(dataBuffer, 0, len);
			readableBytes -= len;
			offset += len;

			// Warning: The following code assumes multi-record responses always end with a separate proto
			// that only contains one header with the info3 last group bit.  This is always true for batch
			// and scan, but query does not conform.  Therefore, connection recovery for queries will
			// likely fail.
			byte info3 = dataBuffer[len - 1];

			if ((info3 & Command.INFO3_LAST) != 0) {
				isLastGroup = true;
			}

			if (readableBytes <= 0) {
				return false;
			}
		}

		int needBytes = length - offset;
		int dataSize = (readableBytes >= needBytes)? needBytes : readableBytes;

		byteBuffer.skipBytes(dataSize);
		offset += dataSize;

		if (offset < length) {
			return false;
		}

		if (isLastGroup) {
			recover();
			return false;
		}

		// Prepare for next group.
		state = AsyncCommand.COMMAND_READ_HEADER;
		offset = 0;
		return true;
	}

	private final void recover() {
		//System.out.println("" + tranId + " connection drained");
		if (state == AsyncCommand.COMPLETE) {
			return;
		}
		state = AsyncCommand.COMPLETE;

		try {
			// Assign normal InboundHandler to connection.
			SocketChannel channel = conn.channel;
			channel.config().setAutoRead(false);

			ChannelPipeline p = channel.pipeline();
			p.removeLast();

			if (cluster.keepAlive == null) {
				p.addLast(new NettyCommand.InboundHandler());
			}
			else {
				AsyncPool pool = node.getAsyncPool(eventState.index);
				p.addLast(new NettyCommand.InboundHandler(pool));
			}

			// Put connection into pool.
			conn.updateLastUsed();
			node.putAsyncConnection(conn, eventLoop.index);

			// Close recovery command.
			close(true);
		}
		catch (Throwable e) {
			if (! eventState.closed) {
				Log.error("NettyRecover recover failed: " + Util.getStackTrace(e));
			}
		}
	}

	private final void abort(boolean cancelTimeout) {
		//System.out.println("" + tranId + " connection aborted");
		if (state == AsyncCommand.COMPLETE) {
			return;
		}
		state = AsyncCommand.COMPLETE;

		try {
			node.closeAsyncConnection(conn, eventLoop.index);
			close(cancelTimeout);
		}
		catch (Throwable e) {
			if (! eventState.closed) {
				Log.error("NettyRecover abort failed: " + Util.getStackTrace(e));
			}
		}
	}

	private final void close(boolean cancelTimeout) {
		if (cancelTimeout) {
			timeoutTask.cancel();
		}

		if (saveBuffer) {
			eventLoop.bufferQueue.addLast(dataBuffer);
		}
	}

	private static final class InboundHandler extends ChannelInboundHandlerAdapter {
		private final NettyRecover command;

		public InboundHandler(NettyRecover command) {
			this.command = command;
		}

		@Override
		public void channelActive(ChannelHandlerContext ctx) {
			// Mark connection ready in regular (non TLS) mode.
			// Otherwise, wait for TLS handshake to complete.
			if (command.state == AsyncCommand.CONNECT) {
				command.channelActive();
			}
		}

		@Override
		public void channelRead(ChannelHandlerContext ctx, Object msg) {
			command.drain((ByteBuf)msg);
		}

		@Override
		public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
			if (! (evt instanceof SslHandshakeCompletionEvent)) {
				return;
			}

			Throwable cause = ((SslHandshakeCompletionEvent)evt).cause();

			if (cause != null) {
				throw new AerospikeException("TLS connect failed: " + cause.getMessage(), cause);
			}

			TlsPolicy tlsPolicy = command.cluster.tlsPolicy;
			String tlsName = command.node.getHost().tlsName;
			SSLSession session = ((SslHandler)ctx.pipeline().first()).engine().getSession();
			X509Certificate cert = (X509Certificate)session.getPeerCertificates()[0];

			Connection.validateServerCertificate(tlsPolicy, tlsName, cert);

			if (command.state == AsyncCommand.TLS_HANDSHAKE) {
				command.channelActive();
			}
		}

		@Override
		public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
			//System.out.println("" + command.tranId + " socket error:");
			//cause.printStackTrace();
			command.abort(true);
		}
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy