org.apache.spark.sql.execution.ColumnarBatchScan.scala Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution
import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
/**
* Helper trait for abstracting scan functionality using [[ColumnarBatch]]es.
*/
private[sql] trait ColumnarBatchScan extends CodegenSupport {
def vectorTypes: Option[Seq[String]] = None
protected def supportsBatch: Boolean = true
protected def needsUnsafeRowConversion: Boolean = true
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time"))
/**
* Generate [[ColumnVector]] expressions for our parent to consume as rows.
* This is called once per [[ColumnarBatch]].
*/
private def genCodeColumnVector(
ctx: CodegenContext,
columnVar: String,
ordinal: String,
dataType: DataType,
nullable: Boolean): ExprCode = {
val javaType = CodeGenerator.javaType(dataType)
val value = CodeGenerator.getValueFromVector(columnVar, dataType, ordinal)
val isNullVar = if (nullable) {
JavaCode.isNullVariable(ctx.freshName("isNull"))
} else {
FalseLiteral
}
val valueVar = ctx.freshName("value")
val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]"
val code = code"${ctx.registerComment(str)}" + (if (nullable) {
code"""
boolean $isNullVar = $columnVar.isNullAt($ordinal);
$javaType $valueVar = $isNullVar ? ${CodeGenerator.defaultValue(dataType)} : ($value);
"""
} else {
code"$javaType $valueVar = $value;"
})
ExprCode(code, isNullVar, JavaCode.variable(valueVar, dataType))
}
/**
* Produce code to process the input iterator as [[ColumnarBatch]]es.
* This produces an [[UnsafeRow]] for each row in each batch.
*/
// TODO: return ColumnarBatch.Rows instead
override protected def doProduce(ctx: CodegenContext): String = {
// PhysicalRDD always just has one input
val input = ctx.addMutableState("scala.collection.Iterator", "input",
v => s"$v = inputs[0];")
if (supportsBatch) {
produceBatches(ctx, input)
} else {
produceRows(ctx, input)
}
}
private def produceBatches(ctx: CodegenContext, input: String): String = {
// metrics
val numOutputRows = metricTerm(ctx, "numOutputRows")
val scanTimeMetric = metricTerm(ctx, "scanTime")
val scanTimeTotalNs =
ctx.addMutableState(CodeGenerator.JAVA_LONG, "scanTime") // init as scanTime = 0
val columnarBatchClz = classOf[ColumnarBatch].getName
val batch = ctx.addMutableState(columnarBatchClz, "batch")
val idx = ctx.addMutableState(CodeGenerator.JAVA_INT, "batchIdx") // init as batchIdx = 0
val columnVectorClzs = vectorTypes.getOrElse(
Seq.fill(output.indices.size)(classOf[ColumnVector].getName))
val (colVars, columnAssigns) = columnVectorClzs.zipWithIndex.map {
case (columnVectorClz, i) =>
val name = ctx.addMutableState(columnVectorClz, s"colInstance$i")
(name, s"$name = ($columnVectorClz) $batch.column($i);")
}.unzip
val nextBatch = ctx.freshName("nextBatch")
val nextBatchFuncName = ctx.addNewFunction(nextBatch,
s"""
|private void $nextBatch() throws java.io.IOException {
| long getBatchStart = System.nanoTime();
| if ($input.hasNext()) {
| $batch = ($columnarBatchClz)$input.next();
| $numOutputRows.add($batch.numRows());
| $idx = 0;
| ${columnAssigns.mkString("", "\n", "\n")}
| }
| $scanTimeTotalNs += System.nanoTime() - getBatchStart;
|}""".stripMargin)
ctx.currentVars = null
val rowidx = ctx.freshName("rowIdx")
val columnsBatchInput = (output zip colVars).map { case (attr, colVar) =>
genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable)
}
val localIdx = ctx.freshName("localIdx")
val localEnd = ctx.freshName("localEnd")
val numRows = ctx.freshName("numRows")
val shouldStop = if (parent.needStopCheck) {
s"if (shouldStop()) { $idx = $rowidx + 1; return; }"
} else {
"// shouldStop check is eliminated"
}
s"""
|if ($batch == null) {
| $nextBatchFuncName();
|}
|while ($batch != null) {
| int $numRows = $batch.numRows();
| int $localEnd = $numRows - $idx;
| for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
| int $rowidx = $idx + $localIdx;
| ${consume(ctx, columnsBatchInput).trim}
| $shouldStop
| }
| $idx = $numRows;
| $batch = null;
| $nextBatchFuncName();
|}
|$scanTimeMetric.add($scanTimeTotalNs / (1000 * 1000));
|$scanTimeTotalNs = 0;
""".stripMargin
}
private def produceRows(ctx: CodegenContext, input: String): String = {
val numOutputRows = metricTerm(ctx, "numOutputRows")
val row = ctx.freshName("row")
ctx.INPUT_ROW = row
ctx.currentVars = null
// Always provide `outputVars`, so that the framework can help us build unsafe row if the input
// row is not unsafe row, i.e. `needsUnsafeRowConversion` is true.
val outputVars = output.zipWithIndex.map { case (a, i) =>
BoundReference(i, a.dataType, a.nullable).genCode(ctx)
}
val inputRow = if (needsUnsafeRowConversion) null else row
s"""
|while ($input.hasNext()) {
| InternalRow $row = (InternalRow) $input.next();
| $numOutputRows.add(1);
| ${consume(ctx, outputVars, inputRow).trim}
| if (shouldStop()) return;
|}
""".stripMargin
}
}