All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.utils.ThreadPool.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.bigdl.utils

import java.util.concurrent._

import com.intel.analytics.bigdl.mkl.hardware.Affinity
import com.intel.analytics.bigdl.mkl.{MKL, MklDnn => BackendMklDnn}
import org.apache.commons.lang.exception.ExceptionUtils
import org.apache.log4j.Logger

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.concurrent.duration.Duration
import scala.concurrent.{Await, ExecutionContext, Future}

/**
 * A thread pool wrapper, provide some helper functions for multi-threading
 *
 * TODO `TreadPool` will give 2-version of thread pool, one uses scala version (`invokeAndWait`),
 * another is provided to Java (`invokeAndWait2`). The design is weird. We should refactor this
 * class later.
 */
class ThreadPool(private var poolSize: Int) {

  import ThreadPool._


  private var mklPoolSize : Option[Int] = None
  private var threadPool: ExecutorService = null

  private var context = spawnThreadPool(poolSize)

  private def spawnThreadPool(poolSize: Int): ExecutionContext = {
    if (poolSize == 1) {
      threadPool = Executors.newFixedThreadPool(poolSize, new ThreadFactory {
        override def newThread(r: Runnable): Thread = {
          val t = Executors.defaultThreadFactory().newThread(r)
          t.setName("single-thread-computing")
          t.setDaemon(true)
          t
        }
      })

      singleThreadPool
    } else {
      new ExecutionContext {
        if (threadPool != null) threadPool.shutdown()
        threadPool = Executors.newFixedThreadPool(poolSize, new ThreadFactory {
          override def newThread(r: Runnable): Thread = {
            val t = Executors.defaultThreadFactory().newThread(r)
            t.setName("default-thread-computing " + t.getId)
            t.setDaemon(true)
            t
          }
        })

        def execute(runnable: Runnable) {
          threadPool.submit(runnable)
        }

        def reportFailure(t: Throwable) {}
      }
    }
  }

  def getPoolSize : Int = poolSize

  /**
   * Set MKL thread pool size
   *
   * @param size
   * @return
   */
  def setMKLThread(size: Int): this.type = this.synchronized {
    require(MKL.isMKLLoaded)
    mklPoolSize = Some(size)
    (1 to poolSize).map(i => Future {
      MKL.setNumThreads(size)
      val tid = Thread.currentThread().getId()
      logger.info(s"Set mkl threads to $size on thread $tid")
    }(context)).foreach(Await.result(_, Duration.Inf))
    this
  }

  def setMKLThreadOfMklDnnBackend(size: Int): this.type = this.synchronized {
    mklPoolSize = Some(size)


    this.invokeAndWait2((0 until 1).map(_ => () => {
      if (System.getProperty("bigdl.flushDenormalState", "true").toBoolean) {
        BackendMklDnn.setFlushDenormalState()
      }

      require(MKL.isMKLLoaded)
      require(BackendMklDnn.isLoaded)

      MKL.setNumThreads(size)
      BackendMklDnn.setNumThreads(size)
      if (!System.getProperty("bigdl.disableOmpAffinity", "false").toBoolean) {
        Affinity.setOmpAffinity()
      }
    }))

    this
  }

  /**
   * Invoke a batch of tasks and wait for all them finished
   *
   * @param tasks
   * @param timeout
   * @tparam T
   * @return
   */
  def invokeAndWait[T](tasks: Seq[() => T], timeout: Duration = Duration.Inf): Seq[T] = {
    tasks.map(task => Future {
      try {
        task()
      } catch {
        case t : Throwable =>
          logger.error("Error: " + ExceptionUtils.getStackTrace(t))
          throw t
      }
    }(context)).map(future => {
      Await.result(future, timeout)
    })
  }

  private type JavaFuture[T] = java.util.concurrent.Future[T]

  /**
   * Use java future to execute the tasks. It will be blocking until tasks completed.
   * If any task throws an exception, it will throw that exception in caller.
   *
   * @param tasks task sequence. each task's return type is T
   * @param timeout the maximum time to wait
   * @param timeUnit the time unit for the timeout
   * @tparam T return type of tasks
   * @return a sequence of Futures representing the tasks.
   */
  def invokeAndWait2[T](tasks: Seq[() => T], timeout: Long = Long.MaxValue,
    timeUnit: TimeUnit = TimeUnit.NANOSECONDS): mutable.Buffer[JavaFuture[T]] = {
    val callables = tasks.map(task => new Callable[T] {
      override def call(): T = {
        task()
      }
    })

    val resultFutures = threadPool.invokeAll(callables.asJava, timeout, timeUnit)

    // we should check all the future in the list, if any task has an exception,
    // we should throw it.
    var i = 0
    while (i < resultFutures.size()) {
      try {
        resultFutures.get(i).get()
      } catch {
        case t: ExecutionException => throw t.getCause
        case i: InterruptedException => throw i.getCause
      }
      i += 1
    }

    resultFutures.asScala
  }

  def invoke2[T](tasks: Seq[() => T]): Seq[JavaFuture[T]] = {
    tasks.map(task => new Callable[T] {
      override def call(): T = {
        try {
          task()
        } catch {
          case t : Throwable =>
            logger.error("Error: " + ExceptionUtils.getStackTrace(t))
            throw t
        }
      }
    }).map(threadPool.submit(_))
  }

  /**
   * Invoke a batch of tasks
   *
   * @param tasks
   */
  def invoke[T](tasks: Seq[() => T]): Seq[Future[T]] = {
    tasks.map(task => Future {
      try {
        task()
      } catch {
        case t : Throwable =>
          logger.error("Error: " + ExceptionUtils.getStackTrace(t))
          throw t
      }
    }(context))
  }

  /**
   * Invoke a single tasks
   *
   * @param task
   */
  def invoke[T](task: () => T): Future[T] = {
    Future {
      try {
        task()
      } catch {
        case t : Throwable =>
          logger.error("Error: " + ExceptionUtils.getStackTrace(t))
          throw t
      }
    }(context)
  }

  /**
   * Wait for all the tasks in the wait queue finish
   *
   * @param timeout
   */
  def sync(futures: Seq[Future[_]], timeout: Duration = Duration.Inf): Unit = {
    futures.foreach(f => {
      Await.result(f, timeout)
    })
  }

  /**
   * Set pool size
   *
   * @param size
   * @return
   */
  def setPoolSize(size: Int): this.type = this.synchronized {
    if (size != poolSize) {
      context = spawnThreadPool(size)
      poolSize = size
      if(mklPoolSize.isDefined) {
        this.setMKLThread(mklPoolSize.get)
      }
    }
    this
  }
}

object ThreadPool {
  val singleThreadPool = new ExecutionContext {
    def execute(runnable: Runnable) {
      runnable.run()
    }

    def reportFailure(t: Throwable) {}
  }

  private val logger = Logger.getLogger(getClass)
}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy