All Downloads are FREE. Search and download functionalities are using the official Maven repository.

spark.broadcast.TreeBroadcast.scala Maven / Gradle / Ivy

The newest version!
package spark.broadcast

import java.io._
import java.net._
import java.util.{Comparator, Random, UUID}

import scala.collection.mutable.{ListBuffer, Map, Set}
import scala.math

import spark._
import spark.storage.StorageLevel

private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {

  def value = value_

  def blockId = "broadcast_" + id

  MultiTracker.synchronized {
    SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
  }

  @transient var arrayOfBlocks: Array[BroadcastBlock] = null
  @transient var totalBytes = -1
  @transient var totalBlocks = -1
  @transient var hasBlocks = 0

  @transient var listenPortLock = new Object
  @transient var guidePortLock = new Object
  @transient var totalBlocksLock = new Object
  @transient var hasBlocksLock = new Object

  @transient var listOfSources = ListBuffer[SourceInfo]()

  @transient var serveMR: ServeMultipleRequests = null
  @transient var guideMR: GuideMultipleRequests = null

  @transient var hostAddress = Utils.localIpAddress
  @transient var listenPort = -1
  @transient var guidePort = -1

  @transient var stopBroadcast = false

  // Must call this after all the variables have been created/initialized
  if (!isLocal) {
    sendBroadcast()
  }

  def sendBroadcast() {
    logInfo("Local host address: " + hostAddress)

    // Create a variableInfo object and store it in valueInfos
    var variableInfo = MultiTracker.blockifyObject(value_)

    // Prepare the value being broadcasted
    arrayOfBlocks = variableInfo.arrayOfBlocks
    totalBytes = variableInfo.totalBytes
    totalBlocks = variableInfo.totalBlocks
    hasBlocks = variableInfo.totalBlocks

    guideMR = new GuideMultipleRequests
    guideMR.setDaemon(true)
    guideMR.start()
    logInfo("GuideMultipleRequests started...")

    // Must always come AFTER guideMR is created
    while (guidePort == -1) {
      guidePortLock.synchronized { guidePortLock.wait() }
    }

    serveMR = new ServeMultipleRequests
    serveMR.setDaemon(true)
    serveMR.start()
    logInfo("ServeMultipleRequests started...")

    // Must always come AFTER serveMR is created
    while (listenPort == -1) {
      listenPortLock.synchronized { listenPortLock.wait() }
    }

    // Must always come AFTER listenPort is created
    val masterSource =
      SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes)
    listOfSources += masterSource

    // Register with the Tracker
    MultiTracker.registerBroadcast(id,
      SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
  }

  private def readObject(in: ObjectInputStream) {
    in.defaultReadObject()
    MultiTracker.synchronized {
      SparkEnv.get.blockManager.getSingle(blockId) match {
        case Some(x) =>
          value_ = x.asInstanceOf[T]

        case None =>
          logInfo("Started reading broadcast variable " + id)
          // Initializing everything because Driver will only send null/0 values
          // Only the 1st worker in a node can be here. Others will get from cache
          initializeWorkerVariables()

          logInfo("Local host address: " + hostAddress)

          serveMR = new ServeMultipleRequests
          serveMR.setDaemon(true)
          serveMR.start()
          logInfo("ServeMultipleRequests started...")

          val start = System.nanoTime

          val receptionSucceeded = receiveBroadcast(id)
          if (receptionSucceeded) {
            value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
            SparkEnv.get.blockManager.putSingle(
              blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
          }  else {
            logError("Reading broadcast variable " + id + " failed")
          }

          val time = (System.nanoTime - start) / 1e9
          logInfo("Reading broadcast variable " + id + " took " + time + " s")
      }
    }
  }

  private def initializeWorkerVariables() {
    arrayOfBlocks = null
    totalBytes = -1
    totalBlocks = -1
    hasBlocks = 0

    listenPortLock = new Object
    totalBlocksLock = new Object
    hasBlocksLock = new Object

    serveMR =  null

    hostAddress = Utils.localIpAddress
    listenPort = -1

    stopBroadcast = false
  }

  def receiveBroadcast(variableID: Long): Boolean = {
    val gInfo = MultiTracker.getGuideInfo(variableID)
    
    if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
      return false
    }

    // Wait until hostAddress and listenPort are created by the
    // ServeMultipleRequests thread
    while (listenPort == -1) {
      listenPortLock.synchronized { listenPortLock.wait() }
    }

    var clientSocketToDriver: Socket = null
    var oosDriver: ObjectOutputStream = null
    var oisDriver: ObjectInputStream = null

    // Connect and receive broadcast from the specified source, retrying the
    // specified number of times in case of failures
    var retriesLeft = MultiTracker.MaxRetryCount
    do {
      // Connect to Driver and send this worker's Information
      clientSocketToDriver = new Socket(MultiTracker.DriverHostAddress, gInfo.listenPort)
      oosDriver = new ObjectOutputStream(clientSocketToDriver.getOutputStream)
      oosDriver.flush()
      oisDriver = new ObjectInputStream(clientSocketToDriver.getInputStream)

      logDebug("Connected to Driver's guiding object")

      // Send local source information
      oosDriver.writeObject(SourceInfo(hostAddress, listenPort))
      oosDriver.flush()

      // Receive source information from Driver
      var sourceInfo = oisDriver.readObject.asInstanceOf[SourceInfo]
      totalBlocks = sourceInfo.totalBlocks
      arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
      totalBlocksLock.synchronized { totalBlocksLock.notifyAll() }
      totalBytes = sourceInfo.totalBytes

      logDebug("Received SourceInfo from Driver:" + sourceInfo + " My Port: " + listenPort)

      val start = System.nanoTime
      val receptionSucceeded = receiveSingleTransmission(sourceInfo)
      val time = (System.nanoTime - start) / 1e9

      // Updating some statistics in sourceInfo. Driver will be using them later
      if (!receptionSucceeded) {
        sourceInfo.receptionFailed = true
      }

      // Send back statistics to the Driver
      oosDriver.writeObject(sourceInfo)

      if (oisDriver != null) {
        oisDriver.close()
      }
      if (oosDriver != null) {
        oosDriver.close()
      }
      if (clientSocketToDriver != null) {
        clientSocketToDriver.close()
      }

      retriesLeft -= 1
    } while (retriesLeft > 0 && hasBlocks < totalBlocks)

    return (hasBlocks == totalBlocks)
  }

  /**
   * Tries to receive broadcast from the source and returns Boolean status.
   * This might be called multiple times to retry a defined number of times.
   */
  private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = {
    var clientSocketToSource: Socket = null
    var oosSource: ObjectOutputStream = null
    var oisSource: ObjectInputStream = null

    var receptionSucceeded = false
    try {
      // Connect to the source to get the object itself
      clientSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
      oosSource = new ObjectOutputStream(clientSocketToSource.getOutputStream)
      oosSource.flush()
      oisSource = new ObjectInputStream(clientSocketToSource.getInputStream)

      logDebug("Inside receiveSingleTransmission")
      logDebug("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)

      // Send the range
      oosSource.writeObject((hasBlocks, totalBlocks))
      oosSource.flush()

      for (i <- hasBlocks until totalBlocks) {
        val recvStartTime = System.currentTimeMillis
        val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
        val receptionTime = (System.currentTimeMillis - recvStartTime)

        logDebug("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")

        arrayOfBlocks(hasBlocks) = bcBlock
        hasBlocks += 1
        
        // Set to true if at least one block is received
        receptionSucceeded = true
        hasBlocksLock.synchronized { hasBlocksLock.notifyAll() }
      }
    } catch {
      case e: Exception => logError("receiveSingleTransmission had a " + e)
    } finally {
      if (oisSource != null) {
        oisSource.close()
      }
      if (oosSource != null) {
        oosSource.close()
      }
      if (clientSocketToSource != null) {
        clientSocketToSource.close()
      }
    }

    return receptionSucceeded
  }

  class GuideMultipleRequests
  extends Thread with Logging {
    // Keep track of sources that have completed reception
    private var setOfCompletedSources = Set[SourceInfo]()

    override def run() {
      var threadPool = Utils.newDaemonCachedThreadPool()
      var serverSocket: ServerSocket = null

      serverSocket = new ServerSocket(0)
      guidePort = serverSocket.getLocalPort
      logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)

      guidePortLock.synchronized { guidePortLock.notifyAll() }

      try {
        while (!stopBroadcast) {
          var clientSocket: Socket = null
          try {
            serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
            clientSocket = serverSocket.accept
          } catch {
            case e: Exception => {
              // Stop broadcast if at least one worker has connected and
              // everyone connected so far are done. Comparing with
              // listOfSources.size - 1, because it includes the Guide itself
              listOfSources.synchronized {
                setOfCompletedSources.synchronized {
                  if (listOfSources.size > 1 &&
                    setOfCompletedSources.size == listOfSources.size - 1) {
                    stopBroadcast = true
                    logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.")
                  }
                }
              }
            }
          }
          if (clientSocket != null) {
            logDebug("Guide: Accepted new client connection: " + clientSocket)
            try {
              threadPool.execute(new GuideSingleRequest(clientSocket))
            } catch {
              // In failure, close() the socket here; else, the thread will close() it
              case ioe: IOException => clientSocket.close()
            }
          }
        }

        logInfo("Sending stopBroadcast notifications...")
        sendStopBroadcastNotifications

        MultiTracker.unregisterBroadcast(id)
      } finally {
        if (serverSocket != null) {
          logInfo("GuideMultipleRequests now stopping...")
          serverSocket.close()
        }
      }
      // Shutdown the thread pool
      threadPool.shutdown()
    }

    private def sendStopBroadcastNotifications() {
      listOfSources.synchronized {
        var listIter = listOfSources.iterator
        while (listIter.hasNext) {
          var sourceInfo = listIter.next

          var guideSocketToSource: Socket = null
          var gosSource: ObjectOutputStream = null
          var gisSource: ObjectInputStream = null

          try {
            // Connect to the source
            guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
            gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream)
            gosSource.flush()
            gisSource = new ObjectInputStream(guideSocketToSource.getInputStream)

            // Send stopBroadcast signal
            gosSource.writeObject((SourceInfo.StopBroadcast, SourceInfo.StopBroadcast))
            gosSource.flush()
          } catch {
            case e: Exception => {
              logError("sendStopBroadcastNotifications had a " + e)
            }
          } finally {
            if (gisSource != null) {
              gisSource.close()
            }
            if (gosSource != null) {
              gosSource.close()
            }
            if (guideSocketToSource != null) {
              guideSocketToSource.close()
            }
          }
        }
      }
    }

    class GuideSingleRequest(val clientSocket: Socket)
    extends Thread with Logging {
      private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
      oos.flush()
      private val ois = new ObjectInputStream(clientSocket.getInputStream)

      private var selectedSourceInfo: SourceInfo = null
      private var thisWorkerInfo:SourceInfo = null

      override def run() {
        try {
          logInfo("new GuideSingleRequest is running")
          // Connecting worker is sending in its hostAddress and listenPort it will
          // be listening to. Other fields are invalid (SourceInfo.UnusedParam)
          var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]

          listOfSources.synchronized {
            // Select a suitable source and send it back to the worker
            selectedSourceInfo = selectSuitableSource(sourceInfo)
            logDebug("Sending selectedSourceInfo: " + selectedSourceInfo)
            oos.writeObject(selectedSourceInfo)
            oos.flush()

            // Add this new (if it can finish) source to the list of sources
            thisWorkerInfo = SourceInfo(sourceInfo.hostAddress,
              sourceInfo.listenPort, totalBlocks, totalBytes)
            logDebug("Adding possible new source to listOfSources: " + thisWorkerInfo)
            listOfSources += thisWorkerInfo
          }

          // Wait till the whole transfer is done. Then receive and update source
          // statistics in listOfSources
          sourceInfo = ois.readObject.asInstanceOf[SourceInfo]

          listOfSources.synchronized {
            // This should work since SourceInfo is a case class
            assert(listOfSources.contains(selectedSourceInfo))

            // Remove first 
            // (Currently removing a source based on just one failure notification!)
            listOfSources = listOfSources - selectedSourceInfo

            // Update sourceInfo and put it back in, IF reception succeeded
            if (!sourceInfo.receptionFailed) {
              // Add thisWorkerInfo to sources that have completed reception
              setOfCompletedSources.synchronized {
                setOfCompletedSources += thisWorkerInfo
              }

              // Update leecher count and put it back in 
              selectedSourceInfo.currentLeechers -= 1
              listOfSources += selectedSourceInfo
            }
          }
        } catch {
          case e: Exception => {
            // Remove failed worker from listOfSources and update leecherCount of
            // corresponding source worker
            listOfSources.synchronized {
              if (selectedSourceInfo != null) {
                // Remove first
                listOfSources = listOfSources - selectedSourceInfo
                // Update leecher count and put it back in
                selectedSourceInfo.currentLeechers -= 1
                listOfSources += selectedSourceInfo
              }

              // Remove thisWorkerInfo
              if (listOfSources != null) {
                listOfSources = listOfSources - thisWorkerInfo
              }
            }
          }
        } finally {
          logInfo("GuideSingleRequest is closing streams and sockets")
          ois.close()
          oos.close()
          clientSocket.close()
        }
      }

      // Assuming the caller to have a synchronized block on listOfSources
      // Select one with the most leechers. This will level-wise fill the tree
      private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
        var maxLeechers = -1
        var selectedSource: SourceInfo = null

        listOfSources.foreach { source =>
          if ((source.hostAddress != skipSourceInfo.hostAddress || 
               source.listenPort != skipSourceInfo.listenPort) && 
            source.currentLeechers < MultiTracker.MaxDegree &&
            source.currentLeechers > maxLeechers) {
              selectedSource = source
              maxLeechers = source.currentLeechers
            }
        }

        // Update leecher count
        selectedSource.currentLeechers += 1
        return selectedSource
      }
    }
  }

  class ServeMultipleRequests
  extends Thread with Logging {
    
    var threadPool = Utils.newDaemonCachedThreadPool()
    
    override def run() {      
      var serverSocket = new ServerSocket(0)
      listenPort = serverSocket.getLocalPort
      
      logInfo("ServeMultipleRequests started with " + serverSocket)

      listenPortLock.synchronized { listenPortLock.notifyAll() }

      try {
        while (!stopBroadcast) {
          var clientSocket: Socket = null
          try {
            serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
            clientSocket = serverSocket.accept
          } catch {
            case e: Exception => { }
          }
          
          if (clientSocket != null) {
            logDebug("Serve: Accepted new client connection: " + clientSocket)
            try {
              threadPool.execute(new ServeSingleRequest(clientSocket))
            } catch {
              // In failure, close socket here; else, the thread will close it
              case ioe: IOException => clientSocket.close()
            }
          }
        }
      } finally {
        if (serverSocket != null) {
          logInfo("ServeMultipleRequests now stopping...")
          serverSocket.close()
        }
      }
      // Shutdown the thread pool
      threadPool.shutdown()
    }

    class ServeSingleRequest(val clientSocket: Socket)
    extends Thread with Logging {
      private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
      oos.flush()
      private val ois = new ObjectInputStream(clientSocket.getInputStream)

      private var sendFrom = 0
      private var sendUntil = totalBlocks

      override def run() {
        try {
          logInfo("new ServeSingleRequest is running")

          // Receive range to send
          var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)]
          sendFrom = rangeToSend._1
          sendUntil = rangeToSend._2

          // If not a valid range, stop broadcast
          if (sendFrom == SourceInfo.StopBroadcast && sendUntil == SourceInfo.StopBroadcast) {
            stopBroadcast = true
          } else {
            sendObject
          }
        } catch {
          case e: Exception => logError("ServeSingleRequest had a " + e)
        } finally {
          logInfo("ServeSingleRequest is closing streams and sockets")
          ois.close()
          oos.close()
          clientSocket.close()
        }
      }

      private def sendObject() {
        // Wait till receiving the SourceInfo from Driver
        while (totalBlocks == -1) {
          totalBlocksLock.synchronized { totalBlocksLock.wait() }
        }

        for (i <- sendFrom until sendUntil) {
          while (i == hasBlocks) {
            hasBlocksLock.synchronized { hasBlocksLock.wait() }
          }
          try {
            oos.writeObject(arrayOfBlocks(i))
            oos.flush()
          } catch {
            case e: Exception => logError("sendObject had a " + e)
          }
          logDebug("Sent block: " + i + " to " + clientSocket)
        }
      }
    }
  }
}

private[spark] class TreeBroadcastFactory
extends BroadcastFactory {
  def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) }

  def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
    new TreeBroadcast[T](value_, isLocal, id)

  def stop() { MultiTracker.stop() }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy