streaming.dsl.mmlib.algs.SQLTreeBuildExt.scala Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package streaming.dsl.mmlib.algs
import net.liftweb.json.NoTypeHints
import net.liftweb.{json => SJSon}
import org.apache.spark.ml.param.Param
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.mlsql.session.MLSQLException
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SparkSession, functions => F}
import streaming.dsl.mmlib._
import streaming.dsl.mmlib.algs.param.{BaseParams, WowParams}
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
/**
* 2018-12-12 WilliamZhu([email protected])
*/
class SQLTreeBuildExt(override val uid: String) extends SQLAlg with Functions with WowParams {
def this() = this(BaseParams.randomUID())
override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
params.get(idCol.name).
map(m => set(idCol, m)).getOrElse {
throw new MLSQLException("idCol is required")
}
params.get(parentIdCol.name).
map(m => set(parentIdCol, m)).getOrElse {
throw new MLSQLException("parentIdCol is required")
}
params.get(topLevelMark.name).
map(m => set(topLevelMark, m)).getOrElse {
set(topLevelMark, null)
}
params.get(treeType.name).
map(m => set(treeType, m)).getOrElse {
set(treeType, "treePerRow")
}
params.get(recurringDependencyBreakTimes.name).
map(m => set(recurringDependencyBreakTimes, m.toInt)).getOrElse {
set(recurringDependencyBreakTimes, 1000)
}
val maxTimes = $(recurringDependencyBreakTimes)
val parentIdColType = df.schema.filter(f => f.name == $(parentIdCol)).head
val t = if ($(topLevelMark) != null) {
parentIdColType.dataType match {
case s: IntegerType => $(topLevelMark).toInt
case s: LongType => $(topLevelMark).toLong
case s: DoubleType => $(topLevelMark).toDouble
case s: ShortType => $(topLevelMark).toShort
case _ => $(topLevelMark)
}
} else {
null
}
val items = df.select($(idCol), $(parentIdCol)).distinct().rdd.filter(row => row.get(0) != row.get(1)).map { row => IDParentID(row.get(0), row.get(1), ArrayBuffer()) }.collect()
val ROOTS = ArrayBuffer[IDParentID]()
val tempMap = scala.collection.mutable.HashMap[Any, Int]()
val itemsWithIndex = items.zipWithIndex
itemsWithIndex.foreach { case (item, index) =>
tempMap(item.id) = index
}
itemsWithIndex.foreach { case (item, index) =>
if (item.parentID != null || item.parentID != t) {
items(tempMap(item.parentID)).children += item
} else {
ROOTS += item
}
}
implicit val formats = SJSon.Serialization.formats(NoTypeHints)
val rdd = df.sparkSession.sparkContext.parallelize(ROOTS.map(f => SJSon.Serialization.write(f)))
var newdf = df.sparkSession.read.json(rdd)
val computeLevel1 = (a: Seq[Row], level: Int) => {
val computeLevel = new ((Seq[Row], Int) => Int) {
def apply(a: Seq[Row], level: Int): Int = {
if (a.size == 0) return level
if (level > maxTimes) return level
return a.map { row =>
val index = a.head.schema.zipWithIndex.filter(s => s._1.name == "children").head._2
val value = row.getSeq[Row](index)
apply(value, level + 1)
}.max
}
}
computeLevel(a, level)
}
val computeLevelUDF = F.udf(computeLevel1, IntegerType)
newdf = newdf.withColumn("level", computeLevelUDF(F.col("children"), F.lit(0)))
$(treeType) match {
case "treePerRow" => newdf
case "nodeTreePerRow" =>
val rdd = df.sparkSession.sparkContext.parallelize(items.toSeq).map { item =>
val resultset = new mutable.HashSet[IDParentID]()
var level = 0
val collectAll = new ((IDParentID) => Int) {
def apply(a: IDParentID): Int = {
if (a.children.size == 0) {
resultset += a
1
} else {
resultset ++= a.children
a.children.map(f => apply(f)).sum
}
}
}
val computeLevel = new ((IDParentID, Int) => Int) {
def apply(a: IDParentID, level: Int): Int = {
if (a.children.size == 0) return level
if (level > maxTimes) return level
return a.children.map { row =>
apply(row, level + 1)
}.max
}
}
level = computeLevel(item, 0)
if (level < maxTimes) {
if (item.children.size > 0) {
collectAll(item)
}
}
Row.fromSeq(Seq(item.id.toString, level, resultset.map(f => f.id.toString).toSeq))
}
df.sparkSession.createDataFrame(rdd, StructType(Seq(
StructField(name = "id", dataType = StringType), StructField(name = "level", dataType = IntegerType), StructField(name = "children", dataType = ArrayType(StringType))
)))
}
}
override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = {
throw new RuntimeException("register is not support by this estimator/transformer")
}
override def predict(sparkSession: SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = {
throw new RuntimeException("register is not support by this estimator/transformer")
}
override def batchPredict(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
train(df, path, params)
}
override def explainParams(sparkSession: SparkSession): DataFrame = _explainParams(sparkSession)
override def modelType: ModelType = ProcessType
override def doc: Doc = Doc(HtmlDoc,
"""
| TreeBuildExt used to build a tree when you have father - child relationship in some table,
| please check the codeExample to see how to use it.
""".stripMargin)
override def codeExample: Code = Code(SQLCode, CodeExampleText.jsonStr +
"""
|```sql
|set jsonStr = '''
|{"id":0,"parentId":null}
|{"id":1,"parentId":null}
|{"id":2,"parentId":1}
|{"id":3,"parentId":3}
|{"id":7,"parentId":0}
|{"id":199,"parentId":1}
|{"id":200,"parentId":199}
|{"id":201,"parentId":199}
|''';
|
|load jsonStr.`jsonStr` as data;
|run data as TreeBuildExt.`` where idCol="id" and parentIdCol="parentId" and treeType="nodeTreePerRow" as result;
|```
|
|Here are the result:
|
|```
|+---+-----+------------------+
||id |level|children |
|+---+-----+------------------+
||200|0 |[] |
||0 |1 |[7] |
||1 |2 |[200, 2, 201, 199]|
||7 |0 |[] |
||201|0 |[] |
||199|1 |[200, 201] |
||2 |0 |[] |
|+---+-----+------------------+
|```
|
|Notice that we will convert the id to string in final result. That means id is string type, and children are array of
|string and you should be careful when comparing.
|
|The max level should lower than 1000(You can set by parameter recurringDependencyBreakTimes).
|When you found some rows are weired, the level >= 1000 and the children is empty, this means
|there are recurring dependency and we can not deal with this situation yet.
|
|Here is the example:
|
|```
|+---+-----+------------------+
||id |level|children |
|+---+-----+------------------+
||7 |1000 |[] |
|```
|
|if treeType == treePerRow
|
|then the result is :
|
|+----------------------------------------+---+--------+-----+
||children |id |parentID|level|
|+----------------------------------------+---+--------+-----+
||[[[], 7, 0]] |0 |null |1 |
||[[[[[], 200, 199]], 199, 1], [[], 2, 1]]|1 |null |2 |
|+----------------------------------------+---+--------+-----+
|
|Notice that children's datatype is Row, you can change it to json so you can use python to deal with it.
|
""".stripMargin)
override def coreCompatibility: Seq[CoreVersion] = super.coreCompatibility
final val idCol: Param[String] = new Param[String](this, "idCol", "")
final val parentIdCol: Param[String] = new Param[String](this, "parentIdCol", "")
final val topLevelMark: Param[String] = new Param[String](this, "topLevelMark", "")
final val treeType: Param[String] = new Param[String](this, "treeType", "treePerRow|nodeTreePerRow")
final val recurringDependencyBreakTimes: Param[Int] = new Param[Int](this, "recurringDependencyBreakTimes",
"default:1000 the max level should lower than this value; " +
"When travel a tree, once a node is found two times, then the subtree will be ignore")
}
case class IDParentID(id: Any, parentID: Any, children: scala.collection.mutable.ArrayBuffer[IDParentID])