All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.sparkutils.quality.impl.PackExpression.scala Maven / Gradle / Ivy

package com.sparkutils.quality.impl

import com.sparkutils.quality.Id
import com.sparkutils.quality.impl.imports.RuleResultsImports.packId
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, ExpressionDescription, NullIntolerant, UnaryExpression}
import org.apache.spark.sql.shim.expressions.InputTypeChecks
import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StructField, StructType}

object Pack {
  def apply(id: Column, version: Column): Column =
    new Column( apply(id.expr, version.expr) )

  def apply(id: Expression, version: Expression) =
    PackExpression( id, version )
}

@ExpressionDescription(
  usage = "packInts(expr, expr) - Returns a long with two packed ints.",
  examples = """
    Examples:
      > SELECT packInts(1, 2);
       10231L
  """)
case class PackExpression(left: Expression, right: Expression) extends BinaryExpression
  with NullIntolerant with InputTypeChecks {

  override def nullSafeEval(id: Any, version: Any): Any = {
    packId(Id(id.asInstanceOf[Int], version.asInstanceOf[Int]))
  }
  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
    defineCodeGen(ctx, ev, (l, r) => s"(Long)com.sparkutils.quality.impl.PackId.packId(new com.sparkutils.quality.Id((Integer)($l), (Integer)($r)))")

  override def dataType: DataType = LongType

  override def inputDataTypes: Seq[Seq[DataType]] = Seq(Seq(IntegerType), Seq(IntegerType))

  protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = copy(left = newLeft, right = newRight)
}

object UnPack {
  def apply(packed: Column): Column =
    new Column( apply(packed.expr) )

  def apply( packed: Expression ) =
    UnPackExpression( packed )

  def toRow( id: Id ) = InternalRow(id.id, id.version)
}

@ExpressionDescription(
  usage = "unpack(expr) - Returns a struct with two ints. id and version",
  examples = """
    Examples:
      > SELECT packInts(1, 2);
       10231L
  """)
case class UnPackExpression(child: Expression) extends UnaryExpression
  with NullIntolerant with InputTypeChecks {

  override def nullSafeEval(packed: Any): Any = {
    val id = PackId.unpack(packed.asInstanceOf[Long])
    UnPack.toRow(id)
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
    defineCodeGen(ctx, ev, (c) => s"com.sparkutils.quality.impl.UnPack.toRow( com.sparkutils.quality.impl.PackId.unpack((Long)$c) )")

  override def dataType: DataType = StructType( Seq(
    StructField(name = "id", dataType = IntegerType),
    StructField(name = "version", dataType = IntegerType)
  ))

  override def inputDataTypes: Seq[Seq[DataType]] = Seq(Seq(LongType))

  protected def withNewChildInternal(newChild: Expression): Expression = copy(child = newChild)
}

object UnPackIdTriple {
  def apply(packed: Column): Column =
    new Column( apply(packed.expr) )

  def apply( packed: Expression ) =
    UnPackIdTripleExpression( packed )

  def toRow( packed: Any ): InternalRow =
    if (packed == null)
      packed.asInstanceOf[InternalRow]
    else {
      val i = packed.asInstanceOf[InternalRow]

      val rsuid = PackId.unpack(i.getLong(0))
      val rsid = PackId.unpack(i.getLong(1))
      val ruid = PackId.unpack(i.getLong(2))

      InternalRow(rsuid.id, rsuid.version, rsid.id, rsid.version, ruid.id, ruid.version)
    }
}

@ExpressionDescription(
  usage = "unpackLong(expr) - Returns a struct with two ints. id and version",
  examples = """
    Examples:
      > SELECT packInts(1, 2);
       10231L
  """)
case class UnPackIdTripleExpression(child: Expression) extends UnaryExpression with InputTypeChecks {

  override def nullSafeEval(packed: Any): Any =
    UnPackIdTriple.toRow(packed)

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
    defineCodeGen(ctx, ev, (c) => s"com.sparkutils.quality.impl.UnPackIdTriple.toRow( $c )")

  override def nullable: Boolean = true

  override def dataType: DataType = StructType( Seq(
    StructField(name = "ruleSuiteId", dataType = IntegerType),
    StructField(name = "ruleSuiteVersion", dataType = IntegerType),
    StructField(name = "ruleSetId", dataType = IntegerType),
    StructField(name = "ruleSetVersion", dataType = IntegerType),
    StructField(name = "ruleId", dataType = IntegerType),
    StructField(name = "ruleVersion", dataType = IntegerType)
  ))

  override def inputDataTypes: Seq[Seq[DataType]] = Seq(Seq(com.sparkutils.quality.types.fullRuleIdType))

  protected def withNewChildInternal(newChild: Expression): Expression = copy(child = newChild)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy