All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.rpc.akka.AkkaRpcEnv.scala Maven / Gradle / Ivy

There is a newer version: 2.4.8
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.rpc.akka

import java.util.concurrent.ConcurrentHashMap

import scala.concurrent.Future
import scala.language.postfixOps
import scala.reflect.ClassTag
import scala.util.control.NonFatal

import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Address}
import akka.event.Logging.Error
import akka.pattern.{ask => akkaAsk}
import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent}
import akka.serialization.JavaSerializer

import org.apache.spark.{SparkException, Logging, SparkConf}
import org.apache.spark.rpc._
import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils}

/**
 * A RpcEnv implementation based on Akka.
 *
 * TODO Once we remove all usages of Akka in other place, we can move this file to a new project and
 * remove Akka from the dependencies.
 *
 * @param actorSystem
 * @param conf
 * @param boundPort
 */
private[spark] class AkkaRpcEnv private[akka] (
    val actorSystem: ActorSystem, conf: SparkConf, boundPort: Int)
  extends RpcEnv(conf) with Logging {

  private val defaultAddress: RpcAddress = {
    val address = actorSystem.asInstanceOf[ExtendedActorSystem].provider.getDefaultAddress
    // In some test case, ActorSystem doesn't bind to any address.
    // So just use some default value since they are only some unit tests
    RpcAddress(address.host.getOrElse("localhost"), address.port.getOrElse(boundPort))
  }

  override val address: RpcAddress = defaultAddress

  /**
   * A lookup table to search a [[RpcEndpointRef]] for a [[RpcEndpoint]]. We need it to make
   * [[RpcEndpoint.self]] work.
   */
  private val endpointToRef = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]()

  /**
   * Need this map to remove `RpcEndpoint` from `endpointToRef` via a `RpcEndpointRef`
   */
  private val refToEndpoint = new ConcurrentHashMap[RpcEndpointRef, RpcEndpoint]()

  private def registerEndpoint(endpoint: RpcEndpoint, endpointRef: RpcEndpointRef): Unit = {
    endpointToRef.put(endpoint, endpointRef)
    refToEndpoint.put(endpointRef, endpoint)
  }

  private def unregisterEndpoint(endpointRef: RpcEndpointRef): Unit = {
    val endpoint = refToEndpoint.remove(endpointRef)
    if (endpoint != null) {
      endpointToRef.remove(endpoint)
    }
  }

  /**
   * Retrieve the [[RpcEndpointRef]] of `endpoint`.
   */
  override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointToRef.get(endpoint)

  override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = {
    @volatile var endpointRef: AkkaRpcEndpointRef = null
    // Use lazy because the Actor needs to use `endpointRef`.
    // So `actorRef` should be created after assigning `endpointRef`.
    lazy val actorRef = actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging {

      assert(endpointRef != null)

      override def preStart(): Unit = {
        // Listen for remote client network events
        context.system.eventStream.subscribe(self, classOf[AssociationEvent])
        safelyCall(endpoint) {
          endpoint.onStart()
        }
      }

      override def receiveWithLogging: Receive = {
        case AssociatedEvent(_, remoteAddress, _) =>
          safelyCall(endpoint) {
            endpoint.onConnected(akkaAddressToRpcAddress(remoteAddress))
          }

        case DisassociatedEvent(_, remoteAddress, _) =>
          safelyCall(endpoint) {
            endpoint.onDisconnected(akkaAddressToRpcAddress(remoteAddress))
          }

        case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) =>
          safelyCall(endpoint) {
            endpoint.onNetworkError(cause, akkaAddressToRpcAddress(remoteAddress))
          }

        case e: AssociationEvent =>
          // TODO ignore?

        case m: AkkaMessage =>
          logDebug(s"Received RPC message: $m")
          safelyCall(endpoint) {
            processMessage(endpoint, m, sender)
          }

        case AkkaFailure(e) =>
          safelyCall(endpoint) {
            throw e
          }

        case message: Any => {
          logWarning(s"Unknown message: $message")
        }

      }

      override def postStop(): Unit = {
        unregisterEndpoint(endpoint.self)
        safelyCall(endpoint) {
          endpoint.onStop()
        }
      }

      }), name = name)
    endpointRef = new AkkaRpcEndpointRef(defaultAddress, actorRef, conf, initInConstructor = false)
    registerEndpoint(endpoint, endpointRef)
    // Now actorRef can be created safely
    endpointRef.init()
    endpointRef
  }

  private def processMessage(endpoint: RpcEndpoint, m: AkkaMessage, _sender: ActorRef): Unit = {
    val message = m.message
    val needReply = m.needReply
    val pf: PartialFunction[Any, Unit] =
      if (needReply) {
        endpoint.receiveAndReply(new RpcCallContext {
          override def sendFailure(e: Throwable): Unit = {
            _sender ! AkkaFailure(e)
          }

          override def reply(response: Any): Unit = {
            _sender ! AkkaMessage(response, false)
          }

          // Some RpcEndpoints need to know the sender's address
          override val sender: RpcEndpointRef =
            new AkkaRpcEndpointRef(defaultAddress, _sender, conf)
        })
      } else {
        endpoint.receive
      }
    try {
      pf.applyOrElse[Any, Unit](message, { message =>
        throw new SparkException(s"Unmatched message $message from ${_sender}")
      })
    } catch {
      case NonFatal(e) =>
        _sender ! AkkaFailure(e)
        if (!needReply) {
          // If the sender does not require a reply, it may not handle the exception. So we rethrow
          // "e" to make sure it will be processed.
          throw e
        }
    }
  }

  /**
   * Run `action` safely to avoid to crash the thread. If any non-fatal exception happens, it will
   * call `endpoint.onError`. If `endpoint.onError` throws any non-fatal exception, just log it.
   */
  private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = {
    try {
      action
    } catch {
      case NonFatal(e) => {
        try {
          endpoint.onError(e)
        } catch {
          case NonFatal(e) => logError(s"Ignore error: ${e.getMessage}", e)
        }
      }
    }
  }

  private def akkaAddressToRpcAddress(address: Address): RpcAddress = {
    RpcAddress(address.host.getOrElse(defaultAddress.host),
      address.port.getOrElse(defaultAddress.port))
  }

  override def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = {
    import actorSystem.dispatcher
    actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout.duration).
      map(new AkkaRpcEndpointRef(defaultAddress, _, conf)).
      // this is just in case there is a timeout from creating the future in resolveOne, we want the
      // exception to indicate the conf that determines the timeout
      recover(defaultLookupTimeout.addMessageIfTimeout)
  }

  override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = {
    AkkaUtils.address(
      AkkaUtils.protocol(actorSystem), systemName, address.host, address.port, endpointName)
  }

  override def shutdown(): Unit = {
    actorSystem.shutdown()
  }

  override def stop(endpoint: RpcEndpointRef): Unit = {
    require(endpoint.isInstanceOf[AkkaRpcEndpointRef])
    actorSystem.stop(endpoint.asInstanceOf[AkkaRpcEndpointRef].actorRef)
  }

  override def awaitTermination(): Unit = {
    actorSystem.awaitTermination()
  }

  override def toString: String = s"${getClass.getSimpleName}($actorSystem)"

  override def deserialize[T](deserializationAction: () => T): T = {
    JavaSerializer.currentSystem.withValue(actorSystem.asInstanceOf[ExtendedActorSystem]) {
      deserializationAction()
    }
  }
}

private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory {

  def create(config: RpcEnvConfig): RpcEnv = {
    val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
      config.name, config.host, config.port, config.conf, config.securityManager)
    actorSystem.actorOf(Props(classOf[ErrorMonitor]), "ErrorMonitor")
    new AkkaRpcEnv(actorSystem, config.conf, boundPort)
  }
}

/**
 * Monitor errors reported by Akka and log them.
 */
private[akka] class ErrorMonitor extends Actor with ActorLogReceive with Logging {

  override def preStart(): Unit = {
    context.system.eventStream.subscribe(self, classOf[Error])
  }

  override def receiveWithLogging: Actor.Receive = {
    case Error(cause: Throwable, _, _, message: String) => logError(message, cause)
  }
}

private[akka] class AkkaRpcEndpointRef(
    @transient defaultAddress: RpcAddress,
    @transient _actorRef: => ActorRef,
    @transient conf: SparkConf,
    @transient initInConstructor: Boolean = true)
  extends RpcEndpointRef(conf) with Logging {

  lazy val actorRef = _actorRef

  override lazy val address: RpcAddress = {
    val akkaAddress = actorRef.path.address
    RpcAddress(akkaAddress.host.getOrElse(defaultAddress.host),
      akkaAddress.port.getOrElse(defaultAddress.port))
  }

  override lazy val name: String = actorRef.path.name

  private[akka] def init(): Unit = {
    // Initialize the lazy vals
    actorRef
    address
    name
  }

  if (initInConstructor) {
    init()
  }

  override def send(message: Any): Unit = {
    actorRef ! AkkaMessage(message, false)
  }

  override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
    actorRef.ask(AkkaMessage(message, true))(timeout.duration).flatMap {
      // The function will run in the calling thread, so it should be short and never block.
      case msg @ AkkaMessage(message, reply) =>
        if (reply) {
          logError(s"Receive $msg but the sender cannot reply")
          Future.failed(new SparkException(s"Receive $msg but the sender cannot reply"))
        } else {
          Future.successful(message)
        }
      case AkkaFailure(e) =>
        Future.failed(e)
    }(ThreadUtils.sameThread).mapTo[T].
    recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)
  }

  override def toString: String = s"${getClass.getSimpleName}($actorRef)"

  final override def equals(that: Any): Boolean = that match {
    case other: AkkaRpcEndpointRef => actorRef == other.actorRef
    case _ => false
  }

  final override def hashCode(): Int = if (actorRef == null) 0 else actorRef.hashCode()
}

/**
 * A wrapper to `message` so that the receiver knows if the sender expects a reply.
 * @param message
 * @param needReply if the sender expects a reply message
 */
private[akka] case class AkkaMessage(message: Any, needReply: Boolean)

/**
 * A reply with the failure error from the receiver to the sender
 */
private[akka] case class AkkaFailure(e: Throwable)




© 2015 - 2024 Weber Informatics LLC | Privacy Policy