All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.azure.cosmos.spark.RowSerializerPoolInternal.scala Maven / Gradle / Ivy

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.cosmos.spark

import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.types.StructType

import java.time.Instant
import java.util.concurrent.{ConcurrentLinkedQueue, Executors, TimeUnit}
import java.util.concurrent.atomic.AtomicLong
import scala.collection.concurrent.TrieMap
import scala.util.control.NonFatal

/**
 * Spark serializers are not thread-safe - and expensive to create (dynamic code generation)
 * So we will use this object pool to allow reusing serializers based on the targeted schema.
 * The main purpose for pooling serializers (vs. creating new ones in each PartitionReader) is for Structured
 * Streaming scenarios where PartitionReaders for the same schema could be created every couple of 100
 * milliseconds
 * A clean-up task is used to purge serializers for schemas which weren't used anymore
 * For each schema we have an object pool that will use a soft-limit to limit the memory footprint
 */
private class RowSerializerPoolInstance(val serializerFactory: StructType => ExpressionEncoder.Serializer[Row])
  extends BasicLoggingTrait {

  val MaxPooledSerializerCount = 256
  private[this] val cleanUpIntervalInSeconds = 300
  private[this] val expirationIntervalInSeconds = 1800
  private[this] val schemaScopedSerializerMap =
    new TrieMap[StructType, RowSerializerQueue]
  private[this] val executorService = Executors.newSingleThreadScheduledExecutor(SparkUtils.daemonThreadFactory())

  executorService.scheduleWithFixedDelay(
    () => this.onCleanUp(),
    cleanUpIntervalInSeconds,
    cleanUpIntervalInSeconds,
    TimeUnit.SECONDS)

  def getOrCreateSerializer(schema: StructType): ExpressionEncoder.Serializer[Row] = {
    schemaScopedSerializerMap.get(schema) match {
      case Some(objectPool) => objectPool.borrowSerializer(schema)
      case None => serializerFactory.apply(schema)
    }
  }

  def returnSerializerToPool(schema: StructType, serializer: ExpressionEncoder.Serializer[Row]): Boolean = {
    schemaScopedSerializerMap.get(schema) match {
      case Some(objectPool) => objectPool.returnSerializer(serializer)
      case None =>
        val newQueue = new RowSerializerQueue(serializerFactory)
        newQueue.returnSerializer(serializer)
        schemaScopedSerializerMap.putIfAbsent(schema, newQueue).isEmpty
    }
  }

  private[this] def onCleanUp(): Unit = {
    try {
      val expirationThreshold: Long = Instant.now.minusSeconds(expirationIntervalInSeconds).toEpochMilli

      schemaScopedSerializerMap
        .readOnlySnapshot()
        .foreach(keyValuePair => {
          if (keyValuePair._2.getLastBorrowedAny < expirationThreshold) {
            schemaScopedSerializerMap.remove(keyValuePair._1, keyValuePair._2)
          }
        })
    } catch {
      case NonFatal(e) => logError("Callback onCleanup invocation failed.", e)
    }
  }

  /**
   * A slim wrapper around ConcurrentLinkedQueue with the purpose of
   * - having a soft-limit on how many serializers can be pooled - there is no need to have an
   *   exact limit - best effort is acceptable. When we exceed the max size we don't offer
   *   returned serializers to the pool anymore to have a limited memory footprint.
   * - keeping track of when any serializer for a certain schema was used last to allow the owner
   *   to purge serializers for schemas not used anymore.
   */
  private class RowSerializerQueue(val serializerFactory: StructType => ExpressionEncoder.Serializer[Row]) {
    private[this] val objectPool = new ConcurrentLinkedQueue[ExpressionEncoder.Serializer[Row]]()
    private[this] val estimatedSize = new AtomicLong(0)
    private[this] val lastBorrowedAny = new AtomicLong(Instant.now.toEpochMilli)

    def borrowSerializer(schema: StructType): ExpressionEncoder.Serializer[Row] = {
      lastBorrowedAny.set(Instant.now.toEpochMilli)
      Option.apply(objectPool.poll()) match {
        case Some(serializer) =>
          estimatedSize.decrementAndGet()
          serializer
        case None => serializerFactory.apply(schema)
      }
    }

    def returnSerializer(serializer: ExpressionEncoder.Serializer[Row]): Boolean = {
      if (estimatedSize.incrementAndGet() > MaxPooledSerializerCount) {
        estimatedSize.decrementAndGet()
        false
      } else {
        objectPool.offer(serializer)
        true
      }
    }

    def getLastBorrowedAny: Long = lastBorrowedAny.get()
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy