All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.core.utils.AsyncUtils.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.core.utils

import scala.concurrent.duration.Duration
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.util.{Failure, Success, Try}

object AsyncUtils {
  def bufferedAwait[T](it: Iterator[Future[T]],
                                 concurrency: Int,
                                 timeout: Duration)
                                (implicit ec: ExecutionContext): Iterator[T] = {
    bufferedAwaitSafe(it, concurrency, timeout).map{
      case Success(data) => data
      case f: Failure[T] => throw f.exception
    }
  }

  private def safeAwait[T](f: Future[T], timeout: Duration): Try[T] = {
    Try(Await.result(f, timeout))
  }

  def bufferedAwaitSafe[T](it: Iterator[Future[T]],
                       concurrency: Int,
                       timeout: Duration)
                      (implicit ec: ExecutionContext): Iterator[Try[T]] = {
    if (concurrency > 1) {
      val slidingIterator = it.sliding(concurrency - 1).withPartial(true)
      // `hasNext` will auto start the nth future in the batch
      val (initIterator, tailIterator) = slidingIterator.span(_ => slidingIterator.hasNext)
      initIterator.map(futureBatch => safeAwait(futureBatch.head, timeout)) ++
        tailIterator.flatMap(lastBatch => lastBatch.map(safeAwait(_, timeout)))
    } else if (concurrency == 1) {
      it.map(f => safeAwait(f, timeout))
    } else {
      throw new IllegalArgumentException(
        s"Concurrency needs to be at least 1, got: $concurrency")
    }
  }

  private def safeAwaitWithContext[T, C](f: (Future[T], C), timeout: Duration): (Try[T], C) = {
    (Try(Await.result(f._1, timeout)), f._2)
  }

  def bufferedAwaitSafeWithContext[T,C](it: Iterator[(Future[T],C)],
                           concurrency: Int,
                           timeout: Duration)
                          (implicit ec: ExecutionContext): Iterator[(Try[T],C)] = {
    if (concurrency > 1) {
      val slidingIterator = it.sliding(concurrency - 1).withPartial(true)
      // `hasNext` will auto start the nth future in the batch
      val (initIterator, tailIterator) = slidingIterator.span(_ => slidingIterator.hasNext)
      initIterator.map(futureBatch => safeAwaitWithContext(futureBatch.head, timeout)) ++
        tailIterator.flatMap(lastBatch => lastBatch.map(safeAwaitWithContext(_, timeout)))
    } else if (concurrency == 1) {
      it.map(f => safeAwaitWithContext(f, timeout))
    } else {
      throw new IllegalArgumentException(
        s"Concurrency needs to be at least 1, got: $concurrency")
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy