com.intel.analytics.zoo.friesian.python.PythonFriesian.scala Maven / Gradle / Ivy
/*
* Copyright 2018 Analytics Zoo Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.intel.analytics.zoo.friesian.python
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.zoo.common.PythonZoo
import com.intel.analytics.zoo.friesian.feature.Utils
import java.util.{List => JList}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.types.{ArrayType, IntegerType, DoubleType, StringType, LongType, StructField, StructType}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.functions.{col, row_number, spark_partition_id, udf, log => sqllog, rand}
import scala.reflect.ClassTag
import scala.collection.JavaConverters._
import scala.collection.mutable.WrappedArray
import scala.util.Random
import scala.math.pow
object PythonFriesian {
def ofFloat(): PythonFriesian[Float] = new PythonFriesian[Float]()
def ofDouble(): PythonFriesian[Double] = new PythonFriesian[Double]()
}
class PythonFriesian[T: ClassTag](implicit ev: TensorNumeric[T]) extends PythonZoo[T] {
val numericTypes: List[String] = List("long", "double", "integer")
def fillNa(df: DataFrame, fillVal: Any = 0, columns: JList[String] = null): DataFrame = {
val cols = if (columns == null) {
df.columns
} else {
columns.asScala.toArray
}
val cols_idx = Utils.getIndex(df, cols)
Utils.fillNaIndex(df, fillVal, cols_idx)
}
def fillNaInt(df: DataFrame, fillVal: Int = 0, columns: JList[String] = null): DataFrame = {
val schema = df.schema
val allColumns = df.columns
val cols_idx = if (columns == null) {
schema.zipWithIndex.filter(pair => pair._1.dataType.typeName == "integer")
.map(pair => pair._2)
} else {
val cols = columns.asScala.toList
cols.map(col_n => {
val idx = allColumns.indexOf(col_n)
if (idx == -1) {
throw new IllegalArgumentException(s"The column name ${col_n} does not exist")
}
if (schema(idx).dataType.typeName != "integer") {
throw new IllegalArgumentException(s"Only columns of IntegerType are supported, but " +
s"the type of column ${col_n} is ${schema(idx).dataType.typeName}")
}
idx
})
}
val dfUpdated = df.rdd.map(row => {
val origin = row.toSeq.toArray
for (idx <- cols_idx) {
if (row.isNullAt(idx)) {
origin.update(idx, fillVal)
}
}
Row.fromSeq(origin)
})
val spark = df.sparkSession
spark.createDataFrame(dfUpdated, schema)
}
def generateStringIdx(df: DataFrame, columns: JList[String], frequencyLimit: String = null)
: JList[DataFrame] = {
var default_limit: Option[Int] = None
val freq_map = scala.collection.mutable.Map[String, Int]()
if (frequencyLimit != null) {
val freq_list = frequencyLimit.split(",")
for (fl <- freq_list) {
val frequency_pair = fl.split(":")
if (frequency_pair.length == 1) {
default_limit = Some(frequency_pair(0).toInt)
} else if (frequency_pair.length == 2) {
freq_map += (frequency_pair(0) -> frequency_pair(1).toInt)
}
}
}
val cols = columns.asScala.toList
cols.map(col_n => {
val df_col = df
.select(col_n)
.filter(s"${col_n} is not null")
.groupBy(col_n)
.count()
val df_col_filtered = if (freq_map.contains(col_n)) {
df_col.filter(s"count >= ${freq_map(col_n)}")
} else if (default_limit.isDefined) {
df_col.filter(s"count >= ${default_limit.get}")
} else {
df_col
}
df_col_filtered.cache()
val count_list: Array[(Int, Int)] = df_col_filtered.rdd.mapPartitions(Utils.getPartitionSize)
.collect()
val base_dict = scala.collection.mutable.Map[Int, Int]()
var running_sum = 0
for (count_tuple <- count_list) {
base_dict += (count_tuple._1 -> running_sum)
running_sum += count_tuple._2
}
val base_dict_bc = df_col_filtered.rdd.sparkContext.broadcast(base_dict)
val windowSpec = Window.partitionBy("part_id").orderBy("count")
val df_with_part_id = df_col_filtered.withColumn("part_id", spark_partition_id())
val df_row_number = df_with_part_id.withColumn("row_number", row_number.over(windowSpec))
val get_label = udf((part_id: Int, row_number: Int) => {
row_number + base_dict_bc.value.getOrElse(part_id, 0)
})
df_row_number
.withColumn("id", get_label(col("part_id"), col("row_number")))
.drop("part_id", "row_number", "count")
}).asJava
}
def compute(df: DataFrame): Unit = {
df.rdd.count()
}
def log(df: DataFrame, columns: JList[String], clipping: Boolean = true): DataFrame = {
val colsIdx = Utils.getIndex(df, columns.asScala.toArray)
for(i <- 0 until columns.size()) {
val colName = columns.get(i)
val colType = df.schema(colsIdx(i)).dataType.typeName
if (!Utils.checkColumnNumeric(df, colName)) {
throw new IllegalArgumentException(s"Unsupported data type $colType of column $colName")
}
}
var resultDF = df
val zeroThreshold = (value: Int) => {
if (value < 0) 0 else value
}
val zeroThresholdUDF = udf(zeroThreshold)
for (i <- 0 until columns.size()) {
val colName = columns.get(i)
if (clipping) {
resultDF = resultDF.withColumn(colName, sqllog(zeroThresholdUDF(col(colName)) + 1))
} else {
resultDF = resultDF.withColumn(colName, sqllog(col(colName)))
}
}
resultDF
}
def clip(df: DataFrame, columns: JList[String], min: Any = null, max: Any = null):
DataFrame = {
if (min == null && max == null) {
throw new IllegalArgumentException(s"min and max cannot be both null")
}
var resultDF = df
val cols = columns.asScala.toArray
val colsType = Utils.getIndex(df, cols).map(idx => df.schema(idx).dataType.typeName)
(cols zip colsType).foreach(nameAndType => {
if (!Utils.checkColumnNumeric(df, nameAndType._1)) {
throw new IllegalArgumentException(s"Unsupported data type ${nameAndType._2} of " +
s"column ${nameAndType._1}")
}
})
for(i <- 0 until columns.size()) {
val colName = columns.get(i)
val colType = colsType(i)
val minVal = Utils.castNumeric(min, colType)
val maxVal = Utils.castNumeric(max, colType)
val clipFuncUDF = colType match {
case "long" => udf(Utils.getClipFunc[Long](minVal, maxVal, colType))
case "integer" => udf(Utils.getClipFunc[Int](minVal, maxVal, colType))
case "double" => udf(Utils.getClipFunc[Double](minVal, maxVal, colType))
case _ => throw new IllegalArgumentException(s"Unsupported data type $colType of column" +
s" $colName")
}
resultDF = resultDF.withColumn(colName, clipFuncUDF(col(colName)))
}
resultDF
}
def crossColumns(df: DataFrame,
crossCols: JList[JList[String]],
bucketSizes: JList[Int]): DataFrame = {
def crossColumns(bucketSize: Int) = udf((cols: WrappedArray[Any]) => {
Utils.hashBucket(cols.mkString("_"), bucketSize = bucketSize)
})
var resultDF = df
for (i <- 0 until crossCols.size()) {
resultDF = resultDF.withColumn(crossCols.get(i).asScala.toList.mkString("_"),
crossColumns(bucketSizes.get(i))(
array(crossCols.get(i).asScala.toArray.map(x => col(x)): _*)
))
}
resultDF
}
def addHistSeq(df: DataFrame,
cols: JList[String],
userCol: String,
timeCol: String,
minLength: Int,
maxLength: Int): DataFrame = {
df.sparkSession.conf.set("spark.sql.legacy.allowUntypedScalaUDF", "true")
val colNames: Array[String] = cols.asScala.toArray
val colsWithType = df.schema.fields.filter(x => x.name != userCol)
val schema = ArrayType(StructType(colsWithType.flatMap(c =>
if (colNames.contains(c.name)) {
Seq(c, StructField(c.name + "_hist_seq", ArrayType(c.dataType)))
} else {
Seq(c)
})))
val genHisUDF = udf(f = (his_collect: Seq[Row]) => {
val full_rows: Array[Row] = his_collect.sortBy(x => x.getAs[Long](timeCol)).toArray
val n = full_rows.length
val result: Seq[Row] = (minLength to n - 1).map(i => {
val lowerBound = if (i < maxLength) {
0
} else {
i - maxLength
}
val rowValue: Array[Any] = colsWithType.flatMap(col => {
if (colNames.contains(col.name)) {
col.dataType.typeName match {
case "integer" => Utils.get1row[Int](full_rows, col.name, i, lowerBound)
case "double" => Utils.get1row[Double](full_rows, col.name, i, lowerBound)
case "float" => Utils.get1row[Float](full_rows, col.name, i, lowerBound)
case "long" => Utils.get1row[Long](full_rows, col.name, i, lowerBound)
case _ => throw new IllegalArgumentException(
s"Unsupported data type ${col.dataType.typeName} " +
s"of column ${col.name} in add_hist_seq")
}
} else {
val colValue: Any = full_rows(i).getAs(col.name)
Seq(colValue)
}
})
Row.fromSeq(rowValue)
})
result
}, schema)
val allColumns = colsWithType.map(x => col(x.name))
df.groupBy(userCol).agg(collect_list(struct(allColumns: _*)).as("friesian_his_collect"))
.withColumn("friesian_history", explode(genHisUDF(col("friesian_his_collect"))))
.select(userCol, "friesian_history.*")
}
def mask(df: DataFrame, cols: JList[String], maxLength: Int): DataFrame = {
var maskDF = df
val maskUdf = udf(Utils.maskArr)
cols.asScala.toList.foreach(c => {
maskDF = maskDF.withColumn(c + "_mask", maskUdf(lit(maxLength), col(c)))
})
maskDF
}
def addNegHisSeq(df: DataFrame, itemSize: Int,
historyCol: String,
negNum: Int = 5): DataFrame = {
df.sparkSession.conf.set("spark.sql.legacy.allowUntypedScalaUDF", "true")
val itemType = df.select(explode(col(historyCol))).schema.fields(0).dataType
require(itemType.typeName == "integer", throw new IllegalArgumentException(
s"Unsupported data type ${itemType.typeName} " +
s"of column ${historyCol} in add_neg_hist_seq"))
val schema = ArrayType(ArrayType(itemType))
val negativeUdf = udf(Utils.addNegativeList(negNum, itemSize), schema)
df.withColumn("neg_" + historyCol, negativeUdf(col(historyCol)))
}
def addNegSamples(df: DataFrame,
itemSize: Int,
itemCol: String = "item",
labelCol: String = "label",
negNum: Int = 1): DataFrame = {
df.sparkSession.conf.set("spark.sql.legacy.allowUntypedScalaUDF", "true")
val itemType = df.select(itemCol).schema.fields(0).dataType
require(itemType.typeName == "integer", throw new IllegalArgumentException(
s"Unsupported data type ${itemType.typeName} " +
s"of column ${itemCol} in add_negative_samples"))
val schema = ArrayType(StructType(Seq(StructField(itemCol, itemType),
StructField(labelCol, itemType))))
val negativeUdf = udf(Utils.addNegtiveItem(negNum, itemSize), schema)
val negativedf = df.withColumn("item_label", explode(negativeUdf(col(itemCol))))
val selectColumns = df.columns.filter(x => x != itemCol)
.map(ele => col(ele)) ++ Seq(col("item_label.*"))
negativedf.select(selectColumns: _*)
}
def postPad(df: DataFrame, cols: JList[String], maxLength: Int = 100): DataFrame = {
val colFields = df.schema.fields.filter(x => cols.contains(x.name))
var paddedDF = df
colFields.foreach(c => {
val dataType = c.dataType
val padUdf = dataType match {
case ArrayType(IntegerType, _) => udf(Utils.padArr[Int])
case ArrayType(LongType, _) => udf(Utils.padArr[Long])
case ArrayType(DoubleType, _) => udf(Utils.padArr[Double])
case ArrayType(ArrayType(IntegerType, _), _) => udf(Utils.padMatrix[Int])
case ArrayType(ArrayType(LongType, _), _) => udf(Utils.padMatrix[Long])
case ArrayType(ArrayType(DoubleType, _), _) => udf(Utils.padMatrix[Double])
case _ => throw new IllegalArgumentException(
s"Unsupported data type $dataType of column $c in pad")
}
paddedDF = paddedDF.withColumn(c.name, padUdf(lit(maxLength), col(c.name)))
})
paddedDF
}
def fillMedian(df: DataFrame, columns: JList[String] = null): DataFrame = {
val cols = if (columns == null) {
df.columns.filter(column => Utils.checkColumnNumeric(df, column))
} else {
columns.asScala.toArray
}
val colsIdx = Utils.getIndex(df, cols)
val medians = Utils.getMedian(df, cols)
val idxMedians = (colsIdx zip medians).map(idxMedian => {
if (idxMedian._2 == null) {
throw new IllegalArgumentException(
s"Cannot compute the median of column ${cols(idxMedian._1)} " +
s"since it contains only null values.")
}
val colType = df.schema(idxMedian._1).dataType.typeName
colType match {
case "long" => (idxMedian._1, idxMedian._2.asInstanceOf[Double].longValue)
case "integer" => (idxMedian._1, idxMedian._2.asInstanceOf[Double].intValue)
case "double" => (idxMedian._1, idxMedian._2.asInstanceOf[Double])
case _ => throw new IllegalArgumentException(
s"Unsupported value type $colType of column ${cols(idxMedian._1)}.")
}
})
val dfUpdated = df.rdd.map(row => {
val origin = row.toSeq.toArray
for ((idx, fillV) <- idxMedians) {
if (row.isNullAt(idx)) {
origin.update(idx, fillV)
}
}
Row.fromSeq(origin)
})
val spark = df.sparkSession
spark.createDataFrame(dfUpdated, df.schema)
}
/* ---- Stat Operator ---- */
def median(df: DataFrame, columns: JList[String] = null, relativeError: Double = 0.00001):
DataFrame = {
val cols = if (columns == null) {
df.columns.filter(column => Utils.checkColumnNumeric(df, column))
} else {
columns.asScala.toArray
}
Utils.getIndex(df, cols) // checks if `columns` exist in `df`
val medians = Utils.getMedian(df, cols, relativeError)
val medians_data = (cols zip medians).map(cm => Row.fromSeq(Array(cm._1, cm._2)))
val spark = df.sparkSession
val schema = StructType(Array(
StructField("column", StringType, nullable = true),
StructField("median", DoubleType, nullable = true)
))
spark.createDataFrame(spark.sparkContext.parallelize(medians_data), schema)
}
def ordinalShufflePartition(df: DataFrame): DataFrame = {
val shuffledDF = df.withColumn("ordinal", (rand() * pow(2, 52)).cast(LongType))
.sortWithinPartitions(col("ordinal")).drop(col("ordinal"))
shuffledDF
}
def dfWriteParquet(df: DataFrame, path: String, mode: String = "overwrite"): Unit = {
df.write.mode(mode).parquet(path)
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy