Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
com.datastax.data.prepare.spark.dataset.StockSimilarity.scala Maven / Gradle / Ivy
package com.datastax.data.prepare.spark.dataset
import java.net.URI
import java.text.SimpleDateFormat
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{ArrayType, StringType}
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
object StockSimilarity extends Serializable {
//证监会
//tradeColumn: trad_dirc数组
@deprecated
def stockSim(spark: SparkSession, df: DataFrame, shrAcctColumn: String, secCodeColumn: String, tradDateColumn: String, tradDircColumn: String,
limitTradeDirection: Boolean, setSimMethod: Int, minSimFlt: Double, tradeDateThreshold: Int, mktTypeColumn: String, mktType: String): DataFrame = {
require(secCodeColumn != null, "secColumn 为空")
require(tradDateColumn != null, "tradDateColumn 为空")
require(tradDircColumn != null, "tradDircColumn 为空")
val a1 = "_1"
val a2 = "_2"
// val tradeColumn = "tradeColumn"
val tradeDay = "tradeDay"
val collectTradeColumn = "collectTradeColumn"
val tradeColDf = df.filter(col(mktTypeColumn).equalTo(mktType)).withColumn(tradeDay,
(year(col(tradDateColumn)) - 2000) * 365 + dayofyear(col(tradDateColumn)))
// .withColumn(tradeColumn, concat_ws("_", col(secCodeColumn), col(tradeDay), col(tradDircColumn)))
val df1 = tradeColDf.select(shrAcctColumn, secCodeColumn, tradeDay, tradDircColumn)
.toDF(shrAcctColumn + a1, secCodeColumn + a1, tradeDay + a1, tradDircColumn + a1)
// val listTradeColDf = tradeColDf.groupBy(shrAcctColumn).agg(collect_list(tradeColumn))
// .select(shrAcctColumn, "collect_list(" + tradeColumn + ")").toDF(shrAcctColumn, collectTradeColumn)
val listTradeColDf = tradeColDf.groupBy(shrAcctColumn).agg(collect_list(tradeDay))
.select(shrAcctColumn, "collect_list(" + tradeDay + ")").toDF(shrAcctColumn, collectTradeColumn)
val df2 = df1.toDF(shrAcctColumn + a2, secCodeColumn + a2, tradeDay + a2, tradDircColumn + a2)
// var nonZero = df1.join(df2, df1(secCodeColumn + a1) === df2(secCodeColumn + a2)
// && df1(shrAcctColumn + a1) > df2(shrAcctColumn + a2) && abs(df1(tradeDay + a1) - df2(tradeDay + a2)) < tradeDateThreshold)
// .select(shrAcctColumn + a1, shrAcctColumn + a2).distinct()
var nonZero: DataFrame = null
if(limitTradeDirection) {
nonZero = df1.join(df2, df1(secCodeColumn + a1) === df2(secCodeColumn + a2)
&& df1(shrAcctColumn + a1) > df2(shrAcctColumn + a2) && abs(df1(tradeDay + a1) - df2(tradeDay + a2)) < tradeDateThreshold &&
df1(tradDircColumn + a1).equalTo(df2(tradDircColumn + a2)))
.select(shrAcctColumn + a1, shrAcctColumn + a2).distinct()
} else {
nonZero = df1.join(df2, df1(secCodeColumn + a1) === df2(secCodeColumn + a2)
&& df1(shrAcctColumn + a1) > df2(shrAcctColumn + a2) && abs(df1(tradeDay + a1) - df2(tradeDay + a2)) < tradeDateThreshold )
.select(shrAcctColumn + a1, shrAcctColumn + a2).distinct()
}
val df3 = nonZero.join(listTradeColDf, nonZero(shrAcctColumn + a1) === listTradeColDf(shrAcctColumn))
.withColumnRenamed(collectTradeColumn, collectTradeColumn + a1)
.select(shrAcctColumn + a1, shrAcctColumn + a2, collectTradeColumn + a1)
.join(listTradeColDf, nonZero(shrAcctColumn + a2) === listTradeColDf(shrAcctColumn))
.withColumnRenamed(collectTradeColumn, collectTradeColumn + a2)
.select(shrAcctColumn + a1, collectTradeColumn + a1, shrAcctColumn + a2, collectTradeColumn + a2)
val udfDf = df3.withColumn("similarity", callUDF("dynamicSimUDF", col(collectTradeColumn + a1).cast(new ArrayType(StringType, true)),
col(collectTradeColumn + a2).cast(new ArrayType(StringType, true)), lit(setSimMethod))) //lit(limitTradeDirection), lit(tradeDateThreshold),
//过滤过程最小相似度
val filterMinSimFltDf = udfDf.filter(col("similarity") >= minSimFlt)
//选择指定列
val simDf = filterMinSimFltDf.select(shrAcctColumn + a1, shrAcctColumn + a2, "similarity")
simDf
}
//"sec_acct_code", "mob_nbr", "fix_or_memo_cntct_tel", "cntct_addr", "email_box", "identifi_file_nbr",
// "open_agt_code", "open_agt_net_code", "open_date"
def stockStaticSim(spark: SparkSession, data: DataFrame, sec_acct_code: String, mob_nbr: String, fix_or_memo_cntct_tel: String, cntct_addr: String,
email_box: String, identifi_file_nbr: String, open_agt_code: String, open_agt_net_code: String,
open_date: String, minSim: Double): DataFrame = {
require(sec_acct_code != null && sec_acct_code.nonEmpty, "sec__acct_code为空")
require(mob_nbr != null && mob_nbr.nonEmpty, "mob_nbr为空")
require(fix_or_memo_cntct_tel != null && fix_or_memo_cntct_tel.nonEmpty, "fix_or_memo_cntct_tel为空")
require(cntct_addr != null && cntct_addr.nonEmpty, "cntct_addr为空")
require(email_box != null && email_box.nonEmpty, "email_box为空")
require(identifi_file_nbr != null && identifi_file_nbr.nonEmpty, "identifi_file_nbr为空")
require(open_agt_code != null && open_agt_code.nonEmpty, "open_agt_code为空")
require(open_agt_net_code != null && open_agt_net_code.nonEmpty, "open_agt_net_code为空")
require(open_date != null && open_date.nonEmpty, "open_date为空")
require(minSim >= 0.0 && minSim <= 100.0, "minSim应该在0到100之间")
import spark.implicits._
val a1 = "_1"
val a2 = "_2"
val selectDf = data.select(col(sec_acct_code), col(mob_nbr), col(fix_or_memo_cntct_tel), col(cntct_addr), col(email_box),
col(identifi_file_nbr), col(open_agt_code), col(open_agt_net_code), col(open_date))
val staticArrayColumn = "staticArrayColumn"
val arrayDf = selectDf.withColumn(staticArrayColumn, array(col(mob_nbr), col(fix_or_memo_cntct_tel), col(cntct_addr), col(email_box),
col(identifi_file_nbr), col(open_agt_code), col(open_agt_net_code), col(open_date))).select(col(sec_acct_code), col(staticArrayColumn))
val autoGenrateId = "auto_genrate_id"
val genrateIdDf = arrayDf.withColumn(autoGenrateId, monotonically_increasing_id()).select(sec_acct_code, staticArrayColumn, autoGenrateId)
val copyDf = genrateIdDf.toDF(sec_acct_code + a2, staticArrayColumn + a2, autoGenrateId + a2)
val exculdIdDf = genrateIdDf.join(copyDf, genrateIdDf(autoGenrateId) < copyDf(autoGenrateId + a2), "left").na.drop()
.select(sec_acct_code, staticArrayColumn, sec_acct_code + a2, staticArrayColumn + a2)
.toDF(sec_acct_code + a1, staticArrayColumn + a1, sec_acct_code + a2, staticArrayColumn + a2)
val result = exculdIdDf.as[(String, Seq[String], String, Seq[String])].map(s => {
val sim = staticUDF(s._2, s._4)
(s._1, s._3, sim)
}).filter(_._3 >= minSim).toDF(sec_acct_code + a1, sec_acct_code + a2, "similarity")
result
}
def staticUDF(seq1: Seq[String], seq2: Seq[String]): Double = {
var sim: Double = 0.0
require(seq1.length == seq2.length, "seq1的长度不等于seq2的长度")
var i = 0
while(i < seq1.length) {
//30%
if(i < 4) {
var firstEq = false
firstEq = firstEq || (!checkNull(seq1(i), seq2(i)) && seq1(i).trim.equals(seq2(i).trim))
if(firstEq) {
sim += 30
i = 3
}
}
//20%
if(i == 4) {
if(!checkNull(seq1(i), seq2(i))) {
val s1 = seq1(i).trim
val s2 = seq2(i).trim
if(s1.length > 6 && s2.length > 6) {
if(s1.substring(0, 6).equals(s2.substring(0, 6))) {
sim += 20
} else if(s1.substring(0, 4).equals(s2.substring(0, 4))) {
sim += (20 * 0.9)
} else if(s1.substring(0, 2).equals(s2.substring(0, 2))) {
sim += (20 * 0.6)
}
}
}
}
//25%
if(i == 5) {
if(!checkNull(seq1(i), seq2(i)) && seq1(i).trim.equals(seq2(i).trim)) {
if(!checkNull(seq1(i + 1), seq2(i + 1)) && seq1(i + 1).trim.equals(seq2(i + 1).trim)) {
sim += 25
} else {
sim += (25 * 0.75)
}
}
i = 6
}
//25%
if(i == 7 && !checkNull(seq1(i), seq2(i))) {
def getTime(s: String): Long = {
if(s.matches("[0-9]{4}-[0-9]{2}-[0-9]{2}")) {
new SimpleDateFormat("yyyy-MM-dd").parse(s).getTime / 86400000
} else if(s.matches("[0-9]{6}")) {
new SimpleDateFormat("yyyyMMdd").parse("20" + s).getTime / 86400000
} else {
-1
}
}
val time1 = getTime(seq1(i).trim)
val time2 = getTime(seq2(i).trim)
val gap = math.abs(time2 - time1)
if(gap == 0) {
sim += 25
} else if(gap == 1) {
sim += (25 * 0.87)
} else if (gap == 2) {
sim += (25 * 0.64)
} else if(gap == 3) {
sim += (25 * 0.55)
} else if(gap == 4) {
sim += (25 * 0.52)
} else if(gap == 5) {
sim += (25 * 0.5)
}
}
i += 1
}
sim.formatted("%.3f").toDouble
}
def checkNull(s1: String, s2: String): Boolean = (s1 == null || "null".equals(s1.trim.toLowerCase) || s1.trim.isEmpty) ||
(s2 == null || "null".equals(s2.trim.toLowerCase) || s2.trim.isEmpty)
def dynamicUDF1(tradeInfo1: Seq[String], tradeInfo2: Seq[String], s: Int): Double = {
val maxColumns: Array[Double] = new Array[Double](tradeInfo1.length)
val maxRows: Array[Double] = new Array[Double](tradeInfo2.length)
for(i <- tradeInfo1.indices) {
for(j <- tradeInfo2.indices) {
val termSim = (100 / math.exp(math.abs(tradeInfo1(i).toInt - tradeInfo2(j).toInt))).formatted("%.3f").toDouble
if(maxColumns(i) < termSim) {
maxColumns(i) = termSim
}
if(maxRows(j) < termSim) {
maxRows(j) = termSim
}
}
}
s match {
case 0 => math.max(maxColumns.sum / (maxColumns.length + 0.5), maxRows.sum / (maxRows.length + 0.5)).formatted("%.3f").toDouble
case 1 => (((maxColumns.sum / maxColumns.length) + (maxRows.sum / maxRows.length)) / 2 + 0.5).formatted("%.3f").toDouble
case _ => throw new IllegalArgumentException("set_similrity_method的值只能为0或者1")
}
}
//todo 中间文件未删除,待处理
def stockSim2(spark: SparkSession, df: DataFrame, shrAcctColumn: String, secCodeColumn: String, tradDateColumn: String, tradDircColumn: String,
limitTradeDirection: Boolean, setSimMethod: Int, minSimFlt: Double, tradeDateThreshold: Int, mktTypeColumn: String, mktType: String, tempPath: String, num: Int): DataFrame = {
require(secCodeColumn != null, "secColumn 为空")
require(tradDateColumn != null, "tradDateColumn 为空")
require(tradDircColumn != null, "tradDircColumn 为空")
require(tradDircColumn != null, "tempPath 为空")
val a1 = "_1"
val a2 = "_2"
val tradeDay = "tradeDay"
import spark.implicits._
//删除遗留中间文件
val fs = FileSystem.get(new URI(tempPath), new Configuration())
val path1 = new Path(tempPath + "_sim_checkpoint")
val path2 = new Path(tempPath + "_sim_checkpoint")
if(fs.exists(path1)) {
fs.delete(path1, true)
}
if(fs.exists(path2)) {
fs.delete(path2, true)
}
spark.sparkContext.setCheckpointDir(tempPath + "/_sim_checkpoint")
//将secCode列、tradDate列和tradDirc列 转换成Long,Long的二进制的最后一位代表tradDirc,倒数第20位到第2位为tradDate,其余代表secCode
val tradeColDf = df.filter(col(mktTypeColumn).equalTo(mktType)).withColumn(tradeDay, (year(col(tradDateColumn)) - 2000) * 365 + dayofyear(col(tradDateColumn)))
.select(shrAcctColumn, secCodeColumn, tradeDay, tradDircColumn).as[(String, String, Int, String)]
.map(s => {
val code = Integer.toHexString(s._2.toInt)
val dirc = if("B".equals(s._4)) 0 else 1
var day = Integer.toHexString((s._3 << 1) + dirc)
if(day.length > 5) {
throw new IllegalStateException("到这步说明day过大,估计要到2100年才会报这个错,报错时只要修改下上面的2000")
} else {
for(i <- Range(day.length, 5)) {
day = "0" + day
}
}
(s._1, java.lang.Long.parseLong(code + day, 16))
}).toDF(a1, a2)
//聚合交易信息
val df1 = tradeColDf.groupBy(a1).agg(collect_list(a2).alias(a2)).withColumn("_id_", monotonically_increasing_id()).as[(String, Seq[Long], Long)]
val df2 = df1.checkpoint()
//开始迭代生成相似度
val total = df2.count()
println("账户数为: " + total)
var list = df2.head(num)
var max = list(list.length - 1)._3
var i = 0
while( i * num < total) {
println("第" + i + "次循环开始...")
val start = System.currentTimeMillis()
val df3 = df2.flatMap(s1 => {
//生成secCode(dirc)和位置i的键值对
val map = new mutable.HashMap[Long, ListBuffer[Int]]
s1._2.indices.foreach(i => {
var t1 = 0l
if(limitTradeDirection) {
t1 = ((s1._2(i) >> 20) << 1) + (s1._2(i) & 0x1)
} else {
t1 = s1._2(i) >> 20
}
if(map.contains(t1)) {
map(t1).append(i)
} else {
val buffer = new ListBuffer[Int]
buffer.append(i)
map.put(t1, buffer)
}
})
//遍历list,计算sim,生成相似数组
list.filter(_._3 > s1._3).map(s2 => {
val maxColumns: Array[Double] = new Array[Double](s1._2.length)
var sumRow = 0.0
s2._2.foreach(l => {
var code = 0l
if(limitTradeDirection) {
code = ((l >> 20) << 1) + (l & 0x1)
} else {
code = l >> 20
}
if(map.contains(code)) {
var maxRow = 0.0
map(code).foreach(i => {
val gap = math.abs(((s1._2(i) & 0xFFFFF) >> 1) - ((l & 0xFFFFF) >> 1))
val termSim = if(gap < tradeDateThreshold) math.round(100 / math.exp(gap)) else 0.0
if(maxColumns(i) < termSim) {
maxColumns(i) = termSim
}
if(maxRow < termSim) {
maxRow = termSim
}
})
sumRow += maxRow
}
})
val sim = setSimMethod match {
case 0 => math.max(maxColumns.sum / (maxColumns.length + 0.5), sumRow / (s2._2.length + 0.5)).formatted("%.3f")
case 1 => (((maxColumns.sum / maxColumns.length) + (sumRow / s2._2.length)) / 2 + 0.5).formatted("%.3f")
case _ => throw new IllegalArgumentException("set_similarity_method的值只能为0或者1")
}
(s1._1, s2._1, sim)
})
}).filter(_._3.toDouble >= minSimFlt).coalesce(100)
df3.write.mode(SaveMode.Overwrite).csv(tempPath + "/_dynamic_sim_temp/part-" + i)
i += 1
if(i * num >= total) {
list = df2.filter(_._3 > max).head(num)
max = list(list.length - 1)._3
}
val end = System.currentTimeMillis()
println("第" + i + "次循环结束")
println("第" + i + "次循环耗时:" + (end -start))
}
spark.read.csv(tempPath + "/_dynamic_sim_temp/*").toDF(shrAcctColumn + a1, shrAcctColumn + a2, "similarity")
}
}