All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.catalyst.expressions.CallMethodViaReflection.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.catalyst.expressions

import java.lang.reflect.{Method, Modifier}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils

/**
 * An expression that invokes a method on a class via reflection.
 *
 * For now, only types defined in `Reflect.typeMapping` are supported (basically primitives
 * and string) as input types, and the output is turned automatically to a string.
 *
 * Note that unlike Hive's reflect function, this expression calls only static methods
 * (i.e. does not support calling non-static methods).
 *
 * We should also look into how to consolidate this expression with
 * [[org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke]] in the future.
 *
 * @param children the first element should be a literal string for the class name,
 *                 and the second element should be a literal string for the method name,
 *                 and the remaining are input arguments to the Java method.
 */
@ExpressionDescription(
  usage = "_FUNC_(class, method[, arg1[, arg2 ..]]) - Calls a method with reflection.",
  examples = """
    Examples:
      > SELECT _FUNC_('java.util.UUID', 'randomUUID');
       c33fb387-8500-4bfa-81d2-6e0e3e930df2
      > SELECT _FUNC_('java.util.UUID', 'fromString', 'a5cf6c42-0c85-418f-af6c-3e4e5b1328f2');
       a5cf6c42-0c85-418f-af6c-3e4e5b1328f2
  """)
case class CallMethodViaReflection(children: Seq[Expression])
  extends Nondeterministic with CodegenFallback {

  override def prettyName: String = "reflect"

  override def checkInputDataTypes(): TypeCheckResult = {
    if (children.size < 2) {
      TypeCheckFailure("requires at least two arguments")
    } else if (!children.take(2).forall(e => e.dataType == StringType && e.foldable)) {
      // The first two arguments must be string type.
      TypeCheckFailure("first two arguments should be string literals")
    } else if (!classExists) {
      TypeCheckFailure(s"class $className not found")
    } else if (children.slice(2, children.length)
        .exists(e => !CallMethodViaReflection.typeMapping.contains(e.dataType))) {
      TypeCheckFailure("arguments from the third require boolean, byte, short, " +
        "integer, long, float, double or string expressions")
    } else if (method == null) {
      TypeCheckFailure(s"cannot find a static method that matches the argument types in $className")
    } else {
      TypeCheckSuccess
    }
  }

  override def nullable: Boolean = true
  override val dataType: DataType = StringType
  override protected def initializeInternal(partitionIndex: Int): Unit = {}

  override protected def evalInternal(input: InternalRow): Any = {
    var i = 0
    while (i < argExprs.length) {
      buffer(i) = argExprs(i).eval(input).asInstanceOf[Object]
      // Convert if necessary. Based on the types defined in typeMapping, string is the only
      // type that needs conversion. If we support timestamps, dates, decimals, arrays, or maps
      // in the future, proper conversion needs to happen here too.
      if (buffer(i).isInstanceOf[UTF8String]) {
        buffer(i) = buffer(i).toString
      }
      i += 1
    }
    val ret = method.invoke(null, buffer : _*)
    UTF8String.fromString(String.valueOf(ret))
  }

  @transient private lazy val argExprs: Array[Expression] = children.drop(2).toArray

  /** Name of the class -- this has to be called after we verify children has at least two exprs. */
  @transient private lazy val className = children(0).eval().asInstanceOf[UTF8String].toString

  /** True if the class exists and can be loaded. */
  @transient private lazy val classExists = CallMethodViaReflection.classExists(className)

  /** The reflection method. */
  @transient lazy val method: Method = {
    val methodName = children(1).eval(null).asInstanceOf[UTF8String].toString
    CallMethodViaReflection.findMethod(className, methodName, argExprs.map(_.dataType)).orNull
  }

  /** A temporary buffer used to hold intermediate results returned by children. */
  @transient private lazy val buffer = new Array[Object](argExprs.length)
}

object CallMethodViaReflection {
  /** Mapping from Spark's type to acceptable JVM types. */
  val typeMapping = Map[DataType, Seq[Class[_]]](
    BooleanType -> Seq(classOf[java.lang.Boolean], classOf[Boolean]),
    ByteType -> Seq(classOf[java.lang.Byte], classOf[Byte]),
    ShortType -> Seq(classOf[java.lang.Short], classOf[Short]),
    IntegerType -> Seq(classOf[java.lang.Integer], classOf[Int]),
    LongType -> Seq(classOf[java.lang.Long], classOf[Long]),
    FloatType -> Seq(classOf[java.lang.Float], classOf[Float]),
    DoubleType -> Seq(classOf[java.lang.Double], classOf[Double]),
    StringType -> Seq(classOf[String])
  )

  /**
   * Returns true if the class can be found and loaded.
   */
  private def classExists(className: String): Boolean = {
    try {
      Utils.classForName(className)
      true
    } catch {
      case e: ClassNotFoundException => false
    }
  }

  /**
   * Finds a Java static method using reflection that matches the given argument types,
   * and whose return type is string.
   *
   * The types sequence must be the valid types defined in [[typeMapping]].
   *
   * This is made public for unit testing.
   */
  def findMethod(className: String, methodName: String, argTypes: Seq[DataType]): Option[Method] = {
    val clazz: Class[_] = Utils.classForName(className)
    clazz.getMethods.find { method =>
      val candidateTypes = method.getParameterTypes
      if (method.getName != methodName) {
        // Name must match
        false
      } else if (!Modifier.isStatic(method.getModifiers)) {
        // Method must be static
        false
      } else if (candidateTypes.length != argTypes.length) {
        // Argument length must match
        false
      } else {
        // Argument type must match. That is, either the method's argument type matches one of the
        // acceptable types defined in typeMapping, or it is a super type of the acceptable types.
        candidateTypes.zip(argTypes).forall { case (candidateType, argType) =>
          typeMapping(argType).exists(candidateType.isAssignableFrom)
        }
      }
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy