
tech.mlsql.plugins.ets.JsonExpandExt.scala Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package tech.mlsql.plugins.ets
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.expressions.UserDefinedFunction
import streaming.dsl.mmlib.{Code, Doc, HtmlDoc, ModelType, ProcessType, SQLAlg, SQLCode}
import streaming.dsl.mmlib.algs.param.WowParams
import streaming.log.WowLog
import tech.mlsql.common.utils.log.Logging
import org.apache.spark.sql.functions._
import streaming.dsl.ScriptSQLExec
import streaming.dsl.auth.OperateType.SELECT
import streaming.dsl.auth.TableType.SYSTEM
import streaming.dsl.auth.{DB_DEFAULT, MLSQLTable, TableAuthResult}
import tech.mlsql.dsl.auth.ETAuth
import tech.mlsql.dsl.auth.dsl.mmlib.ETMethod.ETMethod
class JsonExpandExt (override val uid: String) extends SQLAlg with WowParams with Logging with WowLog with ETAuth {
def this() = this(Identifiable.randomUID("tech.mlsql.plugins.ets.JsonExpandExt"))
final val inputCol: Param[String] = new Param[String](parent = "JsonExpandExt"
, name = "inputCol"
, doc = """Required. Json column to be expanded
|e.g. WHERE inputCol = col_1 """.stripMargin
)
final val samplingRatio: Param[String] = new Param[String](parent = "JsonExpandExt"
, name = "samplingRatio"
, doc = """Optional. SamplingRatio used by Spark to infer schema from json, 1.0 by default.
|e.g. WHERE sampleRatio = "0.2" """.stripMargin)
/**
* Expands a json column to multiple columns. Json column name is addressed by parameter inputCol's value
* @param df Input dataframe.
* @param path Not used
* @param params Input parameters, must contain inputCol
* @return The dataframe with non-json columns and json expanded columns
* @throws IllegalArgumentException if inputCol is not present in params
*/
override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
// Parameter checking
require(params.contains(inputCol.name),
"inputCol is required. e.g. inputCol = col_1, where col_1 is the json column ")
val spark = df.sparkSession
import spark.implicits._
params.get(inputCol.name) match {
case Some(col) =>
val ratio = params.getOrElse(samplingRatio.name, "1.0").toDouble
logInfo( format(s"samplingRatio ${ratio}") )
// spark.sql.AnalysisException thrown if col name was wrong.
val colValueDs = df.select( col ).as[String]
// Infers schema from json
val schema = spark.read
.option( "samplingRatio", ratio )
.json(colValueDs).schema
// Throws RuntimeException if json is invalid
if( schema.exists(_.name == spark.sessionState.conf.columnNameOfCorruptRecord ) )
throw new RuntimeException(s"Corrupted JSON in column ${col}")
// Expand json and return a new DataFrame including non-json columns
val expandedCols = schema.fields.map( _.name )
val originalCols = df.schema.map(_.name)
df.select( df("*"), json_tuple(df(col), expandedCols: _*) )
.toDF( (originalCols ++ expandedCols): _* )
.drop(col)
case None => df // Should not happen
}
}
/**
* Wraps this plugin as a {@link streaming.dsl.auth.MLSQLTable} and delegates authorization to implementation of
* {@link streaming.dsl.auth.TableAuth}
*
* @param etMethod
* @param path
* @param params
* @return
*/
override def auth(etMethod: ETMethod, path: String, params: Map[String, String]): List[TableAuthResult] = {
// Parameter checking
require(params.contains(inputCol.name),
"inputCol is required. e.g. inputCol = col_1, where col_1 is the json column ")
val table = MLSQLTable(db = Some(DB_DEFAULT.MLSQL_SYSTEM.toString)
, table = Some("__json_expand_operator__")
, columns = None
, operateType = SELECT
, sourceType = Some("select")
, tableType = SYSTEM
)
val context = ScriptSQLExec.contextGetOrForTest()
context.execListener.getTableAuth match {
case Some(tableAuth) => tableAuth.auth(List(table))
case None => List(TableAuthResult(granted = true, ""))
}
}
override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = {}
override def predict(sparkSession: SparkSession, _model: Any, name: String, params: Map[String, String])
: UserDefinedFunction = ???
override def doc: Doc = Doc(HtmlDoc,
"""
| JsonExpandExt is used to expand json strings, please
| see the codeExample to learn its usage.
|
| Use "load modelExample.`JsonExpandExt` as output;"
| to see the codeExample.
""".stripMargin)
override def codeExample: Code = Code(SQLCode,
"""
|```sql
|## Generate a table named "table_1" with one json column "col_1"
|SELECT '{"key": "value", "key_2":"value_2"}' AS col_1 AS table_1;
|
|## Expand json from col_1, please note that there are 2 columns
|## in the result set
|run table_1 as JsonExpandExt.`` where inputCol="col_1" AND samplingRatio = "0.5" as A2;
|```
|output:
|```
|+---------------+
|key |key_2 |
|+---------------+
|value |value_2 |
|+---------------+
|```
|""".stripMargin)
/**
* Explanation to each parameter's name and doc
* @param sparkSession
* @return
*/
override def explainParams(sparkSession: SparkSession): DataFrame = {
_explainParams(sparkSession)
}
override def modelType: ModelType = ProcessType
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy