All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.runtime.testingUtils.TestingTaskManagerLike.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.runtime.testingUtils

import akka.actor.{ActorRef, Terminated}
import org.apache.flink.api.common.JobID
import org.apache.flink.runtime.FlinkActor
import org.apache.flink.runtime.execution.ExecutionState
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID
import org.apache.flink.runtime.messages.JobManagerMessages.{RequestLeaderSessionID, ResponseLeaderSessionID}
import org.apache.flink.runtime.messages.Messages.{Acknowledge, Disconnect}
import org.apache.flink.runtime.messages.RegistrationMessages.{AcknowledgeRegistration, AlreadyRegistered}
import org.apache.flink.runtime.messages.TaskMessages.{SubmitTask, TaskInFinalState, UpdateTaskExecutionState}
import org.apache.flink.runtime.taskmanager.TaskManager
import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages.NotifyWhenJobRemoved
import org.apache.flink.runtime.testingUtils.TestingMessages._
import org.apache.flink.runtime.testingUtils.TestingTaskManagerMessages._

import scala.concurrent.duration._
import scala.language.postfixOps

/** This mixin can be used to decorate a TaskManager with messages for testing purposes. */
trait TestingTaskManagerLike extends FlinkActor {
  that: TaskManager =>

  import scala.collection.JavaConverters._

  val waitForRemoval = scala.collection.mutable.HashMap[ExecutionAttemptID, Set[ActorRef]]()
  val waitForJobManagerToBeTerminated = scala.collection.mutable.HashMap[String, Set[ActorRef]]()
  val waitForRegisteredAtResourceManager =
    scala.collection.mutable.HashMap[ActorRef, Set[ActorRef]]()
  val waitForRunning = scala.collection.mutable.HashMap[ExecutionAttemptID, Set[ActorRef]]()
  val unregisteredTasks = scala.collection.mutable.HashSet[ExecutionAttemptID]()

  /** Map of registered task submit listeners */
  val registeredSubmitTaskListeners = scala.collection.mutable.HashMap[JobID, ActorRef]()

  val waitForShutdown = scala.collection.mutable.HashSet[ActorRef]()

  var disconnectDisabled = false

  /**
   * Handler for testing related messages
   */
  abstract override def handleMessage: Receive = {
    handleTestingMessage orElse super.handleMessage
  }

  def handleTestingMessage: Receive = {
    case Alive => sender() ! Acknowledge

    case NotifyWhenTaskIsRunning(executionID) =>
      Option(runningTasks.get(executionID)) match {
        case Some(task) if task.getExecutionState == ExecutionState.RUNNING =>
          sender ! decorateMessage(true)

        case _ =>
          val listeners = waitForRunning.getOrElse(executionID, Set())
          waitForRunning += (executionID -> (listeners + sender))
      }

    case RequestRunningTasks =>
      sender ! decorateMessage(ResponseRunningTasks(runningTasks.asScala.toMap))

    case NotifyWhenTaskRemoved(executionID) =>
      Option(runningTasks.get(executionID)) match {
        case Some(_) =>
          val set = waitForRemoval.getOrElse(executionID, Set())
          waitForRemoval += (executionID -> (set + sender))
        case None =>
          if(unregisteredTasks.contains(executionID)) {
            sender ! decorateMessage(true)
          } else {
            val set = waitForRemoval.getOrElse(executionID, Set())
            waitForRemoval += (executionID -> (set + sender))
          }
      }

    case TaskInFinalState(executionID) =>
      super.handleMessage(TaskInFinalState(executionID))
      waitForRemoval.remove(executionID) match {
        case Some(actors) => for(actor <- actors) actor ! decorateMessage(true)
        case None =>
      }

      unregisteredTasks += executionID

    case RequestBroadcastVariablesWithReferences =>
      sender ! decorateMessage(
        ResponseBroadcastVariablesWithReferences(
          bcVarManager.getNumberOfVariablesWithReferences)
      )

    case RequestNumActiveConnections =>
      val numActive = if (network.isAssociated) {
        network.getConnectionManager.getNumberOfActiveConnections
      } else {
        0
      }
      sender ! decorateMessage(ResponseNumActiveConnections(numActive))

    case NotifyWhenJobRemoved(jobID) =>
      if(runningTasks.values.asScala.exists(_.getJobID == jobID)){
        context.system.scheduler.scheduleOnce(
          200 milliseconds,
          self,
          decorateMessage(CheckIfJobRemoved(jobID)))(
            context.dispatcher,
            sender()
          )
      }else{
        sender ! decorateMessage(true)
      }

    case CheckIfJobRemoved(jobID) =>
      if(runningTasks.values.asScala.forall(_.getJobID != jobID)){
        sender ! decorateMessage(true)
      } else {
        context.system.scheduler.scheduleOnce(
          200 milliseconds,
          self,
          decorateMessage(CheckIfJobRemoved(jobID)))(
            context.dispatcher,
            sender()
          )
      }

    case NotifyWhenJobManagerTerminated(jobManager) =>
      val waiting = waitForJobManagerToBeTerminated.getOrElse(jobManager.path.name, Set())
      waitForJobManagerToBeTerminated += jobManager.path.name -> (waiting + sender)

    case RegisterSubmitTaskListener(jobId) =>
      registeredSubmitTaskListeners.put(jobId, sender())

    case msg@SubmitTask(tdd) =>
      registeredSubmitTaskListeners.get(tdd.getJobID) match {
        case Some(listenerRef) =>
          listenerRef ! ResponseSubmitTaskListener(tdd)
        case None =>
        // Nothing to do
      }

      super.handleMessage(msg)

    /**
     * Message from task manager that accumulator values changed and need to be reported immediately
     * instead of lazily through the
     * [[org.apache.flink.runtime.messages.TaskManagerMessages.Heartbeat]] message. We forward this
     * message to the job manager that it knows it should report to the listeners.
     */
    case msg: AccumulatorsChanged =>
      currentJobManager match {
        case Some(jobManager) =>
          jobManager.forward(msg)
          sendHeartbeatToJobManager()
          sender ! true
        case None =>
      }

    case msg@Terminated(jobManager) =>
      super.handleMessage(msg)

      waitForJobManagerToBeTerminated.remove(jobManager.path.name) foreach {
        _ foreach {
          _ ! decorateMessage(JobManagerTerminated(jobManager))
        }
      }

    case msg:Disconnect =>
      if (!disconnectDisabled) {
        super.handleMessage(msg)

        val jobManager = sender()

        waitForJobManagerToBeTerminated.remove(jobManager.path.name) foreach {
          _ foreach {
            _ ! decorateMessage(JobManagerTerminated(jobManager))
          }
        }
      }

    case DisableDisconnect =>
      disconnectDisabled = true

    case NotifyOfComponentShutdown =>
      waitForShutdown += sender()

    case msg @ UpdateTaskExecutionState(taskExecutionState) =>
      super.handleMessage(msg)

      if(taskExecutionState.getExecutionState == ExecutionState.RUNNING) {
        waitForRunning.get(taskExecutionState.getID) foreach {
          _ foreach (_ ! decorateMessage(true))
        }
      }

    case RequestLeaderSessionID =>
      sender() ! ResponseLeaderSessionID(leaderSessionID.orNull)

    case NotifyWhenRegisteredAtJobManager(jobManager: ActorRef) =>
      if(isConnected && jobManager == currentJobManager.get) {
        sender() ! true
      } else {
        val list = waitForRegisteredAtResourceManager.getOrElse(
          jobManager,
          Set[ActorRef]())

        waitForRegisteredAtResourceManager += jobManager -> (list + sender())
      }

    case msg @ (_: AcknowledgeRegistration | _: AlreadyRegistered) =>
      super.handleMessage(msg)

      val jm = sender()

      waitForRegisteredAtResourceManager.remove(jm).foreach {
        listeners => listeners.foreach{
          listener =>
            listener ! true
        }
      }
  }

  /**
    * No killing of the VM for testing.
    */
  override protected def shutdown(): Unit = {
    log.info("Shutting down TestingJobManager.")
    waitForShutdown.foreach(_ ! ComponentShutdown(self))
    waitForShutdown.clear()
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy