All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.delta.util.StateCache.scala Maven / Gradle / Ivy

There is a newer version: 0.6.1
Show newest version
/*
 * Copyright 2019 Databricks, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.delta.util

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.delta.Snapshot

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.storage.StorageLevel

/**
 * Machinary that caches the reconstructed state of a Delta table
 * using the RDD cache. The cache is designed so that the first access
 * will materialize the results.  However once uncache is called,
 * all data will be flushed and will not be cached again.
 */
trait StateCache {
  protected def spark: SparkSession

  /** If state RDDs for this snapshot should still be cached. */
  private var isCached = true
  /** A list of RDDs that we need to uncache when we are done with this snapshot. */
  private val cached = ArrayBuffer[RDD[_]]()

  implicit class CacheableDS[A](ds: Dataset[A]) {
    def rddCache(name: String): Dataset[A] = cached.synchronized {
      if (isCached) {
        implicit val enc = ds.exprEnc
        val stateRDD = ds.queryExecution.toRdd.map(_.copy())
        stateRDD.setName(name)
        stateRDD.persist(StorageLevel.MEMORY_AND_DISK_SER)
        cached += stateRDD

        Dataset.ofRows(
          spark,
          LogicalRDD(
            ds.queryExecution.analyzed.output,
            stateRDD)(
            spark)).as[A]
      } else {
        ds
      }
    }
  }

  /** Drop any cached data for this [[Snapshot]]. */
  def uncache(): Unit = cached.synchronized {
    if (isCached) {
      isCached = false
      cached.foreach(_.unpersist(blocking = false))
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy