org.apache.spark.sql.crosssupport.package.scala Maven / Gradle / Ivy
package org.apache.spark.sql
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListenerStageCompleted, StageInfo}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.types.{FractionalType, IntegralType, StringType, StructField, TimestampType}
package object crosssupport {
def fieldAccumulator(sparkSession: SparkSession)(field: StructField): FieldStatAccumulator = field.dataType match {
case _: FractionalType ⇒ FieldStatAccumulator.double(sparkSession)
case _: IntegralType ⇒ FieldStatAccumulator.long(sparkSession)
case _: TimestampType ⇒ FieldStatAccumulator.long(sparkSession)
case _: StringType ⇒ FieldStatAccumulator.string(sparkSession)
// case _: HiveStringType ⇒ FieldStatAccumulator.string(sparkSession)
case _ ⇒ FieldStatAccumulator.other(sparkSession)
}
implicit class ExpressionEncoderWithDeserializeSupport(encoder: ExpressionEncoder[Row]) extends Serializable {
lazy val deserializer: ExpressionEncoder.Deserializer[Row] = encoder.createDeserializer()
def deserializeRow(r: InternalRow): Row = deserializer(r)
}
def isSubNode(node: LogicalPlan, subNode: LogicalPlan): Boolean =
// Commenting out the first case. Because this causes wrong result to be returned in some cases. See test
// "new column creation" in LineageWriterTest, where new column creation is not reported if we uncomment this.
// case (Project(nodeProjectList, _), Project(subNodeProjectList, _)) ⇒
// nodeProjectList.toSet.intersect(subNodeProjectList.toSet).size == subNodeProjectList.toSet.size
// We were also using node.simpleString == subNode.simpleString, which was returning same result when
// node and subNode both are of type RepartitionByExpression. Hence we are using node.toString now.
// TODO(Ashish): figure out a better way to compare plans.
node.toString.replaceAll("#([0-9]+)", "") == subNode.toString.replaceAll("#([0-9]+)", "")
}