All Downloads are FREE. Search and download functionalities are using the official Maven repository.

spark.network.ConnectionManager.scala Maven / Gradle / Ivy

The newest version!
package spark.network

import spark._

import java.nio._
import java.nio.channels._
import java.nio.channels.spi._
import java.net._
import java.util.concurrent.Executors

import scala.collection.mutable.HashMap
import scala.collection.mutable.SynchronizedMap
import scala.collection.mutable.SynchronizedQueue
import scala.collection.mutable.Queue
import scala.collection.mutable.ArrayBuffer

import akka.dispatch.{Await, Promise, ExecutionContext, Future}
import akka.util.Duration
import akka.util.duration._

private[spark] case class ConnectionManagerId(host: String, port: Int) {
  def toSocketAddress() = new InetSocketAddress(host, port)
}

private[spark] object ConnectionManagerId {
  def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = {
    new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort())
  }
}
  
private[spark] class ConnectionManager(port: Int) extends Logging {

  class MessageStatus(
      val message: Message,
      val connectionManagerId: ConnectionManagerId,
      completionHandler: MessageStatus => Unit) {

    var ackMessage: Option[Message] = None
    var attempted = false
    var acked = false

    def markDone() { completionHandler(this) }
  }
  
  val selector = SelectorProvider.provider.openSelector()
  val handleMessageExecutor = Executors.newFixedThreadPool(System.getProperty("spark.core.connection.handler.threads","20").toInt)
  val serverChannel = ServerSocketChannel.open()
  val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] 
  val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
  val messageStatuses = new HashMap[Int, MessageStatus] 
  val connectionRequests = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
  val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
  val sendMessageRequests = new Queue[(Message, SendingConnection)]

  implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool())

  var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null

  serverChannel.configureBlocking(false)
  serverChannel.socket.setReuseAddress(true)
  serverChannel.socket.setReceiveBufferSize(256 * 1024) 

  serverChannel.socket.bind(new InetSocketAddress(port))
  serverChannel.register(selector, SelectionKey.OP_ACCEPT)

  val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort)
  logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id)
  
  val selectorThread = new Thread("connection-manager-thread") {
    override def run() = ConnectionManager.this.run()
  }
  selectorThread.setDaemon(true)
  selectorThread.start()

  private def run() {
    try {
      while(!selectorThread.isInterrupted) {
        for ((connectionManagerId, sendingConnection) <- connectionRequests) {
          sendingConnection.connect() 
          addConnection(sendingConnection)
          connectionRequests -= connectionManagerId
        }
        sendMessageRequests.synchronized {
          while (!sendMessageRequests.isEmpty) {
            val (message, connection) = sendMessageRequests.dequeue
            connection.send(message)
          }
        }

        while (!keyInterestChangeRequests.isEmpty) {
          val (key, ops) = keyInterestChangeRequests.dequeue
          val connection = connectionsByKey(key)
          val lastOps = key.interestOps()
          key.interestOps(ops)
          
          def intToOpStr(op: Int): String = {
            val opStrs = ArrayBuffer[String]()
            if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ"
            if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE"
            if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT"
            if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT"
            if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " "
          }
          
          logTrace("Changed key for connection to [" + connection.remoteConnectionManagerId  + 
            "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]")
          
        }

        val selectedKeysCount = selector.select()
        if (selectedKeysCount == 0) {
          logDebug("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys")
        }
        if (selectorThread.isInterrupted) {
          logInfo("Selector thread was interrupted!")
          return
        }
        
        val selectedKeys = selector.selectedKeys().iterator()
        while (selectedKeys.hasNext()) {
          val key = selectedKeys.next
          selectedKeys.remove()
          if (key.isValid) {
            if (key.isAcceptable) {
              acceptConnection(key)
            } else if (key.isConnectable) {
              connectionsByKey(key).asInstanceOf[SendingConnection].finishConnect()
            } else if (key.isReadable) {
              connectionsByKey(key).read()
            } else if (key.isWritable) {
              connectionsByKey(key).write()
            }
          }
        }
      }
    } catch {
      case e: Exception => logError("Error in select loop", e)
    }
  }
  
  private def acceptConnection(key: SelectionKey) {
    val serverChannel = key.channel.asInstanceOf[ServerSocketChannel]
    val newChannel = serverChannel.accept()
    val newConnection = new ReceivingConnection(newChannel, selector)
    newConnection.onReceive(receiveMessage)
    newConnection.onClose(removeConnection)
    addConnection(newConnection)
    logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]")
  }

  private def addConnection(connection: Connection) {
    connectionsByKey += ((connection.key, connection))
    if (connection.isInstanceOf[SendingConnection]) {
      val sendingConnection = connection.asInstanceOf[SendingConnection]
      connectionsById += ((sendingConnection.remoteConnectionManagerId, sendingConnection))
    }
    connection.onKeyInterestChange(changeConnectionKeyInterest)
    connection.onException(handleConnectionError)
    connection.onClose(removeConnection)
  }

  private def removeConnection(connection: Connection) {
    connectionsByKey -= connection.key
    if (connection.isInstanceOf[SendingConnection]) {
      val sendingConnection = connection.asInstanceOf[SendingConnection]
      val sendingConnectionManagerId = sendingConnection.remoteConnectionManagerId
      logInfo("Removing SendingConnection to " + sendingConnectionManagerId)
      
      connectionsById -= sendingConnectionManagerId

      messageStatuses.synchronized {
        messageStatuses
          .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => {
            logInfo("Notifying " + status)
            status.synchronized {
            status.attempted = true 
             status.acked = false
             status.markDone()
            }
          })

        messageStatuses.retain((i, status) => { 
          status.connectionManagerId != sendingConnectionManagerId 
        })
      }
    } else if (connection.isInstanceOf[ReceivingConnection]) {
      val receivingConnection = connection.asInstanceOf[ReceivingConnection]
      val remoteConnectionManagerId = receivingConnection.remoteConnectionManagerId
      logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId)
      
      val sendingConnectionManagerId = connectionsById.keys.find(_.host == remoteConnectionManagerId.host).orNull
      if (sendingConnectionManagerId == null) {
        logError("Corresponding SendingConnectionManagerId not found")
        return
      }
      logInfo("Corresponding SendingConnectionManagerId is " + sendingConnectionManagerId)
      
      val sendingConnection = connectionsById(sendingConnectionManagerId)
      sendingConnection.close()
      connectionsById -= sendingConnectionManagerId
      
      messageStatuses.synchronized {
        for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) {
          logInfo("Notifying " + s)
          s.synchronized {
            s.attempted = true
            s.acked = false
            s.markDone()
          }
        }

        messageStatuses.retain((i, status) => { 
          status.connectionManagerId != sendingConnectionManagerId 
        })
      }
    }
  }

  private def handleConnectionError(connection: Connection, e: Exception) {
    logInfo("Handling connection error on connection to " + connection.remoteConnectionManagerId)
    removeConnection(connection)
  }

  private def changeConnectionKeyInterest(connection: Connection, ops: Int) {
    keyInterestChangeRequests += ((connection.key, ops))  
  }

  private def receiveMessage(connection: Connection, message: Message) {
    val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress)
    logDebug("Received [" + message + "] from [" + connectionManagerId + "]") 
    val runnable = new Runnable() {
      val creationTime = System.currentTimeMillis
      def run() {
        logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms")
        handleMessage(connectionManagerId, message)
        logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms")
      }
    }
    handleMessageExecutor.execute(runnable)
    /*handleMessage(connection, message)*/
  }

  private def handleMessage(connectionManagerId: ConnectionManagerId, message: Message) {
    logDebug("Handling [" + message + "] from [" + connectionManagerId + "]")
    message match {
      case bufferMessage: BufferMessage => {
        if (bufferMessage.hasAckId) {
          val sentMessageStatus = messageStatuses.synchronized {
            messageStatuses.get(bufferMessage.ackId) match {
              case Some(status) => { 
                messageStatuses -= bufferMessage.ackId 
                status
              }
              case None => { 
                throw new Exception("Could not find reference for received ack message " + message.id)
                null
              }
            }
          }
          sentMessageStatus.synchronized {
            sentMessageStatus.ackMessage = Some(message)
            sentMessageStatus.attempted = true
            sentMessageStatus.acked = true
            sentMessageStatus.markDone()
          }
        } else {
          val ackMessage = if (onReceiveCallback != null) {
            logDebug("Calling back")
            onReceiveCallback(bufferMessage, connectionManagerId)
          } else {
            logDebug("Not calling back as callback is null")
            None
          }
          
          if (ackMessage.isDefined) {
            if (!ackMessage.get.isInstanceOf[BufferMessage]) {
              logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass())
            } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) {
              logDebug("Response to " + bufferMessage + " does not have ack id set")
              ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
            }
          }

          sendMessage(connectionManagerId, ackMessage.getOrElse { 
            Message.createBufferMessage(bufferMessage.id)
          })
        }
      }
      case _ => throw new Exception("Unknown type message received")
    }
  }

  private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
    def startNewConnection(): SendingConnection = {
      val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port)
      val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId,
          new SendingConnection(inetSocketAddress, selector, connectionManagerId))
      newConnection   
    }
    val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress)
    val connection = connectionsById.getOrElse(lookupKey, startNewConnection())
    message.senderAddress = id.toSocketAddress()
    logDebug("Sending [" + message + "] to [" + connectionManagerId + "]")
    /*connection.send(message)*/
    sendMessageRequests.synchronized {
      sendMessageRequests += ((message, connection))
    }
    selector.wakeup()
  }

  def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message)
      : Future[Option[Message]] = {
    val promise = Promise[Option[Message]]
    val status = new MessageStatus(message, connectionManagerId, s => promise.success(s.ackMessage))
    messageStatuses.synchronized {
      messageStatuses += ((message.id, status))
    }
    sendMessage(connectionManagerId, message)
    promise.future
  }

  def sendMessageReliablySync(connectionManagerId: ConnectionManagerId, message: Message): Option[Message] = {
    Await.result(sendMessageReliably(connectionManagerId, message), Duration.Inf)
  }

  def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) {
    onReceiveCallback = callback
  }

  def stop() {
    selectorThread.interrupt()
    selectorThread.join()
    selector.close()
    val connections = connectionsByKey.values
    connections.foreach(_.close())
    if (connectionsByKey.size != 0) {
      logWarning("All connections not cleaned up")
    }
    handleMessageExecutor.shutdown()
    logInfo("ConnectionManager stopped")
  }
}


private[spark] object ConnectionManager {

  def main(args: Array[String]) {
    val manager = new ConnectionManager(9999)
    manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { 
      println("Received [" + msg + "] from [" + id + "]")
      None
    })
    
    /*testSequentialSending(manager)*/
    /*System.gc()*/

    /*testParallelSending(manager)*/
    /*System.gc()*/
    
    /*testParallelDecreasingSending(manager)*/
    /*System.gc()*/

    testContinuousSending(manager)
    System.gc()
  }

  def testSequentialSending(manager: ConnectionManager) {
    println("--------------------------")
    println("Sequential Sending")
    println("--------------------------")
    val size = 10 * 1024 * 1024 
    val count = 10
    
    val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
    buffer.flip

    (0 until count).map(i => {
      val bufferMessage = Message.createBufferMessage(buffer.duplicate)
      manager.sendMessageReliablySync(manager.id, bufferMessage)
    })
    println("--------------------------")
    println()
  }

  def testParallelSending(manager: ConnectionManager) {
    println("--------------------------")
    println("Parallel Sending")
    println("--------------------------")
    val size = 10 * 1024 * 1024 
    val count = 10

    val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
    buffer.flip

    val startTime = System.currentTimeMillis
    (0 until count).map(i => {
      val bufferMessage = Message.createBufferMessage(buffer.duplicate)
      manager.sendMessageReliably(manager.id, bufferMessage)
    }).foreach(f => {
      val g = Await.result(f, 1 second)
      if (!g.isDefined) println("Failed")
    })
    val finishTime = System.currentTimeMillis
    
    val mb = size * count / 1024.0 / 1024.0
    val ms = finishTime - startTime
    val tput = mb * 1000.0 / ms
    println("--------------------------")
    println("Started at " + startTime + ", finished at " + finishTime) 
    println("Sent " + count + " messages of size " + size + " in " + ms + " ms (" + tput + " MB/s)")
    println("--------------------------")
    println()
  }

  def testParallelDecreasingSending(manager: ConnectionManager) {
    println("--------------------------")
    println("Parallel Decreasing Sending")
    println("--------------------------")
    val size = 10 * 1024 * 1024 
    val count = 10
    val buffers = Array.tabulate(count)(i => ByteBuffer.allocate(size * (i + 1)).put(Array.tabulate[Byte](size * (i + 1))(x => x.toByte)))
    buffers.foreach(_.flip)
    val mb = buffers.map(_.remaining).reduceLeft(_ + _) / 1024.0 / 1024.0

    val startTime = System.currentTimeMillis
    (0 until count).map(i => {
      val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate)
      manager.sendMessageReliably(manager.id, bufferMessage)
    }).foreach(f => {
      val g = Await.result(f, 1 second)
      if (!g.isDefined) println("Failed")
    })
    val finishTime = System.currentTimeMillis
    
    val ms = finishTime - startTime
    val tput = mb * 1000.0 / ms
    println("--------------------------")
    /*println("Started at " + startTime + ", finished at " + finishTime) */
    println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)")
    println("--------------------------")
    println()
  }

  def testContinuousSending(manager: ConnectionManager) {
    println("--------------------------")
    println("Continuous Sending")
    println("--------------------------")
    val size = 10 * 1024 * 1024 
    val count = 10

    val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
    buffer.flip

    val startTime = System.currentTimeMillis
    while(true) {
      (0 until count).map(i => {
          val bufferMessage = Message.createBufferMessage(buffer.duplicate)
          manager.sendMessageReliably(manager.id, bufferMessage)
        }).foreach(f => {
          val g = Await.result(f, 1 second)
          if (!g.isDefined) println("Failed")
        })
      val finishTime = System.currentTimeMillis
      Thread.sleep(1000)
      val mb = size * count / 1024.0 / 1024.0
      val ms = finishTime - startTime
      val tput = mb * 1000.0 / ms
      println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)")
      println("--------------------------")
      println()
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy