All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.dimajix.spark.sql.execution.CountRecordsExec.scala Maven / Gradle / Ivy

/*
 * Copyright (C) 2021 The Flowman Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.dimajix.spark.sql.execution

import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.UnaryExecNode
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.LongAccumulator


case class CountRecordsExec(child: SparkPlan, counter:LongAccumulator) extends UnaryExecNode {
    override def output: Seq[Attribute] = child.output

    override def outputPartitioning: Partitioning = child.outputPartitioning
    override def outputOrdering: Seq[SortOrder] = child.outputOrdering

    override def supportsColumnar: Boolean = child.supportsColumnar
    override def vectorTypes: Option[Seq[String]] = child.vectorTypes

    override protected def doCanonicalize(): SparkPlan = copy(child=child.canonicalized)

    override protected def doExecuteColumnar(): RDD[ColumnarBatch] = {
        val collector = counter
        child.executeColumnar().mapPartitions { batches =>
            // Only publish the value of the accumulator when the task has completed. This is done by
            // updating a task local accumulator ('updater') which will be merged with the actual
            // accumulator as soon as the task completes. This avoids the following problems during the
            // heartbeat:
            // - Correctness issues due to partially completed/visible updates.
            // - Performance issues due to excessive serialization.
            val updater = new LongAccumulator
            TaskContext.get().addTaskCompletionListener[Unit] { _ =>
                collector.merge(updater)
            }

            batches.map { batch =>
                updater.add(batch.numRows())
                batch
            }
        }
    }

    override protected def doExecute(): RDD[InternalRow] = {
        val collector = counter
        child.execute().mapPartitions { rows =>
            val updater = new LongAccumulator
            TaskContext.get().addTaskCompletionListener[Unit] { _ =>
                collector.merge(updater)
            }

            rows.map { row =>
                updater.add(1)
                row
            }
        }
    }

    override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy