All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.datastax.data.prepare.spark.dataset.StockBasicCompute.scala Maven / Gradle / Ivy

The newest version!
package com.datastax.data.prepare.spark.dataset

import java.net.URI

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.ml.fpm.FPGrowth
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.{DataFrame, Dataset, SaveMode, SparkSession}

import scala.collection.mutable.ListBuffer

object StockBasicCompute extends Serializable {

  //原始的层次聚类
  def hcCompute(spark: SparkSession, df: DataFrame, xCol: String, yCol: String, simCol: String,
              threshold: Double, minSim: Double, tempPath: String): DataFrame = {
    val uri = new URI(tempPath)
    val path = new Path(uri.getPath)
    val fs = FileSystem.get(uri, new Configuration())

    import spark.implicits._
    spark.sparkContext.setCheckpointDir(tempPath + "/_hc_checkpoint")
    val df1 = df.select(col(xCol), col(yCol), col(simCol).cast(DoubleType)).as[(String, String, Double)]
    val accu = spark.sparkContext.longAccumulator  //用于判断数据集是否为空
    var maxRow = df1.reduce((s1, s2) => {
      accu.add(1)
      if(s1._3 >= s2._3) s1 else s2
    })
    var df2 = df1
    var i = 0
    println("df2 begin count: " + df2.count())
    while(maxRow._3 >= threshold && accu.value > 0) {
      println("maxRow:(" + maxRow._1 + ", " + maxRow._2 + ", " + maxRow._3 + ")")
      val start = System.currentTimeMillis()
      val df3 = loop(spark, df2, maxRow, xCol, yCol, simCol).rdd

      //清空存档文件
      if((i + 1) % 200 == 0) {
        println("clear checkpoint files begin...")
        if(fs.exists(path)) {
          fs.delete(path, true)
        }
        println("clear checkpoint files end")
      }

      //rdd存档
      if((i + 1) % 20 == 0 ) {
        df3.checkpoint()
      }

      //为什么要重新create?因为没有create, 每次循环所用的时间都会增加,感觉作用应该和 checkpoint 一样 _( '-' _)⌒)_
      df2 = spark.createDataset(df3)
        .withColumnRenamed("_1", xCol)
        .withColumnRenamed("_2", yCol)
        .withColumnRenamed("_3", simCol)
        .as[(String, String, Double)]

      accu.reset()
      maxRow = df2.reduce((s1, s2) => {
        accu.add(1)
        if(s1._3 >= s2._3) s1 else s2
      })

      val end = System.currentTimeMillis()
      println("loop(" + i + ") time: " + (end - start))
      i += 1

    }
    if(accu.value == 0) {
      df2 = Seq((maxRow._1 + "_" + maxRow._2, "", maxRow._3)).toDF(xCol, yCol, simCol).as[(String, String, Double)]
    }
    println("df2 result count: " + df2.count())
//    fs.close()

    df2.filter(col(simCol) >= minSim and (col(xCol).contains("_") or col(yCol).contains("_")))
      .select(col(simCol), concat_ws(";", col(xCol), col(yCol)))
  }

  private def loop(spark: SparkSession, df2: Dataset[(String, String, Double)], maxRow: (String, String, Double), xCol: String, yCol: String, simCol: String): Dataset[(String, String, Double)] = {
    import spark.implicits._
    val t1 = maxRow._1 + "_" + maxRow._2
    val df3 = df2.filter(a => !(a._1.equals(maxRow._1) && a._2.equals(maxRow._2))).mapPartitions(as => {
      as.map(a => {
        if(a._1.equals(maxRow._1) || a._1.equals(maxRow._2)) {
          (t1, a._2, a._3)
        } else if(a._2.equals(maxRow._1) || a._2.equals(maxRow._2)) {
          (t1, a._1, a._3)
        } else {
          (a._1, a._2, a._3 * 2)
        }
      })
    }).groupBy(col("_1"), col("_2")).agg(sum(col("_3")).divide(2)).toDF(xCol, yCol, simCol).as[(String, String, Double)]
    df3
  }


  def fpgCompute(spark: SparkSession, data: DataFrame, groupCol: String, targetCol: String, minSupport: Double,
                 numPartitions: Int, minFreq: Long, p: Double, minItems: Int, path: String): Unit = {
    import spark.implicits._
    println("FPGrowth开始....")
    require(data != null)
    require(groupCol != null && groupCol.trim.length != 0)
    require(targetCol != null && targetCol.trim.length != 0)
    require(minSupport > 0.0 && minSupport <= 1.0)
    if(!checkExist(data.schema.fieldNames, groupCol, targetCol)) {
      throw new IllegalArgumentException("数据集中不存在参数中的一个或者多个列名")
    }

    val df = data.select(groupCol, targetCol)
    val fpg = new FPGrowth().setItemsCol(targetCol).setMinSupport(minSupport)
    if(numPartitions > 0) {
      fpg.setNumPartitions(numPartitions)
    }
    val groups = df.select(groupCol).distinct().map(_.getString(0)).collect()
    if(groups.length != 0) {
      groups.filter(_ != null).foreach(g => {
        try {
          val start = System.currentTimeMillis()
          println(g + "开始执行FPGrowth...")
          val t = df.where(col(groupCol).equalTo(g)).select(targetCol)
          val model = fpg.fit(t)
          println("FPGrowth 结束, 开始合并频繁集...")
          val freqItemsets = model.freqItemsets.filter(col("freq") >= minFreq)  //items freq
          val result = merge(spark, freqItemsets.select("items"), p, minItems)
          result.withColumn("stock", lit(g)).select(col("stock"), concat_ws(";", col("class_member")).alias("class_member")).repartition(1).write.option("header", true).mode(SaveMode.Overwrite).csv(path + "/" + g)
//          freqItemsets.select(lit(g), concat_ws(";", col("items")), col("freq")).repartition(4).write.mode(SaveMode.Overwrite).csv(path + "/pre-merge/" + g)
          println(g + " 执行结束")
          val end = System.currentTimeMillis()
          println(g + " 执行FPGrowth消耗的时间为:" + (end - start))
          println("====================")
        } catch {
          case e: Exception => {
            println(g + " 出错")
            e.printStackTrace()
          }
        }

      })
    } else {
      println(groupCol + "数据为空")
    }
    println("FPGrowth结束")
  }

  def merge(spark: SparkSession, data: DataFrame, p: Double, minItems: Int): DataFrame = {
    import spark.implicits._
//    val bc = spark.sparkContext.broadcast(new FPMerge(p))
//    data.as[Seq[String]].foreach(s => bc.value.add(s))
//    val result = bc.value.getList()
//    bc.unpersist()
//    result.toDF("class_member").filter(_.getAs[Seq[String]](0).length >= minItems)

    val rdd = data.as[Seq[String]].filter(_.length != 1).rdd
    val result = rdd.repartition(1).sortBy(s => s.length, false).aggregate(new FPMerge(p))((fp, seq) => fp.add(seq), (fp1, fp2) => fp1.merge(fp2)).getList()
    result.toDF("class_member").filter(_.getAs[Seq[String]](0).length >= minItems)

  }

  private def checkExist(fields: Array[String], cols: String*): Boolean = {
    cols.foreach(c => {
      if(!fields.contains(c)) {
        false
      }
    })
    true
  }

}



class FPMerge(p: Double) extends Serializable {

  val buffer = new ListBuffer[Seq[String]]

  def add(seq: Seq[String]): this.type = {
    if(seq.nonEmpty) {
      if(buffer.isEmpty) {
        buffer.append(seq)
      } else {
        var i = 0
        var flag = true
        var flag2 = false
        var unionSet: Seq[String] = null
        while(i < buffer.size && flag) {
          val s = buffer.apply(i)
          unionSet = (s ++ seq).distinct
          if(s.length != unionSet.length) {
            if(s.length + seq.length - unionSet.length >= p * seq.length) {
              flag = false
            } else {
              i += 1
            }
          } else {
            flag2 = true
            flag = false
          }

        }
        if(!flag2) {
          if(!flag) {
            buffer.update(i, unionSet)
          } else {
            buffer.append(seq)
          }
        }

      }

    }
    this

  }

  def merge(other: FPMerge): this.type = {
    other.buffer.foreach(s => add(s))
    this
  }

  def getList(): Seq[Seq[String]] = {
    buffer.toList
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy