All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.nvidia.dataframe_udfs.scala Maven / Gradle / Ivy

/*
 * Copyright (c) 2024, NVIDIA CORPORATION.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.nvidia

import java.lang.invoke.SerializedLambda

import org.apache.spark.sql.Column
import org.apache.spark.sql.api.java._
import org.apache.spark.util.Utils

trait DFUDF {
  def apply(input: Array[Column]): Column
}

case class DFUDF0(f: Function0[Column])
  extends UDF0[Any] with DFUDF {
  override def call(): Any = {
    throw new IllegalStateException("TODO better error message. This should have been replaced")
  }

  override def apply(input: Array[Column]): Column = {
    assert(input.length == 0)
    f()
  }
}

case class DFUDF1(f: Function1[Column, Column])
  extends UDF1[Any, Any] with DFUDF {
  override def call(t1: Any): Any = {
    throw new IllegalStateException("TODO better error message. This should have been replaced")
  }

  override def apply(input: Array[Column]): Column = {
    assert(input.length == 1)
    f(input(0))
  }
}

case class DFUDF2(f: Function2[Column, Column, Column])
  extends UDF2[Any, Any, Any] with DFUDF {
  override def call(t1: Any, t2: Any): Any = {
    throw new IllegalStateException("TODO better error message. This should have been replaced")
  }

  override def apply(input: Array[Column]): Column = {
    assert(input.length == 2)
    f(input(0), input(1))
  }
}

case class DFUDF3(f: Function3[Column, Column, Column, Column])
  extends UDF3[Any, Any, Any, Any] with DFUDF {
  override def call(t1: Any, t2: Any, t3: Any): Any = {
    throw new IllegalStateException("TODO better error message. This should have been replaced")
  }

  override def apply(input: Array[Column]): Column = {
    assert(input.length == 3)
    f(input(0), input(1), input(2))
  }
}

case class DFUDF4(f: Function4[Column, Column, Column, Column, Column])
  extends UDF4[Any, Any, Any, Any, Any] with DFUDF {
  override def call(t1: Any, t2: Any, t3: Any, t4: Any): Any = {
    throw new IllegalStateException("TODO better error message. This should have been replaced")
  }

  override def apply(input: Array[Column]): Column = {
    assert(input.length == 4)
    f(input(0), input(1), input(2), input(3))
  }
}

case class DFUDF5(f: Function5[Column, Column, Column, Column, Column, Column])
  extends UDF5[Any, Any, Any, Any, Any, Any] with DFUDF {
  override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any): Any = {
    throw new IllegalStateException("TODO better error message. This should have been replaced")
  }

  override def apply(input: Array[Column]): Column = {
    assert(input.length == 5)
    f(input(0), input(1), input(2), input(3), input(4))
  }
}

case class DFUDF6(f: Function6[Column, Column, Column, Column, Column, Column, Column])
  extends UDF6[Any, Any, Any, Any, Any, Any, Any] with DFUDF {
  override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any): Any = {
    throw new IllegalStateException("TODO better error message. This should have been replaced")
  }

  override def apply(input: Array[Column]): Column = {
    assert(input.length == 6)
    f(input(0), input(1), input(2), input(3), input(4), input(5))
  }
}

case class DFUDF7(f: Function7[Column, Column, Column, Column, Column, Column, Column, Column])
  extends UDF7[Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF {
  override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any): Any = {
    throw new IllegalStateException("TODO better error message. This should have been replaced")
  }

  override def apply(input: Array[Column]): Column = {
    assert(input.length == 7)
    f(input(0), input(1), input(2), input(3), input(4), input(5), input(6))
  }
}

case class DFUDF8(f: Function8[Column, Column, Column, Column, Column, Column, Column, Column,
  Column])
  extends UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF {
  override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any, t8: Any): Any = {
    throw new IllegalStateException("TODO better error message. This should have been replaced")
  }

  override def apply(input: Array[Column]): Column = {
    assert(input.length == 8)
    f(input(0), input(1), input(2), input(3), input(4), input(5), input(6), input(7))
  }
}

case class DFUDF9(f: Function9[Column, Column, Column, Column, Column, Column, Column, Column,
  Column, Column])
  extends UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF {
  override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any, t8: Any,
                    t9: Any): Any = {
    throw new IllegalStateException("TODO better error message. This should have been replaced")
  }

  override def apply(input: Array[Column]): Column = {
    assert(input.length == 9)
    f(input(0), input(1), input(2), input(3), input(4), input(5), input(6), input(7), input(8))
  }
}

case class DFUDF10(f: Function10[Column, Column, Column, Column, Column, Column, Column, Column,
  Column, Column, Column])
  extends UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF {
  override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any, t8: Any,
                    t9: Any, t10: Any): Any = {
    throw new IllegalStateException("TODO better error message. This should have been replaced")
  }

  override def apply(input: Array[Column]): Column = {
    assert(input.length == 10)
    f(input(0), input(1), input(2), input(3), input(4), input(5), input(6), input(7), input(8),
      input(9))
  }
}

case class JDFUDF0(f: UDF0[Column])
  extends UDF0[Any] with DFUDF {
  override def call(): Any = {
    throw new IllegalStateException("TODO better error message. This should have been replaced")
  }

  override def apply(input: Array[Column]): Column = {
    assert(input.length == 0)
    f.call()
  }
}

case class JDFUDF1(f: UDF1[Column, Column])
  extends UDF1[Any, Any] with DFUDF {
  override def call(t1: Any): Any = {
    throw new IllegalStateException("TODO better error message. This should have been replaced")
  }

  override def apply(input: Array[Column]): Column = {
    assert(input.length == 1)
    f.call(input(0))
  }
}

case class JDFUDF2(f: UDF2[Column, Column, Column])
  extends UDF2[Any, Any, Any] with DFUDF {
  override def call(t1: Any, t2: Any): Any = {
    throw new IllegalStateException("TODO better error message. This should have been replaced")
  }

  override def apply(input: Array[Column]): Column = {
    assert(input.length == 2)
    f.call(input(0), input(1))
  }
}

case class JDFUDF3(f: UDF3[Column, Column, Column, Column])
  extends UDF3[Any, Any, Any, Any] with DFUDF {
  override def call(t1: Any, t2: Any, t3: Any): Any = {
    throw new IllegalStateException("TODO better error message. This should have been replaced")
  }

  override def apply(input: Array[Column]): Column = {
    assert(input.length == 3)
    f.call(input(0), input(1), input(2))
  }
}

case class JDFUDF4(f: UDF4[Column, Column, Column, Column, Column])
  extends UDF4[Any, Any, Any, Any, Any] with DFUDF {
  override def call(t1: Any, t2: Any, t3: Any, t4: Any): Any = {
    throw new IllegalStateException("TODO better error message. This should have been replaced")
  }

  override def apply(input: Array[Column]): Column = {
    assert(input.length == 4)
    f.call(input(0), input(1), input(2), input(3))
  }
}

case class JDFUDF5(f: UDF5[Column, Column, Column, Column, Column, Column])
  extends UDF5[Any, Any, Any, Any, Any, Any] with DFUDF {
  override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any): Any = {
    throw new IllegalStateException("TODO better error message. This should have been replaced")
  }

  override def apply(input: Array[Column]): Column = {
    assert(input.length == 5)
    f.call(input(0), input(1), input(2), input(3), input(4))
  }
}

case class JDFUDF6(f: UDF6[Column, Column, Column, Column, Column, Column, Column])
  extends UDF6[Any, Any, Any, Any, Any, Any, Any] with DFUDF {
  override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any): Any = {
    throw new IllegalStateException("TODO better error message. This should have been replaced")
  }

  override def apply(input: Array[Column]): Column = {
    assert(input.length == 6)
    f.call(input(0), input(1), input(2), input(3), input(4), input(5))
  }
}

case class JDFUDF7(f: UDF7[Column, Column, Column, Column, Column, Column, Column, Column])
  extends UDF7[Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF {
  override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any): Any = {
    throw new IllegalStateException("TODO better error message. This should have been replaced")
  }

  override def apply(input: Array[Column]): Column = {
    assert(input.length == 7)
    f.call(input(0), input(1), input(2), input(3), input(4), input(5), input(6))
  }
}

case class JDFUDF8(f: UDF8[Column, Column, Column, Column, Column, Column, Column, Column,
  Column])
  extends UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF {
  override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any, t8: Any): Any = {
    throw new IllegalStateException("TODO better error message. This should have been replaced")
  }

  override def apply(input: Array[Column]): Column = {
    assert(input.length == 8)
    f.call(input(0), input(1), input(2), input(3), input(4), input(5), input(6), input(7))
  }
}

case class JDFUDF9(f: UDF9[Column, Column, Column, Column, Column, Column, Column, Column,
  Column, Column])
  extends UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF {
  override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any, t8: Any,
                    t9: Any): Any = {
    throw new IllegalStateException("TODO better error message. This should have been replaced")
  }

  override def apply(input: Array[Column]): Column = {
    assert(input.length == 9)
    f.call(input(0), input(1), input(2), input(3), input(4), input(5), input(6), input(7), input(8))
  }
}

case class JDFUDF10(f: UDF10[Column, Column, Column, Column, Column, Column, Column, Column,
  Column, Column, Column])
  extends UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF {
  override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any, t8: Any,
                    t9: Any, t10: Any): Any = {
    throw new IllegalStateException("TODO better error message. This should have been replaced")
  }

  override def apply(input: Array[Column]): Column = {
    assert(input.length == 10)
    f.call(input(0), input(1), input(2), input(3), input(4), input(5), input(6), input(7), input(8),
      input(9))
  }
}

object DFUDF {
  /**
   * Determine if the UDF function implements the DFUDF.
   */
  def getDFUDF(function: AnyRef): Option[DFUDF] = {
    function match {
      case f: DFUDF => Some(f)
      case f =>
        try {
          // This may be a lambda that Spark's UDFRegistration wrapped around a Java UDF instance.
          val clazz = f.getClass
          if (Utils.getSimpleName(clazz).toLowerCase().contains("lambda")) {
            // Try to find a `writeReplace` method, further indicating it is likely a lambda
            // instance, and invoke it to serialize the lambda. Once serialized, captured arguments
            // can be examined to locate the Java UDF instance.
            // Note this relies on implementation details of Spark's UDFRegistration class.
            val writeReplace = clazz.getDeclaredMethod("writeReplace")
            writeReplace.setAccessible(true)
            val serializedLambda = writeReplace.invoke(f).asInstanceOf[SerializedLambda]
            if (serializedLambda.getCapturedArgCount == 1) {
              serializedLambda.getCapturedArg(0) match {
                case c: DFUDF => Some(c)
                case _ => None
              }
            } else {
              None
            }
          } else {
            None
          }
        } catch {
          case _: ClassCastException | _: NoSuchMethodException | _: SecurityException => None
        }
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy