All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.scassandra.server.actors.ConnectionHandler.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (C) 2014 Christopher Batey and Dogan Narinc
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.scassandra.server.actors

import akka.actor.{Actor, ActorRef, ActorRefFactory}
import akka.util.ByteString
import com.typesafe.scalalogging.slf4j.Logging
import org.scassandra.server.RegisterHandlerMessages
import org.scassandra.server.cqlmessages._
import org.scassandra.server.cqlmessages.response.UnsupportedProtocolVersion
import org.scassandra.server.priming.QueryHandlerMessages

/*
 * TODO: This class is on the verge of needing split up.
 *
 * This class's responsibility is:
 *  * To read full CQL messages from the data, knowing that they may not come all at once or that multiple
 *    messages may come in the same data packet
 *  * To check the opcode and forward to the correct handler
 *
 *  This could be changed to only know how to read full messages and then pass to an actor
 *  per stream that could check the opcode and forward.
 */
class ConnectionHandler(queryHandlerFactory: (ActorRefFactory, ActorRef, CqlMessageFactory) => ActorRef,
                        registerHandlerFactory: (ActorRefFactory, ActorRef, CqlMessageFactory) => ActorRef,
                        prepareHandler: ActorRef,
                        connectionWrapperFactory: (ActorRefFactory, ActorRef) => ActorRef) extends Actor with Logging {

  import akka.io.Tcp._

  var ready = false
  var partialMessage = false
  var dataFromPreviousMessage: ByteString = _
  var currentData: ByteString = _
  var messageFactory: CqlMessageFactory = _
  var registerHandler: ActorRef = _
  var queryHandler: ActorRef = _

  val ProtocolOneOrTwoHeaderLength = 8

  def receive = {

    case Received(data: ByteString) =>
      logger.trace(s"Received a message of length ${data.length} data:: $data")

      currentData = data
      if (partialMessage) {
        currentData = dataFromPreviousMessage ++ data
      }

      val messageLength = currentData.length
      logger.trace(s"Whole message length so far is $messageLength")

      // the header could be 8 or 9 bits now :(
      while (currentData.length >= ProtocolOneOrTwoHeaderLength && takeMessage()) {}



      if (currentData.length > 0) {
        logger.trace("Not received length yet..")
        partialMessage = true
        dataFromPreviousMessage = currentData
        currentData = ByteString()
      }
    case PeerClosed =>
      logger.info("Client disconnected.")
      context stop self
    case unknown@_ =>
      logger.warn(s"Unknown message $unknown")

  }

  private def processMessage(opCode: Byte, stream: Byte, messageBody: ByteString, protocolVersion: Byte) = {
    logger.trace(s"Whole body $messageBody with length ${messageBody.length}")

    opCode match {
      case OpCodes.Startup =>
        logger.debug("Sending ready message")
        initialiseMessageFactory(protocolVersion)
        val wrappedSender = connectionWrapperFactory(context, sender)
        queryHandler = queryHandlerFactory(context, wrappedSender, messageFactory)
        registerHandler = registerHandlerFactory(context, wrappedSender, messageFactory)
        wrappedSender ! messageFactory.createReadyMessage(stream)
        ready = true
      case OpCodes.Query =>
        if (!ready) {
          initialiseMessageFactory(protocolVersion)
          logger.info("Received query before startup message, sending error")
          sender ! Write(messageFactory.createQueryBeforeErrorMessage().serialize())
        } else {
          queryHandler ! QueryHandlerMessages.Query(messageBody, stream)
        }
      case OpCodes.Register =>
        logger.debug("Received register message. Sending to RegisterHandler")
        registerHandler ! RegisterHandlerMessages.Register(messageBody, stream)
      case OpCodes.Prepare =>
        logger.debug("Received prepare message. Sending to PrepareHandler")
        val wrappedSender = connectionWrapperFactory(context, sender)
        prepareHandler ! PrepareHandlerMessages.Prepare(messageBody, stream, messageFactory, wrappedSender)
      case OpCodes.Execute =>
        logger.debug("Received execute message. Sending to ExecuteHandler")
        val wrappedSender = connectionWrapperFactory(context, sender)
        prepareHandler ! PrepareHandlerMessages.Execute(messageBody, stream, messageFactory, wrappedSender)
      case opCode@_ =>
        logger.warn(s"Received unknown opcode $opCode this probably means this feature is yet to be implemented the message body is $messageBody")
    }
  }

  def initialiseMessageFactory(protocolVersion: Byte) = {
    messageFactory = if (protocolVersion == ProtocolVersion.ClientProtocolVersionOne) {
      logger.debug("Connection is for protocol version one")
      VersionOneMessageFactory
    } else {
      logger.debug("Connection is for protocol version two")
      VersionTwoMessageFactory
    }
  }

  /* should not be called if there isn't at least a header */
  private def takeMessage(): Boolean = {

    val protocolVersion = currentData(0)

    if (protocolVersion == VersionThree.clientCode) {
      logger.warn("Received a version three message, currently only one and two supported so sending an unsupported protocol error to get the driver to use an older version of the protocol.")
      val wrappedSender = connectionWrapperFactory(context, sender)
      // we can't really send the correct stream back as it is a different type (short rather than byte)
      wrappedSender ! UnsupportedProtocolVersion(0x0)(VersionTwo)
      currentData = ByteString()
      return false
    }


    val stream: Byte = currentData(2)
    val opCode: Byte = currentData(3)

    val bodyLengthArray = currentData.take(ProtocolOneOrTwoHeaderLength).drop(4)
    logger.debug(s"Body length array $bodyLengthArray")
    val bodyLength = bodyLengthArray.asByteBuffer.getInt
    logger.debug(s"Body length $bodyLength")

    if (currentData.length == bodyLength + ProtocolOneOrTwoHeaderLength) {
      logger.debug("Received exactly the whole message")
      partialMessage = false
      val messageBody = currentData.drop(ProtocolOneOrTwoHeaderLength)
      processMessage(opCode, stream, messageBody, protocolVersion)
      currentData = ByteString()
      false
    } else if (currentData.length > (bodyLength + ProtocolOneOrTwoHeaderLength)) {
      partialMessage = true
      logger.debug("Received a larger message than the length specifies - assume the rest is another message")
      val messageBody = currentData.drop(ProtocolOneOrTwoHeaderLength).take(bodyLength)
      logger.debug(s"Message received ${messageBody.utf8String}")
      processMessage(opCode, stream, messageBody, protocolVersion)
      currentData = currentData.drop(ProtocolOneOrTwoHeaderLength + bodyLength)
      true
    } else {
      logger.debug(s"Not received whole message yet, currently ${currentData.length} but need ${bodyLength + 8}")
      partialMessage = true
      dataFromPreviousMessage = currentData
      false
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy