All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.mauricio.async.db.mysql.codec.MySQLConnectionHandler.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2013 Maurício Linhares
 *
 * Maurício Linhares licenses this file to you under the Apache License,
 * version 2.0 (the "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at:
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */

package com.github.mauricio.async.db.mysql.codec

import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.util.concurrent.TimeUnit
import com.github.mauricio.async.db.Configuration
import com.github.mauricio.async.db.exceptions.DatabaseException
import com.github.mauricio.async.db.general.MutableResultSet
import com.github.mauricio.async.db.mysql.binary.BinaryRowDecoder
import com.github.mauricio.async.db.mysql.message.client._
import com.github.mauricio.async.db.mysql.message.server._
import com.github.mauricio.async.db.mysql.util.CharsetMapper
import com.github.mauricio.async.db.util.ChannelFutureTransformer._
import com.github.mauricio.async.db.util._
import com.google.common.cache._
import io.netty.bootstrap.Bootstrap
import io.netty.buffer.{ByteBuf, ByteBufAllocator, Unpooled}
import io.netty.channel._
import io.netty.handler.codec.CodecException

import scala.annotation.switch
import scala.collection.Seq
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.concurrent._
import scala.concurrent.duration.Duration
import scala.util._

object Stmt {

  val StmtPool: LoadingCache[String, String] = CacheBuilder
    .newBuilder()
    .maximumSize(4096)
    .expireAfterAccess(60, TimeUnit.SECONDS)
    .build(
      new CacheLoader[String, String] {
        def load(k: String) = k
      }
    )

  def pooled(stmt: String) = {
    StmtPool.get(stmt)
  }
}

class MySQLConnectionHandler(
  configuration: Configuration,
  charsetMapper: CharsetMapper,
  handlerDelegate: MySQLHandlerDelegate,
  executionContext: ExecutionContext,
  connectionId: String
) extends SimpleChannelInboundHandler[Object] {

  private final val group             = configuration.eventLoopGroup
  private implicit val internalPool   = executionContext
  private final val log               = Log.getByName(s"[connection-handler]${connectionId}")
  private final val bootstrap         = new Bootstrap().group(this.group)
  private final val connectionPromise = Promise[MySQLConnectionHandler]
  private final val decoder =
    new MySQLFrameDecoder(configuration.charset, connectionId)
  private final val encoder =
    new MySQLOneToOneEncoder(configuration.charset, charsetMapper)
  private final val sendLongDataEncoder = new SendLongDataEncoder()
  private final val currentParameters =
    new ArrayBuffer[ColumnDefinitionMessage]()
  private final val currentColumns = new ArrayBuffer[ColumnDefinitionMessage]()
  private final val parsedStatements: Cache[String, PreparedStatementHolder] =
    CacheBuilder
      .newBuilder()
      .maximumSize(configuration.preparedStatementCacheSize)
      .expireAfterAccess(
        configuration.preparedStatementExpireTime.toSeconds,
        TimeUnit.SECONDS
      )
      .removalListener(new RemovalListener[String, PreparedStatementHolder] {
        def onRemoval(
          removal: RemovalNotification[String, PreparedStatementHolder]
        ) = {
          log.debug("Closing preparestatement...")
          closePreparedStatment(removal.getValue().statementId)
        }
      })
      .build()

  private final val binaryRowDecoder = new BinaryRowDecoder()

  private var currentPreparedStatementHolder: PreparedStatementHolder = null
  private var currentPreparedStatement: PreparedStatement             = null
  private var currentQuery: MutableResultSet[ColumnDefinitionMessage] = null
  private var currentContext: ChannelHandlerContext                   = null

  def connect: Future[MySQLConnectionHandler] = {
    this.bootstrap.channel(NettyUtils.SocketChannelClass)
    this.bootstrap.handler(new ChannelInitializer[io.netty.channel.Channel]() {

      override def initChannel(channel: io.netty.channel.Channel): Unit = {
        channel.pipeline.addLast(
          decoder,
          encoder,
          sendLongDataEncoder,
          MySQLConnectionHandler.this
        )
      }

    })

    this.bootstrap.option[java.lang.Boolean](ChannelOption.SO_KEEPALIVE, true)
    this.bootstrap.option[ByteBufAllocator](
      ChannelOption.ALLOCATOR,
      LittleEndianByteBufAllocator.INSTANCE
    )

    this.bootstrap
      .connect(new InetSocketAddress(configuration.host, configuration.port))
      .asScala
      .onComplete {
        case Failure(exception) => this.connectionPromise.tryFailure(exception)
        case _                  =>
      }

    this.connectionPromise.future
  }

  override def channelRead0(ctx: ChannelHandlerContext, message: Object) = {
    message match {
      case m: ServerMessage => {
        (m.kind: @switch) match {
          case ServerMessage.ServerProtocolVersion => {
            handlerDelegate.onHandshake(m.asInstanceOf[HandshakeMessage])
          }
          case ServerMessage.Ok => {
            this.clearQueryState
            handlerDelegate.onOk(m.asInstanceOf[OkMessage])
          }
          case ServerMessage.Error => {
            this.clearQueryState
            handlerDelegate.onError(m.asInstanceOf[ErrorMessage])
          }
          case ServerMessage.EOF => {
            this.handleEOF(m)
          }
          case ServerMessage.ColumnDefinition => {
            val message = m.asInstanceOf[ColumnDefinitionMessage]

            if (currentPreparedStatementHolder != null && this.currentPreparedStatementHolder.needsAny) {
              this.currentPreparedStatementHolder.add(message)
            }

            this.currentColumns += message
          }
          case ServerMessage.ColumnDefinitionFinished => {
            this.onColumnDefinitionFinished()
          }
          case ServerMessage.PreparedStatementPrepareResponse => {
            this.onPreparedStatementPrepareResponse(
              m.asInstanceOf[PreparedStatementPrepareResponse]
            )
          }
          case ServerMessage.Row => {
            val message = m.asInstanceOf[ResultSetRowMessage]
            val items   = new Array[Any](message.size)

            var x = 0
            while (x < message.size) {
              items(x) = if (message(x) == null) {
                null
              } else {
                val columnDescription = this.currentQuery.columnTypes(x)
                val buf               = message(x)
                try {
                  columnDescription.textDecoder.decode(
                    columnDescription,
                    buf,
                    configuration.charset
                  )
                } finally {
                  buf.release()
                }
              }
              x += 1
            }

            this.currentQuery.addRow(items)
          }
          case ServerMessage.BinaryRow => {
            val message = m.asInstanceOf[BinaryRowMessage]
            val decoded =
              this.binaryRowDecoder.decode(message.buffer, this.currentColumns)
            message.buffer.release()
            this.currentQuery.addRow(decoded)
          }
          case ServerMessage.ParamProcessingFinished => {}
          case ServerMessage.ParamAndColumnProcessingFinished => {
            this.onColumnDefinitionFinished()
          }
        }
      }
    }

  }

  override def channelActive(ctx: ChannelHandlerContext): Unit = {
    log.debug("Channel became active")
    handlerDelegate.connected(ctx)
  }

  override def channelInactive(ctx: ChannelHandlerContext) = {
    log.debug("Channel became inactive")
  }

  override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) = {
    // unwrap CodecException if needed
    cause match {
      case t: CodecException => handleException(t.getCause)
      case _                 => handleException(cause)
    }

  }

  private def handleException(cause: Throwable) = {
    if (!this.connectionPromise.isCompleted) {
      this.connectionPromise.failure(cause)
    }
    handlerDelegate.exceptionCaught(cause)
  }

  override def handlerAdded(ctx: ChannelHandlerContext) = {
    this.currentContext = ctx
  }

  def write(message: QueryMessage): ChannelFuture = {
    this.decoder.queryProcessStarted()
    writeAndHandleError(message)
  }

  def sendPreparedStatement(
    query: String,
    values: Seq[Any]
  ): Future[ChannelFuture] = {
    val preparedStatement = new PreparedStatement(query, values)

    this.currentColumns.clear()
    this.currentParameters.clear()

    this.currentPreparedStatement = preparedStatement

    Option(
      this.parsedStatements
        .getIfPresent(Stmt.pooled(preparedStatement.statement))
    ) match {
      case Some(item) => {
        this.executePreparedStatement(
          item.statementId,
          item.columns.size,
          preparedStatement.values,
          item.parameters.toSeq
        )
      }
      case None => {
        decoder.preparedStatementPrepareStarted()
        writeAndHandleError(
          new PreparedStatementPrepareMessage(preparedStatement.statement)
        ).asScala
      }
    }
  }

  def write(message: HandshakeResponseMessage): ChannelFuture = {
    decoder.hasDoneHandshake = true
    writeAndHandleError(message)
  }

  def write(message: AuthenticationSwitchResponse): ChannelFuture =
    writeAndHandleError(message)

  def write(message: QuitMessage): ChannelFuture = {
    writeAndHandleError(message)
  }

  def disconnect: ChannelFuture = this.currentContext.close()

  def clearQueryState = {
    this.currentColumns.clear()
    this.currentParameters.clear()
    this.currentQuery = null
  }

  def isConnected: Boolean = {
    if (this.currentContext != null && this.currentContext.channel() != null) {
      this.currentContext.channel.isActive
    } else {
      false
    }
  }

  private def closePreparedStatment(statementId: Array[Byte]) = {
    writeAndHandleError(new PreparedStatementCloseMessage(statementId))
  }

  private def executePreparedStatement(
    statementId: Array[Byte],
    columnsCount: Int,
    values: Seq[Any],
    parameters: Seq[ColumnDefinitionMessage]
  ): Future[ChannelFuture] = {
    decoder.preparedStatementExecuteStarted(columnsCount, parameters.size)
    this.currentColumns.clear()
    this.currentParameters.clear()

    val (nonLongIndicesOpt, longValuesOpt) = values.zipWithIndex.map {
      case (Some(value), index) if isLong(value) => (None, Some(index, value))
      case (value, index) if isLong(value)       => (None, Some(index, value))
      case (_, index)                            => (Some(index), None)
    }.unzip
    val nonLongIndices: Seq[Int]    = nonLongIndicesOpt.flatten
    val longValues: Seq[(Int, Any)] = longValuesOpt.flatten

    if (longValues.nonEmpty) {
      val (firstIndex, firstValue) = longValues.head
      var channelFuture: Future[ChannelFuture] =
        sendLongParameter(statementId, firstIndex, firstValue)
      longValues.tail foreach {
        case (index, value) =>
          channelFuture = channelFuture.flatMap { _ =>
            sendLongParameter(statementId, index, value)
          }
      }
      channelFuture flatMap { _ =>
        writeAndHandleError(
          new PreparedStatementExecuteMessage(
            statementId,
            values,
            nonLongIndices.toSet,
            parameters
          )
        ).asScala
      }
    } else {
      writeAndHandleError(
        new PreparedStatementExecuteMessage(
          statementId,
          values,
          nonLongIndices.toSet,
          parameters
        )
      ).asScala
    }
  }

  private def isLong(value: Any): Boolean = {
    value match {
      case v: Array[Byte] => v.length > SendLongDataEncoder.LONG_THRESHOLD
      case v: ByteBuffer  => v.remaining() > SendLongDataEncoder.LONG_THRESHOLD
      case v: ByteBuf     => v.readableBytes() > SendLongDataEncoder.LONG_THRESHOLD

      case _ => false
    }
  }

  private def sendLongParameter(
    statementId: Array[Byte],
    index: Int,
    longValue: Any
  ): Future[ChannelFuture] = {
    longValue match {
      case v: Array[Byte] =>
        sendBuffer(Unpooled.wrappedBuffer(v), statementId, index).asScala

      case v: ByteBuffer =>
        sendBuffer(Unpooled.wrappedBuffer(v), statementId, index).asScala

      case v: ByteBuf =>
        sendBuffer(v, statementId, index).asScala
    }
  }

  private def sendBuffer(
    buffer: ByteBuf,
    statementId: Array[Byte],
    paramId: Int
  ): ChannelFuture = {
    writeAndHandleError(new SendLongDataMessage(statementId, buffer, paramId))
  }

  private def onPreparedStatementPrepareResponse(
    message: PreparedStatementPrepareResponse
  ) = {
    this.currentPreparedStatementHolder = new PreparedStatementHolder(
      this.currentPreparedStatement.statement,
      message
    )
  }

  def onColumnDefinitionFinished() = {

    val columns = if (this.currentPreparedStatementHolder != null) {
      this.currentPreparedStatementHolder.columns
    } else {
      this.currentColumns
    }

    this.currentQuery = new MutableResultSet[ColumnDefinitionMessage](columns)

    if (this.currentPreparedStatementHolder != null) {
      this.parsedStatements.put(
        Stmt.pooled(this.currentPreparedStatementHolder.statement),
        this.currentPreparedStatementHolder
      )
      this.executePreparedStatement(
        this.currentPreparedStatementHolder.statementId,
        this.currentPreparedStatementHolder.columns.size,
        this.currentPreparedStatement.values,
        this.currentPreparedStatementHolder.parameters
      )
      this.currentPreparedStatementHolder = null
      this.currentPreparedStatement = null
    }
  }

  private def writeAndHandleError(message: Any): ChannelFuture = {
    if (this.currentContext.channel().isActive) {
      val res = this.currentContext.writeAndFlush(message)

      res.asScala.onComplete {
        case Failure(e) => handleException(e)
        case _          =>
      }

      res
    } else {
      val error = new DatabaseException(
        "This channel is not active and can't take messages"
      )
      handleException(error)
      this.currentContext.channel().newFailedFuture(error)
    }
  }

  private def handleEOF(m: ServerMessage) = {
    m match {
      case eof: EOFMessage => {
        val resultSet = this.currentQuery
        this.clearQueryState

        if (resultSet != null) {
          handlerDelegate.onResultSet(resultSet, eof)
        } else {
          handlerDelegate.onEOF(eof)
        }
      }
      case authenticationSwitch: AuthenticationSwitchRequest => {
        handlerDelegate.switchAuthentication(authenticationSwitch)
      }
    }
  }

  def schedule(block: => Unit, duration: Duration): Unit = {
    this.currentContext
      .channel()
      .eventLoop()
      .schedule(new Runnable {
        override def run(): Unit = block
      }, duration.toMillis, TimeUnit.MILLISECONDS)
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy