All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.catalyst.expressions.complexTypeCreator.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types.UTF8String

/**
 * Returns an Array containing the evaluation of all children expressions.
 */
@ExpressionDescription(
  usage = "_FUNC_(expr, ...) - Returns an array with the given elements.",
  examples = """
    Examples:
      > SELECT _FUNC_(1, 2, 3);
       [1,2,3]
  """)
case class CreateArray(children: Seq[Expression]) extends Expression {

  override def foldable: Boolean = children.forall(_.foldable)

  override def checkInputDataTypes(): TypeCheckResult = {
    TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName")
  }

  override def dataType: ArrayType = {
    ArrayType(
      TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(children.map(_.dataType))
        .getOrElse(StringType),
      containsNull = children.exists(_.nullable))
  }

  override def nullable: Boolean = false

  override def eval(input: InternalRow): Any = {
    new GenericArrayData(children.map(_.eval(input)).toArray)
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val et = dataType.elementType
    val evals = children.map(e => e.genCode(ctx))
    val (preprocess, assigns, postprocess, arrayData) =
      GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false)
    ev.copy(
      code = code"${preprocess}${assigns}${postprocess}",
      value = JavaCode.variable(arrayData, dataType),
      isNull = FalseLiteral)
  }

  override def prettyName: String = "array"
}

private [sql] object GenArrayData {
  /**
   * Return Java code pieces based on DataType and isPrimitive to allocate ArrayData class
   *
   * @param ctx a [[CodegenContext]]
   * @param elementType data type of underlying array elements
   * @param elementsCode concatenated set of [[ExprCode]] for each element of an underlying array
   * @param isMapKey if true, throw an exception when the element is null
   * @return (code pre-assignments, concatenated assignments to each array elements,
   *           code post-assignments, arrayData name)
   */
  def genCodeToCreateArrayData(
      ctx: CodegenContext,
      elementType: DataType,
      elementsCode: Seq[ExprCode],
      isMapKey: Boolean): (String, String, String, String) = {
    val arrayDataName = ctx.freshName("arrayData")
    val numElements = elementsCode.length

    if (!CodeGenerator.isPrimitiveType(elementType)) {
      val arrayName = ctx.freshName("arrayObject")
      val genericArrayClass = classOf[GenericArrayData].getName

      val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
        val isNullAssignment = if (!isMapKey) {
          s"$arrayName[$i] = null;"
        } else {
          "throw new RuntimeException(\"Cannot use null as map key!\");"
        }
        eval.code + s"""
         if (${eval.isNull}) {
           $isNullAssignment
         } else {
           $arrayName[$i] = ${eval.value};
         }
       """
      }
      val assignmentString = ctx.splitExpressionsWithCurrentInputs(
        expressions = assignments,
        funcName = "apply",
        extraArguments = ("Object[]", arrayName) :: Nil)

      (s"Object[] $arrayName = new Object[$numElements];",
       assignmentString,
       s"final ArrayData $arrayDataName = new $genericArrayClass($arrayName);",
       arrayDataName)
    } else {
      val arrayName = ctx.freshName("array")
      val unsafeArraySizeInBytes =
        UnsafeArrayData.calculateHeaderPortionInBytes(numElements) +
        ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements)
      val baseOffset = "Platform.BYTE_ARRAY_OFFSET"

      val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
      val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
        val isNullAssignment = if (!isMapKey) {
          s"$arrayDataName.setNullAt($i);"
        } else {
          "throw new RuntimeException(\"Cannot use null as map key!\");"
        }
        eval.code + s"""
         if (${eval.isNull}) {
           $isNullAssignment
         } else {
           $arrayDataName.set$primitiveValueTypeName($i, ${eval.value});
         }
       """
      }
      val assignmentString = ctx.splitExpressionsWithCurrentInputs(
        expressions = assignments,
        funcName = "apply",
        extraArguments = ("UnsafeArrayData", arrayDataName) :: Nil)

      (s"""
        byte[] $arrayName = new byte[$unsafeArraySizeInBytes];
        UnsafeArrayData $arrayDataName = new UnsafeArrayData();
        Platform.putLong($arrayName, $baseOffset, $numElements);
        $arrayDataName.pointTo($arrayName, $baseOffset, $unsafeArraySizeInBytes);
      """,
       assignmentString,
       "",
       arrayDataName)
    }
  }
}

/**
 * Returns a catalyst Map containing the evaluation of all children expressions as keys and values.
 * The children are a flatted sequence of kv pairs, e.g. (key1, value1, key2, value2, ...)
 */
@ExpressionDescription(
  usage = "_FUNC_(key0, value0, key1, value1, ...) - Creates a map with the given key/value pairs.",
  examples = """
    Examples:
      > SELECT _FUNC_(1.0, '2', 3.0, '4');
       {1.0:"2",3.0:"4"}
  """)
case class CreateMap(children: Seq[Expression]) extends Expression {
  lazy val keys = children.indices.filter(_ % 2 == 0).map(children)
  lazy val values = children.indices.filter(_ % 2 != 0).map(children)

  override def foldable: Boolean = children.forall(_.foldable)

  override def checkInputDataTypes(): TypeCheckResult = {
    if (children.size % 2 != 0) {
      TypeCheckResult.TypeCheckFailure(
        s"$prettyName expects a positive even number of arguments.")
    } else if (!TypeCoercion.haveSameType(keys.map(_.dataType))) {
      TypeCheckResult.TypeCheckFailure(
        "The given keys of function map should all be the same type, but they are " +
          keys.map(_.dataType.catalogString).mkString("[", ", ", "]"))
    } else if (!TypeCoercion.haveSameType(values.map(_.dataType))) {
      TypeCheckResult.TypeCheckFailure(
        "The given values of function map should all be the same type, but they are " +
          values.map(_.dataType.catalogString).mkString("[", ", ", "]"))
    } else {
      TypeCheckResult.TypeCheckSuccess
    }
  }

  override def dataType: DataType = {
    MapType(
      keyType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(keys.map(_.dataType))
        .getOrElse(StringType),
      valueType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(values.map(_.dataType))
        .getOrElse(StringType),
      valueContainsNull = values.exists(_.nullable))
  }

  override def nullable: Boolean = false

  override def eval(input: InternalRow): Any = {
    val keyArray = keys.map(_.eval(input)).toArray
    if (keyArray.contains(null)) {
      throw new RuntimeException("Cannot use null as map key!")
    }
    val valueArray = values.map(_.eval(input)).toArray
    new ArrayBasedMapData(new GenericArrayData(keyArray), new GenericArrayData(valueArray))
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val mapClass = classOf[ArrayBasedMapData].getName
    val MapType(keyDt, valueDt, _) = dataType
    val evalKeys = keys.map(e => e.genCode(ctx))
    val evalValues = values.map(e => e.genCode(ctx))
    val (preprocessKeyData, assignKeys, postprocessKeyData, keyArrayData) =
      GenArrayData.genCodeToCreateArrayData(ctx, keyDt, evalKeys, true)
    val (preprocessValueData, assignValues, postprocessValueData, valueArrayData) =
      GenArrayData.genCodeToCreateArrayData(ctx, valueDt, evalValues, false)
    val code =
      code"""
       final boolean ${ev.isNull} = false;
       $preprocessKeyData
       $assignKeys
       $postprocessKeyData
       $preprocessValueData
       $assignValues
       $postprocessValueData
       final MapData ${ev.value} = new $mapClass($keyArrayData, $valueArrayData);
      """
    ev.copy(code = code)
  }

  override def prettyName: String = "map"
}

/**
 * Returns a catalyst Map containing the two arrays in children expressions as keys and values.
 */
@ExpressionDescription(
  usage = """
    _FUNC_(keys, values) - Creates a map with a pair of the given key/value arrays. All elements
      in keys should not be null""",
  examples = """
    Examples:
      > SELECT _FUNC_(array(1.0, 3.0), array('2', '4'));
       {1.0:"2",3.0:"4"}
  """, since = "2.4.0")
case class MapFromArrays(left: Expression, right: Expression)
  extends BinaryExpression with ExpectsInputTypes {

  override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType)

  override def dataType: DataType = {
    MapType(
      keyType = left.dataType.asInstanceOf[ArrayType].elementType,
      valueType = right.dataType.asInstanceOf[ArrayType].elementType,
      valueContainsNull = right.dataType.asInstanceOf[ArrayType].containsNull)
  }

  override def nullSafeEval(keyArray: Any, valueArray: Any): Any = {
    val keyArrayData = keyArray.asInstanceOf[ArrayData]
    val valueArrayData = valueArray.asInstanceOf[ArrayData]
    if (keyArrayData.numElements != valueArrayData.numElements) {
      throw new RuntimeException("The given two arrays should have the same length")
    }
    val leftArrayType = left.dataType.asInstanceOf[ArrayType]
    if (leftArrayType.containsNull) {
      var i = 0
      while (i < keyArrayData.numElements) {
        if (keyArrayData.isNullAt(i)) {
          throw new RuntimeException("Cannot use null as map key!")
        }
        i += 1
      }
    }
    new ArrayBasedMapData(keyArrayData.copy(), valueArrayData.copy())
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    nullSafeCodeGen(ctx, ev, (keyArrayData, valueArrayData) => {
      val arrayBasedMapData = classOf[ArrayBasedMapData].getName
      val leftArrayType = left.dataType.asInstanceOf[ArrayType]
      val keyArrayElemNullCheck = if (!leftArrayType.containsNull) "" else {
        val i = ctx.freshName("i")
        s"""
           |for (int $i = 0; $i < $keyArrayData.numElements(); $i++) {
           |  if ($keyArrayData.isNullAt($i)) {
           |    throw new RuntimeException("Cannot use null as map key!");
           |  }
           |}
         """.stripMargin
      }
      s"""
         |if ($keyArrayData.numElements() != $valueArrayData.numElements()) {
         |  throw new RuntimeException("The given two arrays should have the same length");
         |}
         |$keyArrayElemNullCheck
         |${ev.value} = new $arrayBasedMapData($keyArrayData.copy(), $valueArrayData.copy());
       """.stripMargin
    })
  }

  override def prettyName: String = "map_from_arrays"
}

/**
 * An expression representing a not yet available attribute name. This expression is unevaluable
 * and as its name suggests it is a temporary place holder until we're able to determine the
 * actual attribute name.
 */
case object NamePlaceholder extends LeafExpression with Unevaluable {
  override lazy val resolved: Boolean = false
  override def foldable: Boolean = false
  override def nullable: Boolean = false
  override def dataType: DataType = StringType
  override def prettyName: String = "NamePlaceholder"
  override def toString: String = prettyName
}

/**
 * Returns a Row containing the evaluation of all children expressions.
 */
object CreateStruct extends FunctionBuilder {
  def apply(children: Seq[Expression]): CreateNamedStruct = {
    CreateNamedStruct(children.zipWithIndex.flatMap {
      case (e: NamedExpression, _) if e.resolved => Seq(Literal(e.name), e)
      case (e: NamedExpression, _) => Seq(NamePlaceholder, e)
      case (e, index) => Seq(Literal(s"col${index + 1}"), e)
    })
  }

  /**
   * Entry to use in the function registry.
   */
  val registryEntry: (String, (ExpressionInfo, FunctionBuilder)) = {
    val info: ExpressionInfo = new ExpressionInfo(
      "org.apache.spark.sql.catalyst.expressions.NamedStruct",
      null,
      "struct",
      "_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.",
      "",
      "",
      "",
      "")
    ("struct", (info, this))
  }
}

/**
 * Common base class for both [[CreateNamedStruct]] and [[CreateNamedStructUnsafe]].
 */
trait CreateNamedStructLike extends Expression {
  lazy val (nameExprs, valExprs) = children.grouped(2).map {
    case Seq(name, value) => (name, value)
  }.toList.unzip

  lazy val names = nameExprs.map(_.eval(EmptyRow))

  override def nullable: Boolean = false

  override def foldable: Boolean = valExprs.forall(_.foldable)

  override lazy val dataType: StructType = {
    val fields = names.zip(valExprs).map {
      case (name, expr) =>
        val metadata = expr match {
          case ne: NamedExpression => ne.metadata
          case _ => Metadata.empty
        }
        StructField(name.toString, expr.dataType, expr.nullable, metadata)
    }
    StructType(fields)
  }

  override def checkInputDataTypes(): TypeCheckResult = {
    if (children.size % 2 != 0) {
      TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.")
    } else {
      val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType)
      if (invalidNames.nonEmpty) {
        TypeCheckResult.TypeCheckFailure(
          s"Only foldable ${StringType.catalogString} expressions are allowed to appear at odd" +
            s" position, got: ${invalidNames.mkString(",")}")
      } else if (!names.contains(null)) {
        TypeCheckResult.TypeCheckSuccess
      } else {
        TypeCheckResult.TypeCheckFailure("Field name should not be null")
      }
    }
  }

  /**
   * Returns Aliased [[Expression]]s that could be used to construct a flattened version of this
   * StructType.
   */
  def flatten: Seq[NamedExpression] = valExprs.zip(names).map {
    case (v, n) => Alias(v, n.toString)()
  }

  override def eval(input: InternalRow): Any = {
    InternalRow(valExprs.map(_.eval(input)): _*)
  }
}

/**
 * Creates a struct with the given field names and values
 *
 * @param children Seq(name1, val1, name2, val2, ...)
 */
// scalastyle:off line.size.limit
@ExpressionDescription(
  usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.",
  examples = """
    Examples:
      > SELECT _FUNC_("a", 1, "b", 2, "c", 3);
       {"a":1,"b":2,"c":3}
  """)
// scalastyle:on line.size.limit
case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStructLike {

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val rowClass = classOf[GenericInternalRow].getName
    val values = ctx.freshName("values")
    val valCodes = valExprs.zipWithIndex.map { case (e, i) =>
      val eval = e.genCode(ctx)
      s"""
         |${eval.code}
         |if (${eval.isNull}) {
         |  $values[$i] = null;
         |} else {
         |  $values[$i] = ${eval.value};
         |}
       """.stripMargin
    }
    val valuesCode = ctx.splitExpressionsWithCurrentInputs(
      expressions = valCodes,
      funcName = "createNamedStruct",
      extraArguments = "Object[]" -> values :: Nil)

    ev.copy(code =
      code"""
         |Object[] $values = new Object[${valExprs.size}];
         |$valuesCode
         |final InternalRow ${ev.value} = new $rowClass($values);
         |$values = null;
       """.stripMargin, isNull = FalseLiteral)
  }

  override def prettyName: String = "named_struct"
}

/**
 * Creates a struct with the given field names and values. This is a variant that returns
 * UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with
 * this expression automatically at runtime.
 *
 * @param children Seq(name1, val1, name2, val2, ...)
 */
case class CreateNamedStructUnsafe(children: Seq[Expression]) extends CreateNamedStructLike {
  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val eval = GenerateUnsafeProjection.createCode(ctx, valExprs)
    ExprCode(code = eval.code, isNull = FalseLiteral, value = eval.value)
  }

  override def prettyName: String = "named_struct_unsafe"
}

/**
 * Creates a map after splitting the input text into key/value pairs using delimiters
 */
// scalastyle:off line.size.limit
@ExpressionDescription(
  usage = "_FUNC_(text[, pairDelim[, keyValueDelim]]) - Creates a map after splitting the text into key/value pairs using delimiters. Default delimiters are ',' for `pairDelim` and ':' for `keyValueDelim`.",
  examples = """
    Examples:
      > SELECT _FUNC_('a:1,b:2,c:3', ',', ':');
       map("a":"1","b":"2","c":"3")
      > SELECT _FUNC_('a');
       map("a":null)
  """)
// scalastyle:on line.size.limit
case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: Expression)
  extends TernaryExpression with CodegenFallback with ExpectsInputTypes {

  def this(child: Expression, pairDelim: Expression) = {
    this(child, pairDelim, Literal(":"))
  }

  def this(child: Expression) = {
    this(child, Literal(","), Literal(":"))
  }

  override def children: Seq[Expression] = Seq(text, pairDelim, keyValueDelim)

  override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType)

  override def dataType: DataType = MapType(StringType, StringType)

  override def checkInputDataTypes(): TypeCheckResult = {
    if (Seq(pairDelim, keyValueDelim).exists(! _.foldable)) {
      TypeCheckResult.TypeCheckFailure(s"$prettyName's delimiters must be foldable.")
    } else {
      super.checkInputDataTypes()
    }
  }

  override def nullSafeEval(
      inputString: Any,
      stringDelimiter: Any,
      keyValueDelimiter: Any): Any = {
    val keyValues =
      inputString.asInstanceOf[UTF8String].split(stringDelimiter.asInstanceOf[UTF8String], -1)

    val iterator = new Iterator[(UTF8String, UTF8String)] {
      var index = 0
      val keyValueDelimiterUTF8String = keyValueDelimiter.asInstanceOf[UTF8String]

      override def hasNext: Boolean = {
        keyValues.length > index
      }

      override def next(): (UTF8String, UTF8String) = {
        val keyValueArray = keyValues(index).split(keyValueDelimiterUTF8String, 2)
        index += 1
        (keyValueArray(0), if (keyValueArray.length < 2) null else keyValueArray(1))
      }
    }
    ArrayBasedMapData(iterator, keyValues.size, identity, identity)
  }

  override def prettyName: String = "str_to_map"
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy