Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
com.datastax.data.prepare.spark.dataset.StockOperation.scala Maven / Gradle / Ivy
package com.datastax.data.prepare.spark.dataset
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import scala.collection.mutable
/**
* 用于证监会股票预处理,原本用UDD组件,但由于流程在CDH集群上跑,UDD组件的jar路径放在本地,CDH不能获取
*/
object StockOperation extends Serializable {
@Deprecated
def mergeDays[T](data: Dataset[T], dateCol: String, idCol: String, resultCol: String, days: Int): Dataset[Row] = {
require(days >= 0)
require(dateCol != null && dateCol.trim.length != 0)
require(idCol != null && idCol.trim.length != 0)
require(resultCol != null && resultCol.trim.length != 0)
val mergeDate = functions.udf((dates: Seq[String], ids: mutable.Seq[Seq[String]]) => {
if(days <= 1 || days > dates.length) {
ids
} else {
val buffer1 = Seq.newBuilder[Seq[String]]
Range(0, dates.length - days + 1).foreach(i => {
var buffer2 = Seq.newBuilder[String]
Range(0, days).foreach(j => {
buffer2 ++= ids(i + j)
})
buffer1 += buffer2.result()
buffer2.clear()
})
val result = buffer1.result()
buffer1.clear()
result
}
})
data.withColumn(resultCol, mergeDate(col(dateCol), col(idCol)))
}
def dataPrepare(data: DataFrame, threshold: Double, codeCol: String = "Sec_code", idCol: String = "shr_acct", dateCol: String = "Trad_date", typeCol: String = "Trad_dirc", tradeCol: String = "Trad_vol"): DataFrame = {
val filterData = filter(data, tradeCol, threshold)
val collectListData = collectList(filterData, Array(codeCol, dateCol, typeCol), Array(idCol))
val reNameData1 = rename(collectListData, "collect_list(" + idCol + ")", idCol)
val windowData = window(reNameData1, codeCol, dateCol, typeCol, dateCol, idCol, 2)
val reNameData2 = rename(windowData, "collect_list(" + idCol + ")", idCol)
val multiArrayToArrayData = multiArray2Array(reNameData2, idCol)
multiArrayToArrayData
}
def filter(data: DataFrame, filterCol: String, value: Double): DataFrame = data.filter(_.getAs[String](filterCol).toDouble >= value)
def collectList(data: DataFrame, groupCols: Array[String], aggCols: Array[String]): DataFrame = {
val groups = groupCols.filter(s => s != null && !s.trim.isEmpty).map(col)
val aggs = aggCols.filter(s => s != null && !s.trim.isEmpty).map(_ -> "collect_list").toMap
require(groups.length != 0, "collect_list的groups为空")
require(aggs.nonEmpty, "collect_list的aggs为空")
data.groupBy(groups: _*).agg(aggs)
}
def sort(data: DataFrame, column: String): DataFrame = {
require(column != null && !column.trim.isEmpty, "sort的col为空")
data.sort(column)
}
def explodeWrap(data: DataFrame, column: String): DataFrame = {
require(column != null && !column.trim.isEmpty, "sort的col为空")
data.withColumn(column, explode(col(column)))
}
def window(data: DataFrame, partitionCol1: String, partitionCol2: String, partitionCol3: String, orderCol1:String, windowCol: String, row: Int): DataFrame = {
val window = Window.partitionBy(partitionCol1, partitionCol2, partitionCol3).orderBy(orderCol1).rowsBetween(Window.currentRow, row - 1)
data.withColumn(windowCol, collect_list(col(windowCol)) over window)
}
def multiArray2Array(data: DataFrame, arrayCol: String): DataFrame = {
val arrayConcat = udf((value: Seq[Seq[String]]) => {
val set = mutable.HashSet[String]()
value.foreach(t1 => t1.foreach(set += _))
set.toSeq
})
data.withColumn(arrayCol, arrayConcat(col(arrayCol)))
}
def rename(data: DataFrame, oldCol: String, newCol: String): DataFrame = data.withColumnRenamed(oldCol, newCol)
}