com.datastax.data.prepare.spark.dataset.StockBasicCompute.scala Maven / Gradle / Ivy
The newest version!
package com.datastax.data.prepare.spark.dataset
import java.net.URI
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.ml.fpm.FPGrowth
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.{DataFrame, Dataset, SaveMode, SparkSession}
import scala.collection.mutable.ListBuffer
object StockBasicCompute extends Serializable {
//原始的层次聚类
def hcCompute(spark: SparkSession, df: DataFrame, xCol: String, yCol: String, simCol: String,
threshold: Double, minSim: Double, tempPath: String): DataFrame = {
val uri = new URI(tempPath)
val path = new Path(uri.getPath)
val fs = FileSystem.get(uri, new Configuration())
import spark.implicits._
spark.sparkContext.setCheckpointDir(tempPath + "/_hc_checkpoint")
val df1 = df.select(col(xCol), col(yCol), col(simCol).cast(DoubleType)).as[(String, String, Double)]
val accu = spark.sparkContext.longAccumulator //用于判断数据集是否为空
var maxRow = df1.reduce((s1, s2) => {
accu.add(1)
if(s1._3 >= s2._3) s1 else s2
})
var df2 = df1
var i = 0
println("df2 begin count: " + df2.count())
while(maxRow._3 >= threshold && accu.value > 0) {
println("maxRow:(" + maxRow._1 + ", " + maxRow._2 + ", " + maxRow._3 + ")")
val start = System.currentTimeMillis()
val df3 = loop(spark, df2, maxRow, xCol, yCol, simCol).rdd
//清空存档文件
if((i + 1) % 200 == 0) {
println("clear checkpoint files begin...")
if(fs.exists(path)) {
fs.delete(path, true)
}
println("clear checkpoint files end")
}
//rdd存档
if((i + 1) % 20 == 0 ) {
df3.checkpoint()
}
//为什么要重新create?因为没有create, 每次循环所用的时间都会增加,感觉作用应该和 checkpoint 一样 _( '-' _)⌒)_
df2 = spark.createDataset(df3)
.withColumnRenamed("_1", xCol)
.withColumnRenamed("_2", yCol)
.withColumnRenamed("_3", simCol)
.as[(String, String, Double)]
accu.reset()
maxRow = df2.reduce((s1, s2) => {
accu.add(1)
if(s1._3 >= s2._3) s1 else s2
})
val end = System.currentTimeMillis()
println("loop(" + i + ") time: " + (end - start))
i += 1
}
if(accu.value == 0) {
df2 = Seq((maxRow._1 + "_" + maxRow._2, "", maxRow._3)).toDF(xCol, yCol, simCol).as[(String, String, Double)]
}
println("df2 result count: " + df2.count())
// fs.close()
df2.filter(col(simCol) >= minSim and (col(xCol).contains("_") or col(yCol).contains("_")))
.select(col(simCol), concat_ws(";", col(xCol), col(yCol)))
}
private def loop(spark: SparkSession, df2: Dataset[(String, String, Double)], maxRow: (String, String, Double), xCol: String, yCol: String, simCol: String): Dataset[(String, String, Double)] = {
import spark.implicits._
val t1 = maxRow._1 + "_" + maxRow._2
val df3 = df2.filter(a => !(a._1.equals(maxRow._1) && a._2.equals(maxRow._2))).mapPartitions(as => {
as.map(a => {
if(a._1.equals(maxRow._1) || a._1.equals(maxRow._2)) {
(t1, a._2, a._3)
} else if(a._2.equals(maxRow._1) || a._2.equals(maxRow._2)) {
(t1, a._1, a._3)
} else {
(a._1, a._2, a._3 * 2)
}
})
}).groupBy(col("_1"), col("_2")).agg(sum(col("_3")).divide(2)).toDF(xCol, yCol, simCol).as[(String, String, Double)]
df3
}
def fpgCompute(spark: SparkSession, data: DataFrame, groupCol: String, targetCol: String, minSupport: Double,
numPartitions: Int, minFreq: Long, p: Double, minItems: Int, path: String): Unit = {
import spark.implicits._
println("FPGrowth开始....")
require(data != null)
require(groupCol != null && groupCol.trim.length != 0)
require(targetCol != null && targetCol.trim.length != 0)
require(minSupport > 0.0 && minSupport <= 1.0)
if(!checkExist(data.schema.fieldNames, groupCol, targetCol)) {
throw new IllegalArgumentException("数据集中不存在参数中的一个或者多个列名")
}
val df = data.select(groupCol, targetCol)
val fpg = new FPGrowth().setItemsCol(targetCol).setMinSupport(minSupport)
if(numPartitions > 0) {
fpg.setNumPartitions(numPartitions)
}
val groups = df.select(groupCol).distinct().map(_.getString(0)).collect()
if(groups.length != 0) {
groups.filter(_ != null).foreach(g => {
try {
val start = System.currentTimeMillis()
println(g + "开始执行FPGrowth...")
val t = df.where(col(groupCol).equalTo(g)).select(targetCol)
val model = fpg.fit(t)
println("FPGrowth 结束, 开始合并频繁集...")
val freqItemsets = model.freqItemsets.filter(col("freq") >= minFreq) //items freq
val result = merge(spark, freqItemsets.select("items"), p, minItems)
result.withColumn("stock", lit(g)).select(col("stock"), concat_ws(";", col("class_member")).alias("class_member")).repartition(1).write.option("header", true).mode(SaveMode.Overwrite).csv(path + "/" + g)
// freqItemsets.select(lit(g), concat_ws(";", col("items")), col("freq")).repartition(4).write.mode(SaveMode.Overwrite).csv(path + "/pre-merge/" + g)
println(g + " 执行结束")
val end = System.currentTimeMillis()
println(g + " 执行FPGrowth消耗的时间为:" + (end - start))
println("====================")
} catch {
case e: Exception => {
println(g + " 出错")
e.printStackTrace()
}
}
})
} else {
println(groupCol + "数据为空")
}
println("FPGrowth结束")
}
def merge(spark: SparkSession, data: DataFrame, p: Double, minItems: Int): DataFrame = {
import spark.implicits._
// val bc = spark.sparkContext.broadcast(new FPMerge(p))
// data.as[Seq[String]].foreach(s => bc.value.add(s))
// val result = bc.value.getList()
// bc.unpersist()
// result.toDF("class_member").filter(_.getAs[Seq[String]](0).length >= minItems)
val rdd = data.as[Seq[String]].filter(_.length != 1).rdd
val result = rdd.repartition(1).sortBy(s => s.length, false).aggregate(new FPMerge(p))((fp, seq) => fp.add(seq), (fp1, fp2) => fp1.merge(fp2)).getList()
result.toDF("class_member").filter(_.getAs[Seq[String]](0).length >= minItems)
}
private def checkExist(fields: Array[String], cols: String*): Boolean = {
cols.foreach(c => {
if(!fields.contains(c)) {
false
}
})
true
}
}
class FPMerge(p: Double) extends Serializable {
val buffer = new ListBuffer[Seq[String]]
def add(seq: Seq[String]): this.type = {
if(seq.nonEmpty) {
if(buffer.isEmpty) {
buffer.append(seq)
} else {
var i = 0
var flag = true
var flag2 = false
var unionSet: Seq[String] = null
while(i < buffer.size && flag) {
val s = buffer.apply(i)
unionSet = (s ++ seq).distinct
if(s.length != unionSet.length) {
if(s.length + seq.length - unionSet.length >= p * seq.length) {
flag = false
} else {
i += 1
}
} else {
flag2 = true
flag = false
}
}
if(!flag2) {
if(!flag) {
buffer.update(i, unionSet)
} else {
buffer.append(seq)
}
}
}
}
this
}
def merge(other: FPMerge): this.type = {
other.buffer.foreach(s => add(s))
this
}
def getList(): Seq[Seq[String]] = {
buffer.toList
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy