All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.mongodb.spark.connection.MongoClientCache.scala Maven / Gradle / Ivy

There is a newer version: 10.2.3
Show newest version
//scalastyle:off
/*
 * Copyright 2016 MongoDB, Inc.
 * Copyright 2014-2015, DataStax, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
//scalastyle:on
package com.mongodb.spark.connection

import java.util.concurrent.{Executors, ThreadFactory, TimeUnit}

import com.mongodb.client.MongoClient

import scala.annotation.tailrec
import scala.collection.JavaConverters._
import scala.collection.concurrent.TrieMap
import scala.concurrent.duration.Duration
import scala.util.{Failure, Success, Try}
import com.mongodb.spark.{Logging, MongoClientFactory}

/**
 * A lockless cache for MongoClients.
 *
 * Allows multiple users access to MongoClients. Closes a `MongoClient` when they're are no longer used.
 *
 * @param keepAlive the duration to keep alive any given MongoClient so that it can be reused by another consumer
 */
private[spark] final class MongoClientCache(keepAlive: Duration) extends Logging {

  private val refCounter = new MongoClientRefCounter
  private val cache = new TrieMap[MongoClientFactory, MongoClient]
  private val clientToKey = new TrieMap[MongoClient, MongoClientFactory]
  private val deferredReleases = new TrieMap[MongoClient, ReleaseTask]

  @tailrec
  def acquire(mongoClientFactory: MongoClientFactory): MongoClient = {
    cache.get(mongoClientFactory) match {
      case Some(mongoClient) =>
        refCounter.canAcquire(mongoClient) match {
          case true  => mongoClient
          case false => acquire(mongoClientFactory)
        }
      case None =>
        val createdMongoClient = mongoClientFactory.create()
        logClient(createdMongoClient)
        refCounter.acquire(createdMongoClient)
        cache.putIfAbsent(mongoClientFactory, createdMongoClient) match {
          case None =>
            clientToKey.put(createdMongoClient, mongoClientFactory)
            createdMongoClient
          case Some(existingMongoClient) =>
            logClient(createdMongoClient, closing = true)
            createdMongoClient.close()
            refCounter.release(createdMongoClient)
            refCounter.canAcquire(existingMongoClient) match {
              case true  => existingMongoClient
              case false => acquire(mongoClientFactory)
            }
        }
    }
  }

  /**
   * Releases previously acquired mongoClient. Once the mongoClient is released by all threads and
   * the `releaseDelayMillis` timeout passes, the mongoClient is destroyed by calling `destroy` function and
   * removed from the cache.
   */
  def release(mongoClient: MongoClient, releaseDelay: Duration = keepAlive) {
    if (releaseDelay.toMillis == 0 || scheduledExecutorService.isShutdown) {
      releaseImmediately(mongoClient)
    } else {
      releaseDeferred(mongoClient, releaseDelay, 1)
    }
  }

  /**
   * Shuts down the background deferred `release` scheduler and forces all pending release tasks to be executed
   */
  def shutdown() {
    scheduledExecutorService.shutdown()
    while (deferredReleases.nonEmpty) {
      for ((mongoClient, task) <- deferredReleases.snapshot()) {
        if (deferredReleases.remove(mongoClient, task)) task.run()
      }
    }
  }

  private def releaseImmediately(mongoClient: MongoClient, count: Int = 1): Unit = {
    Try(refCounter.release(mongoClient, count)) match {
      case Success(0) =>
        cache.remove(clientToKey(mongoClient))
        clientToKey.remove(mongoClient)
        logClient(mongoClient, closing = true)
        mongoClient.close()
      case Failure(e) => logWarning(e.getMessage)
      case _          =>
    }
  }

  @tailrec
  private def releaseDeferred(mongoClient: MongoClient, releaseDelay: Duration, count: Int): Unit = {
    val newTime = System.nanoTime() + releaseDelay.toNanos
    val newTask = deferredReleases.remove(mongoClient) match {
      case Some(oldTask) => ReleaseTask(mongoClient, oldTask.count + count, math.max(oldTask.scheduledTime, newTime))
      case None          => ReleaseTask(mongoClient, count, newTime)
    }
    deferredReleases.putIfAbsent(mongoClient, newTask) match {
      case Some(oldTask) => releaseDeferred(mongoClient, releaseDelay, newTask.count)
      case None          =>
    }
  }

  /**
   * Called periodically by `scheduledExecutorService`
   */
  private def processPendingReleases() {
    val now = System.nanoTime()
    for ((mongoClient, task) <- deferredReleases)
      if (task.scheduledTime <= now)
        if (deferredReleases.remove(mongoClient, task)) task.run()
  }

  private val processPendingReleasesTask = new Runnable() {
    override def run() {
      processPendingReleases()
    }
  }

  private val scheduledExecutorService = Executors.newSingleThreadScheduledExecutor(new ThreadFactory() {
    override def newThread(r: Runnable) = {
      val thread = Executors.defaultThreadFactory().newThread(r)
      thread.setDaemon(true)
      thread
    }
  })

  // This must be high enough so it doesn't cause too much CPU usage,
  // but also low enough to allow for acceptable releaseDelayMillis resolution.
  private val period = 100
  scheduledExecutorService.scheduleAtFixedRate(processPendingReleasesTask, period, period, TimeUnit.MILLISECONDS)

  private case class ReleaseTask(mongoClient: MongoClient, count: Int, scheduledTime: Long) extends Runnable {
    override def run() {
      releaseImmediately(mongoClient, count)
    }
  }

  private def logClient(mongoClient: MongoClient, closing: Boolean = false): Unit = {
    val status = if (closing) "Closing" else "Creating"
    logInfo(s"""$status MongoClient: ${mongoClient.getClusterDescription.getServerDescriptions.asScala.map(_.getAddress.toString).mkString("[", ",", "]")}""")
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy