All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.datastax.data.prepare.spark.dataset.StockOperation.scala Maven / Gradle / Ivy

package com.datastax.data.prepare.spark.dataset

import org.apache.spark.sql._
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._

import scala.collection.mutable

/**
  * 用于证监会股票预处理,原本用UDD组件,但由于流程在CDH集群上跑,UDD组件的jar路径放在本地,CDH不能获取
  */
object StockOperation extends Serializable {

  @Deprecated
  def mergeDays[T](data: Dataset[T], dateCol: String, idCol: String, resultCol: String, days: Int): Dataset[Row] = {
    require(days >= 0)
    require(dateCol != null && dateCol.trim.length != 0)
    require(idCol != null && idCol.trim.length != 0)
    require(resultCol != null && resultCol.trim.length != 0)
    val mergeDate = functions.udf((dates: Seq[String], ids: mutable.Seq[Seq[String]]) => {
      if(days <= 1 || days > dates.length) {
        ids
      } else {
        val buffer1 = Seq.newBuilder[Seq[String]]
        Range(0, dates.length - days + 1).foreach(i => {
          var buffer2 = Seq.newBuilder[String]
          Range(0, days).foreach(j => {
            buffer2 ++= ids(i + j)
          })
          buffer1 += buffer2.result()
          buffer2.clear()
        })
        val result = buffer1.result()
        buffer1.clear()
        result
      }
    })
    data.withColumn(resultCol, mergeDate(col(dateCol), col(idCol)))
  }

  def dataPrepare(data: DataFrame, threshold: Double, codeCol: String = "Sec_code", idCol: String = "shr_acct", dateCol: String = "Trad_date", typeCol: String = "Trad_dirc", tradeCol: String = "Trad_vol"): DataFrame = {
    val filterData = filter(data, tradeCol, threshold)
    val collectListData = collectList(filterData, Array(codeCol, dateCol, typeCol), Array(idCol))
    val reNameData1 = rename(collectListData, "collect_list(" + idCol + ")", idCol)
    val windowData = window(reNameData1, codeCol, dateCol, typeCol, dateCol, idCol, 2)
    val reNameData2 = rename(windowData, "collect_list(" + idCol + ")", idCol)
    val multiArrayToArrayData = multiArray2Array(reNameData2, idCol)
    multiArrayToArrayData
  }


  def filter(data: DataFrame, filterCol: String, value: Double): DataFrame = data.filter(_.getAs[String](filterCol).toDouble >= value)

  def collectList(data: DataFrame, groupCols: Array[String], aggCols: Array[String]): DataFrame = {
    val groups = groupCols.filter(s => s != null && !s.trim.isEmpty).map(col)
    val aggs = aggCols.filter(s => s != null && !s.trim.isEmpty).map(_ -> "collect_list").toMap
    require(groups.length != 0, "collect_list的groups为空")
    require(aggs.nonEmpty, "collect_list的aggs为空")
    data.groupBy(groups: _*).agg(aggs)
  }

  def sort(data: DataFrame, column: String): DataFrame = {
    require(column != null && !column.trim.isEmpty, "sort的col为空")
    data.sort(column)
  }

  def explodeWrap(data: DataFrame, column: String): DataFrame = {
    require(column != null && !column.trim.isEmpty, "sort的col为空")
    data.withColumn(column, explode(col(column)))
  }

  def window(data: DataFrame, partitionCol1: String, partitionCol2: String, partitionCol3: String, orderCol1:String, windowCol: String, row: Int): DataFrame = {
    val window = Window.partitionBy(partitionCol1, partitionCol2, partitionCol3).orderBy(orderCol1).rowsBetween(Window.currentRow, row - 1)
    data.withColumn(windowCol, collect_list(col(windowCol)) over window)
  }

  def multiArray2Array(data: DataFrame, arrayCol: String): DataFrame = {
    val arrayConcat = udf((value: Seq[Seq[String]]) => {
      val set = mutable.HashSet[String]()
      value.foreach(t1 => t1.foreach(set += _))
      set.toSeq
    })
    data.withColumn(arrayCol, arrayConcat(col(arrayCol)))
  }

  def rename(data: DataFrame, oldCol: String, newCol: String): DataFrame = data.withColumnRenamed(oldCol, newCol)


}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy