All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.sparkutils.quality.impl.id.Base64Expressions.scala Maven / Gradle / Ivy

package com.sparkutils.quality.impl.id

import java.util.Base64

import com.sparkutils.quality.impl.id.model.GuaranteedUniqueIDType
import org.apache.spark.sql.ShimUtils.{toSQLExpr, toSQLType, mismatch, toSQLValue}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckSuccess
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression}
import org.apache.spark.sql.shim.expressions.InputTypeChecks
import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String

/**
 * for a given string returns the length of the given id in longs
 * @param child
 */
case class SizeOfIDString(child: Expression) extends UnaryExpression with InputTypeChecks with CodegenFallback {
  protected def withNewChildInternal(newChild: Expression): Expression = copy(child = newChild)

  override def nullSafeEval(input: Any): Any = try {
    val bytes = Base64.getDecoder.decode(input.toString)
    model.idTypeOf(bytes(0)) match {
      case GuaranteedUniqueIDType => 2
      case _ =>
        model.lengthOfID(bytes)
    }
  } catch {
    case _: Throwable => null
  }

  override def nullable: Boolean = true

  override def dataType: DataType = IntegerType

  override def inputDataTypes: Seq[Seq[DataType]] = Seq(Seq(StringType))
}

trait IDStructChecker extends InputTypeChecks {

  def child: Expression

  override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
    case StructType(fields) if fields.head.name.endsWith("_base") && fields.head.dataType == IntegerType &&
      fields.drop(1).zipWithIndex.forall{ case (f, i) => f.name.endsWith(s"_i$i") && f.dataType == LongType } =>
      TypeCheckSuccess
    case _ => mismatch(
      errorSubClass = "UNEXPECTED_INPUT_TYPE",
      messageParameters = Map(
        "paramIndex" -> "1",
        "requiredType" -> "<.._base: INT, .._i0: BIGINT, .._i1: BIGINT>",
        "inputSql" -> toSQLExpr(child),
        "inputType" -> toSQLType(child.dataType)
      )
    )
  }

  override def inputDataTypes: Seq[Seq[DataType]] = Seq()

}


/**
 * For an id structure generates a string
 * @param child
 */
case class AsBase64Struct(child: Expression) extends UnaryExpression with IDStructChecker with CodegenFallback {
  protected def withNewChildInternal(newChild: Expression): Expression = copy(child = newChild)

  /**
   * used with SizeOfIDString to verify strings are comparable
   */
  lazy val size = child.dataType.asInstanceOf[StructType].fields.size - 1

  override def nullSafeEval(input: Any): Any = {
    val struct = input.asInstanceOf[InternalRow]
    val base = struct.getInt(0)
    val longs = (1 to size).map(i => struct.getLong(i)).toArray
    UTF8String.fromString( model.base64( model.bitLength(size), base, longs) )
  }

  override def dataType: DataType = StringType
}

/**
 * For an id structure generates a string, first arg is an int, the others long
 * @param children
 */
case class AsBase64Fields(children: Seq[Expression]) extends Expression with InputTypeChecks with CodegenFallback {

  /**
   * used with SizeOfIDString to verify strings are comparable
   */
  lazy val size = children.size - 1

  override def eval(input: InternalRow = null): Any = {
    val base = children(0).eval(input)
    val longs = (1 to size).map(i => children(i).eval(input)).toArray
    if (base == null || longs.exists(_ == null))
      null
    else
      UTF8String.fromString( model.base64(model.bitLength(size), base.asInstanceOf[Int], longs.map(_.asInstanceOf[Long])) )
  }

  override def dataType: DataType = StringType

  override def checkInputDataTypes(): TypeCheckResult = children.map(_.dataType) match {
    case s: Seq[DataType] if s.size < 3 =>
      mismatch(
        errorSubClass = "VALUE_OUT_OF_RANGE",
        messageParameters = Map(
          "exprName" -> "arguments",
          "valueRange" -> s"[3, positive]",
          "currentValue" -> toSQLValue(s.size, IntegerType)
        )
      )
    case IntegerType +: tail =>
      tail.zipWithIndex.map{
        case (t, _) if t == LongType => TypeCheckSuccess
        case (t, i) => mismatch(
          errorSubClass = "UNEXPECTED_INPUT_TYPE",
          messageParameters = Map(
            "paramIndex" -> (i + 2).toString, // 1 is first param
            "requiredType" -> toSQLType(LongType),
            "inputSql" -> toSQLExpr(children(i + 1)),
            "inputType" -> toSQLType(t)
          )
        )
      }.find(_.isFailure).getOrElse(TypeCheckSuccess)

    case _ => mismatch(
      errorSubClass = "UNEXPECTED_INPUT_TYPE",
      messageParameters = Map(
        "paramIndex" -> "1",
        "requiredType" -> toSQLType(IntegerType),
        "inputSql" -> toSQLExpr(children(0)),
        "inputType" -> toSQLType(children(0).dataType)
      )
    )
  }

  override def nullable: Boolean = children.exists(_.nullable)

  protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children = newChildren)

  override def inputDataTypes: Seq[Seq[DataType]] = Seq()
}

/**
 * Generates an unprefixed 'raw' id structure of a given size.  Note that size is fixed, the type can't change on the plan during the plan.
 *
 * @param child the base64 strings that must have the same size, will return null if it's not the right size, or cannot parse it.
 * @param size the size for the number of longs to have, 2 longs is 160 bit and the default
 */
case class IDFromBase64(child: Expression, size: Int) extends UnaryExpression with InputTypeChecks with CodegenFallback {

  override def nullSafeEval(child: Any): Any =  try {
    val id = model.parseID(child.toString).asInstanceOf[BaseWithLongs]
    val ar = id.array
    if (ar.length != size)
      null
    else
      InternalRow.fromSeq(id.base +: ar.toSeq)
  } catch {
    case _: Throwable => null
  }

  override def nullable: Boolean = true

  override def dataType: DataType = model.rawType(size)

  override def inputDataTypes: Seq[Seq[DataType]] = Seq(Seq(StringType))

  protected def withNewChildInternal(newChild: Expression): Expression = copy(child = newChild)
}

/**
 * Converts any prefixed id back to rawType (base, i0, i1 etc.)
 * @param child
 */
case class IDToRawIDDataType(child: Expression) extends UnaryExpression with IDStructChecker with CodegenFallback {

  override def nullSafeEval(child: Any): Any = child

  override def nullable: Boolean = true

  override def dataType: DataType = model.rawType(child.dataType.asInstanceOf[StructType].fields.length - 1)

  protected def withNewChildInternal(newChild: Expression): Expression = copy(child = newChild)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy