com.github.mauricio.async.db.mysql.codec.MySQLConnectionHandler.scala Maven / Gradle / Ivy
/*
* Copyright 2013 Maurício Linhares
*
* Maurício Linhares licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package com.github.mauricio.async.db.mysql.codec
import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.util.concurrent.TimeUnit
import com.github.mauricio.async.db.Configuration
import com.github.mauricio.async.db.exceptions.DatabaseException
import com.github.mauricio.async.db.general.MutableResultSet
import com.github.mauricio.async.db.mysql.binary.BinaryRowDecoder
import com.github.mauricio.async.db.mysql.message.client._
import com.github.mauricio.async.db.mysql.message.server._
import com.github.mauricio.async.db.mysql.util.CharsetMapper
import com.github.mauricio.async.db.util.ChannelFutureTransformer.toFuture
import com.github.mauricio.async.db.util._
import io.netty.bootstrap.Bootstrap
import io.netty.buffer.{ByteBuf, ByteBufAllocator, Unpooled}
import io.netty.channel._
import io.netty.channel.socket.nio.NioSocketChannel
import io.netty.handler.codec.CodecException
import scala.annotation.switch
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.concurrent._
import scala.concurrent.duration.Duration
class MySQLConnectionHandler(
configuration: Configuration,
charsetMapper: CharsetMapper,
handlerDelegate: MySQLHandlerDelegate,
group : EventLoopGroup,
executionContext : ExecutionContext,
connectionId : String
)
extends SimpleChannelInboundHandler[Object] {
private implicit val internalPool = executionContext
private final val log = Log.getByName(s"[connection-handler]${connectionId}")
private final val bootstrap = new Bootstrap().group(this.group)
private final val connectionPromise = Promise[MySQLConnectionHandler]
private final val decoder = new MySQLFrameDecoder(configuration.charset, connectionId)
private final val encoder = new MySQLOneToOneEncoder(configuration.charset, charsetMapper)
private final val sendLongDataEncoder = new SendLongDataEncoder()
private final val currentParameters = new ArrayBuffer[ColumnDefinitionMessage]()
private final val currentColumns = new ArrayBuffer[ColumnDefinitionMessage]()
private final val parsedStatements = new HashMap[String,PreparedStatementHolder]()
private final val binaryRowDecoder = new BinaryRowDecoder()
private var currentPreparedStatementHolder : PreparedStatementHolder = null
private var currentPreparedStatement : PreparedStatement = null
private var currentQuery : MutableResultSet[ColumnDefinitionMessage] = null
private var currentContext: ChannelHandlerContext = null
def connect: Future[MySQLConnectionHandler] = {
this.bootstrap.channel(classOf[NioSocketChannel])
this.bootstrap.handler(new ChannelInitializer[io.netty.channel.Channel]() {
override def initChannel(channel: io.netty.channel.Channel): Unit = {
channel.pipeline.addLast(
decoder,
encoder,
sendLongDataEncoder,
MySQLConnectionHandler.this)
}
})
this.bootstrap.option[java.lang.Boolean](ChannelOption.SO_KEEPALIVE, true)
this.bootstrap.option[ByteBufAllocator](ChannelOption.ALLOCATOR, LittleEndianByteBufAllocator.INSTANCE)
this.bootstrap.connect(new InetSocketAddress(configuration.host, configuration.port)).onFailure {
case exception => this.connectionPromise.tryFailure(exception)
}
this.connectionPromise.future
}
override def channelRead0(ctx: ChannelHandlerContext, message: Object) {
message match {
case m: ServerMessage => {
(m.kind: @switch) match {
case ServerMessage.ServerProtocolVersion => {
handlerDelegate.onHandshake(m.asInstanceOf[HandshakeMessage])
}
case ServerMessage.Ok => {
this.clearQueryState
handlerDelegate.onOk(m.asInstanceOf[OkMessage])
}
case ServerMessage.Error => {
this.clearQueryState
handlerDelegate.onError(m.asInstanceOf[ErrorMessage])
}
case ServerMessage.EOF => {
this.handleEOF(m)
}
case ServerMessage.ColumnDefinition => {
val message = m.asInstanceOf[ColumnDefinitionMessage]
if ( currentPreparedStatementHolder != null && this.currentPreparedStatementHolder.needsAny ) {
this.currentPreparedStatementHolder.add(message)
}
this.currentColumns += message
}
case ServerMessage.ColumnDefinitionFinished => {
this.onColumnDefinitionFinished()
}
case ServerMessage.PreparedStatementPrepareResponse => {
this.onPreparedStatementPrepareResponse(m.asInstanceOf[PreparedStatementPrepareResponse])
}
case ServerMessage.Row => {
val message = m.asInstanceOf[ResultSetRowMessage]
val items = new Array[Any](message.size)
var x = 0
while ( x < message.size ) {
items(x) = if ( message(x) == null ) {
null
} else {
val columnDescription = this.currentQuery.columnTypes(x)
columnDescription.textDecoder.decode(columnDescription, message(x), configuration.charset)
}
x += 1
}
this.currentQuery.addRow(items)
}
case ServerMessage.BinaryRow => {
val message = m.asInstanceOf[BinaryRowMessage]
this.currentQuery.addRow( this.binaryRowDecoder.decode(message.buffer, this.currentColumns ))
}
case ServerMessage.ParamProcessingFinished => {
}
case ServerMessage.ParamAndColumnProcessingFinished => {
this.onColumnDefinitionFinished()
}
}
}
}
}
override def channelActive(ctx: ChannelHandlerContext): Unit = {
log.debug("Channel became active")
handlerDelegate.connected(ctx)
}
override def channelInactive(ctx: ChannelHandlerContext) {
log.debug("Channel became inactive")
}
override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) {
// unwrap CodecException if needed
cause match {
case t: CodecException => handleException(t.getCause)
case _ => handleException(cause)
}
}
private def handleException(cause: Throwable) {
if (!this.connectionPromise.isCompleted) {
this.connectionPromise.failure(cause)
}
handlerDelegate.exceptionCaught(cause)
}
override def handlerAdded(ctx: ChannelHandlerContext) {
this.currentContext = ctx
}
def write( message : QueryMessage ) : ChannelFuture = {
this.decoder.queryProcessStarted()
writeAndHandleError(message)
}
def sendPreparedStatement( query: String, values: Seq[Any] ): Future[ChannelFuture] = {
val preparedStatement = new PreparedStatement(query, values)
this.currentColumns.clear()
this.currentParameters.clear()
this.currentPreparedStatement = preparedStatement
this.parsedStatements.get(preparedStatement.statement) match {
case Some( item ) => {
this.executePreparedStatement(item.statementId, item.columns.size, preparedStatement.values, item.parameters)
}
case None => {
decoder.preparedStatementPrepareStarted()
writeAndHandleError( new PreparedStatementPrepareMessage(preparedStatement.statement) )
}
}
}
def write( message : HandshakeResponseMessage ) : ChannelFuture = {
decoder.hasDoneHandshake = true
writeAndHandleError(message)
}
def write( message : AuthenticationSwitchResponse ) : ChannelFuture = writeAndHandleError(message)
def write( message : QuitMessage ) : ChannelFuture = {
writeAndHandleError(message)
}
def disconnect: ChannelFuture = this.currentContext.close()
def clearQueryState {
this.currentColumns.clear()
this.currentParameters.clear()
this.currentQuery = null
}
def isConnected : Boolean = {
if ( this.currentContext != null && this.currentContext.channel() != null ) {
this.currentContext.channel.isActive
} else {
false
}
}
private def executePreparedStatement( statementId : Array[Byte], columnsCount : Int, values : Seq[Any], parameters : Seq[ColumnDefinitionMessage] ): Future[ChannelFuture] = {
decoder.preparedStatementExecuteStarted(columnsCount, parameters.size)
this.currentColumns.clear()
this.currentParameters.clear()
val (nonLongIndicesOpt, longValuesOpt) = values.zipWithIndex.map {
case (Some(value), index) if isLong(value) => (None, Some(index, value))
case (value, index) if isLong(value) => (None, Some(index, value))
case (_, index) => (Some(index), None)
}.unzip
val nonLongIndices: Seq[Int] = nonLongIndicesOpt.flatten
val longValues: Seq[(Int, Any)] = longValuesOpt.flatten
if (longValues.nonEmpty) {
val (firstIndex, firstValue) = longValues.head
var channelFuture: Future[ChannelFuture] = sendLongParameter(statementId, firstIndex, firstValue)
longValues.tail foreach { case (index, value) =>
channelFuture = channelFuture.flatMap { _ =>
sendLongParameter(statementId, index, value)
}
}
channelFuture flatMap { _ =>
writeAndHandleError(new PreparedStatementExecuteMessage(statementId, values, nonLongIndices.toSet, parameters))
}
} else {
writeAndHandleError(new PreparedStatementExecuteMessage(statementId, values, nonLongIndices.toSet, parameters))
}
}
private def isLong(value: Any): Boolean = {
value match {
case v : Array[Byte] => v.length > SendLongDataEncoder.LONG_THRESHOLD
case v : ByteBuffer => v.remaining() > SendLongDataEncoder.LONG_THRESHOLD
case v : ByteBuf => v.readableBytes() > SendLongDataEncoder.LONG_THRESHOLD
case _ => false
}
}
private def sendLongParameter(statementId: Array[Byte], index: Int, longValue: Any): Future[ChannelFuture] = {
longValue match {
case v : Array[Byte] =>
sendBuffer(Unpooled.wrappedBuffer(v), statementId, index)
case v : ByteBuffer =>
sendBuffer(Unpooled.wrappedBuffer(v), statementId, index)
case v : ByteBuf =>
sendBuffer(v, statementId, index)
}
}
private def sendBuffer(buffer: ByteBuf, statementId: Array[Byte], paramId: Int): ChannelFuture = {
writeAndHandleError(new SendLongDataMessage(statementId, buffer, paramId))
}
private def onPreparedStatementPrepareResponse( message : PreparedStatementPrepareResponse ) {
this.currentPreparedStatementHolder = new PreparedStatementHolder( this.currentPreparedStatement.statement, message)
}
def onColumnDefinitionFinished() {
val columns = if ( this.currentPreparedStatementHolder != null ) {
this.currentPreparedStatementHolder.columns
} else {
this.currentColumns
}
this.currentQuery = new MutableResultSet[ColumnDefinitionMessage](columns)
if ( this.currentPreparedStatementHolder != null ) {
this.parsedStatements.put( this.currentPreparedStatementHolder.statement, this.currentPreparedStatementHolder )
this.executePreparedStatement(
this.currentPreparedStatementHolder.statementId,
this.currentPreparedStatementHolder.columns.size,
this.currentPreparedStatement.values,
this.currentPreparedStatementHolder.parameters
)
this.currentPreparedStatementHolder = null
this.currentPreparedStatement = null
}
}
private def writeAndHandleError( message : Any ) : ChannelFuture = {
if ( this.currentContext.channel().isActive ) {
val res = this.currentContext.writeAndFlush(message)
res.onFailure {
case e : Throwable => handleException(e)
}
res
} else {
val error = new DatabaseException("This channel is not active and can't take messages")
handleException(error)
this.currentContext.channel().newFailedFuture(error)
}
}
private def handleEOF( m : ServerMessage ) {
m match {
case eof : EOFMessage => {
val resultSet = this.currentQuery
this.clearQueryState
if ( resultSet != null ) {
handlerDelegate.onResultSet( resultSet, eof )
} else {
handlerDelegate.onEOF(eof)
}
}
case authenticationSwitch : AuthenticationSwitchRequest => {
handlerDelegate.switchAuthentication(authenticationSwitch)
}
}
}
def schedule(block: => Unit, duration: Duration): Unit = {
this.currentContext.channel().eventLoop().schedule(new Runnable {
override def run(): Unit = block
}, duration.toMillis, TimeUnit.MILLISECONDS)
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy