All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.sona.ml.feature.SQLTransformer.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.tencent.angel.sona.ml.feature

import com.tencent.angel.sona.ml.Transformer
import com.tencent.angel.sona.ml.param.{Param, ParamMap}
import com.tencent.angel.sona.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.types.StructType

/**
  * Implements the transformations which are defined by SQL statement.
  * Currently we only support SQL syntax like 'SELECT ... FROM __THIS__ ...'
  * where '__THIS__' represents the underlying table of the input dataset.
  * The select clause specifies the fields, constants, and expressions to display in
  * the output, it can be any select clause that Spark SQL supports. Users can also
  * use Spark SQL built-in function and UDFs to operate on these selected columns.
  * For example, [[SQLTransformer]] supports statements like:
  * {{{
  *  SELECT a, a + b AS a_b FROM __THIS__
  *  SELECT a, SQRT(b) AS b_sqrt FROM __THIS__ where a > 5
  *  SELECT a, b, SUM(c) AS c_sum FROM __THIS__ GROUP BY a, b
  * }}}
  */
class SQLTransformer(override val uid: String) extends Transformer
  with DefaultParamsWritable {

  def this() = this(Identifiable.randomUID("sql"))

  /**
    * SQL statement parameter. The statement is provided in string form.
    *
    * @group param
    */
  final val statement: Param[String] = new Param[String](this, "statement", "SQL statement")

  /** @group setParam */
  def setStatement(value: String): this.type = set(statement, value)

  /** @group getParam */
  def getStatement: String = $(statement)

  private val tableIdentifier: String = "__THIS__"

  override def transform(dataset: Dataset[_]): DataFrame = {
    transformSchema(dataset.schema, logging = true)
    val tableName = Identifiable.randomUID(uid)
    dataset.createOrReplaceTempView(tableName)
    val realStatement = $(statement).replace(tableIdentifier, tableName)
    val result = dataset.sparkSession.sql(realStatement)
    // Call SessionCatalog.dropTempView to avoid unpersisting the possibly cached dataset.
    dataset.sparkSession.catalog.dropTempView(tableName)
    // Compatible.sessionstate.catalog.dropTempView(tableName)
    result
  }

  override def transformSchema(schema: StructType): StructType = {
    val spark = SparkSession.builder().getOrCreate()
    val dummyRDD = spark.sparkContext.parallelize(Seq(Row.empty))
    val dummyDF = spark.createDataFrame(dummyRDD, schema)
    val tableName = Identifiable.randomUID(uid)
    val realStatement = $(statement).replace(tableIdentifier, tableName)
    dummyDF.createOrReplaceTempView(tableName)
    val outputSchema = spark.sql(realStatement).schema
    spark.catalog.dropTempView(tableName)
    outputSchema
  }

  override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra)
}


object SQLTransformer extends DefaultParamsReadable[SQLTransformer] {
  override def load(path: String): SQLTransformer = super.load(path)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy