All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.datastax.data.prepare.spark.dataset.StockSimilarity.scala Maven / Gradle / Ivy

package com.datastax.data.prepare.spark.dataset

import java.net.URI
import java.text.SimpleDateFormat

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{ArrayType, StringType}
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}

import scala.collection.mutable
import scala.collection.mutable.ListBuffer

object StockSimilarity extends Serializable {

  //证监会
  //tradeColumn: trad_dirc数组
  @deprecated
  def stockSim(spark: SparkSession, df: DataFrame, shrAcctColumn: String, secCodeColumn: String, tradDateColumn: String, tradDircColumn: String,
               limitTradeDirection: Boolean, setSimMethod: Int, minSimFlt: Double, tradeDateThreshold: Int, mktTypeColumn: String, mktType: String): DataFrame = {
    require(secCodeColumn != null, "secColumn 为空")
    require(tradDateColumn != null, "tradDateColumn 为空")
    require(tradDircColumn != null, "tradDircColumn 为空")
    val a1 = "_1"
    val a2 = "_2"
//    val tradeColumn = "tradeColumn"
    val tradeDay = "tradeDay"
    val collectTradeColumn = "collectTradeColumn"

    val tradeColDf = df.filter(col(mktTypeColumn).equalTo(mktType)).withColumn(tradeDay,
      (year(col(tradDateColumn)) - 2000) * 365 + dayofyear(col(tradDateColumn)))
//      .withColumn(tradeColumn, concat_ws("_", col(secCodeColumn), col(tradeDay), col(tradDircColumn)))
    val df1 = tradeColDf.select(shrAcctColumn, secCodeColumn, tradeDay, tradDircColumn)
      .toDF(shrAcctColumn + a1, secCodeColumn + a1, tradeDay + a1, tradDircColumn + a1)

//    val listTradeColDf = tradeColDf.groupBy(shrAcctColumn).agg(collect_list(tradeColumn))
//      .select(shrAcctColumn, "collect_list(" + tradeColumn + ")").toDF(shrAcctColumn, collectTradeColumn)
    val listTradeColDf = tradeColDf.groupBy(shrAcctColumn).agg(collect_list(tradeDay))
      .select(shrAcctColumn, "collect_list(" + tradeDay + ")").toDF(shrAcctColumn, collectTradeColumn)

    val df2 = df1.toDF(shrAcctColumn + a2, secCodeColumn + a2, tradeDay + a2, tradDircColumn + a2)
//    var nonZero = df1.join(df2, df1(secCodeColumn + a1) === df2(secCodeColumn + a2)
//      && df1(shrAcctColumn + a1) > df2(shrAcctColumn + a2) && abs(df1(tradeDay + a1) - df2(tradeDay + a2)) < tradeDateThreshold)
//      .select(shrAcctColumn + a1, shrAcctColumn + a2).distinct()

    var nonZero: DataFrame = null
    if(limitTradeDirection) {
      nonZero = df1.join(df2, df1(secCodeColumn + a1) === df2(secCodeColumn + a2)
        && df1(shrAcctColumn + a1) > df2(shrAcctColumn + a2) && abs(df1(tradeDay + a1) - df2(tradeDay + a2)) < tradeDateThreshold &&
        df1(tradDircColumn + a1).equalTo(df2(tradDircColumn + a2)))
        .select(shrAcctColumn + a1, shrAcctColumn + a2).distinct()

    } else {
      nonZero = df1.join(df2, df1(secCodeColumn + a1) === df2(secCodeColumn + a2)
        && df1(shrAcctColumn + a1) > df2(shrAcctColumn + a2) && abs(df1(tradeDay + a1) - df2(tradeDay + a2)) < tradeDateThreshold )
        .select(shrAcctColumn + a1, shrAcctColumn + a2).distinct()
    }

    val df3 = nonZero.join(listTradeColDf, nonZero(shrAcctColumn + a1) === listTradeColDf(shrAcctColumn))
      .withColumnRenamed(collectTradeColumn, collectTradeColumn + a1)
      .select(shrAcctColumn + a1, shrAcctColumn + a2, collectTradeColumn + a1)
      .join(listTradeColDf, nonZero(shrAcctColumn + a2) === listTradeColDf(shrAcctColumn))
      .withColumnRenamed(collectTradeColumn, collectTradeColumn + a2)
      .select(shrAcctColumn + a1, collectTradeColumn + a1, shrAcctColumn + a2, collectTradeColumn + a2)

    val udfDf = df3.withColumn("similarity", callUDF("dynamicSimUDF", col(collectTradeColumn + a1).cast(new ArrayType(StringType, true)),
      col(collectTradeColumn + a2).cast(new ArrayType(StringType, true)), lit(setSimMethod)))   //lit(limitTradeDirection), lit(tradeDateThreshold),

    //过滤过程最小相似度
    val filterMinSimFltDf = udfDf.filter(col("similarity") >= minSimFlt)

    //选择指定列
    val simDf = filterMinSimFltDf.select(shrAcctColumn + a1, shrAcctColumn + a2, "similarity")

    simDf
  }


//"sec_acct_code", "mob_nbr", "fix_or_memo_cntct_tel", "cntct_addr", "email_box", "identifi_file_nbr",
//  "open_agt_code", "open_agt_net_code", "open_date"
  def stockStaticSim(spark: SparkSession, data: DataFrame, sec_acct_code: String, mob_nbr: String, fix_or_memo_cntct_tel: String, cntct_addr: String,
                     email_box: String, identifi_file_nbr: String, open_agt_code: String, open_agt_net_code: String,
                     open_date: String, minSim: Double): DataFrame = {
    require(sec_acct_code != null && sec_acct_code.nonEmpty, "sec__acct_code为空")
    require(mob_nbr != null && mob_nbr.nonEmpty, "mob_nbr为空")
    require(fix_or_memo_cntct_tel != null && fix_or_memo_cntct_tel.nonEmpty, "fix_or_memo_cntct_tel为空")
    require(cntct_addr != null && cntct_addr.nonEmpty, "cntct_addr为空")
    require(email_box != null && email_box.nonEmpty, "email_box为空")
    require(identifi_file_nbr != null && identifi_file_nbr.nonEmpty, "identifi_file_nbr为空")
    require(open_agt_code != null && open_agt_code.nonEmpty, "open_agt_code为空")
    require(open_agt_net_code != null && open_agt_net_code.nonEmpty, "open_agt_net_code为空")
    require(open_date != null && open_date.nonEmpty, "open_date为空")
    require(minSim >= 0.0 && minSim <= 100.0, "minSim应该在0到100之间")
    import spark.implicits._
    val a1 = "_1"
    val a2 = "_2"

    val selectDf = data.select(col(sec_acct_code), col(mob_nbr), col(fix_or_memo_cntct_tel), col(cntct_addr), col(email_box),
      col(identifi_file_nbr), col(open_agt_code), col(open_agt_net_code), col(open_date))

    val staticArrayColumn = "staticArrayColumn"
    val arrayDf = selectDf.withColumn(staticArrayColumn, array(col(mob_nbr), col(fix_or_memo_cntct_tel), col(cntct_addr), col(email_box),
      col(identifi_file_nbr), col(open_agt_code), col(open_agt_net_code), col(open_date))).select(col(sec_acct_code), col(staticArrayColumn))

    val autoGenrateId = "auto_genrate_id"
    val genrateIdDf = arrayDf.withColumn(autoGenrateId, monotonically_increasing_id()).select(sec_acct_code, staticArrayColumn, autoGenrateId)

    val copyDf = genrateIdDf.toDF(sec_acct_code + a2, staticArrayColumn + a2, autoGenrateId + a2)
    val exculdIdDf = genrateIdDf.join(copyDf,  genrateIdDf(autoGenrateId) < copyDf(autoGenrateId + a2), "left").na.drop()
      .select(sec_acct_code, staticArrayColumn, sec_acct_code + a2, staticArrayColumn + a2)
      .toDF(sec_acct_code + a1, staticArrayColumn + a1, sec_acct_code + a2, staticArrayColumn + a2)

    val result = exculdIdDf.as[(String, Seq[String], String, Seq[String])].map(s => {
      val sim = staticUDF(s._2, s._4)
      (s._1, s._3, sim)
    }).filter(_._3 >= minSim).toDF(sec_acct_code + a1, sec_acct_code + a2, "similarity")

    result

  }

  def staticUDF(seq1: Seq[String], seq2: Seq[String]): Double = {
    var sim: Double = 0.0
    require(seq1.length == seq2.length, "seq1的长度不等于seq2的长度")

    var i = 0
    while(i < seq1.length) {
      //30%
      if(i < 4) {
        var firstEq = false
        firstEq = firstEq || (!checkNull(seq1(i), seq2(i)) && seq1(i).trim.equals(seq2(i).trim))
        if(firstEq) {
          sim += 30
          i = 3
        }
      }

      //20%
      if(i == 4) {
        if(!checkNull(seq1(i), seq2(i))) {
          val s1 = seq1(i).trim
          val s2 = seq2(i).trim
          if(s1.length > 6 && s2.length > 6) {
            if(s1.substring(0, 6).equals(s2.substring(0, 6))) {
              sim += 20
            } else if(s1.substring(0, 4).equals(s2.substring(0, 4))) {
              sim += (20 * 0.9)
            } else if(s1.substring(0, 2).equals(s2.substring(0, 2))) {
              sim += (20 * 0.6)
            }
          }

        }
      }

      //25%
      if(i == 5) {
        if(!checkNull(seq1(i), seq2(i)) && seq1(i).trim.equals(seq2(i).trim)) {
          if(!checkNull(seq1(i + 1), seq2(i + 1)) && seq1(i + 1).trim.equals(seq2(i + 1).trim)) {
            sim += 25
          } else {
            sim += (25 * 0.75)
          }
        }
        i = 6
      }

      //25%
      if(i == 7 && !checkNull(seq1(i), seq2(i))) {

        def getTime(s: String): Long = {
          if(s.matches("[0-9]{4}-[0-9]{2}-[0-9]{2}")) {
            new SimpleDateFormat("yyyy-MM-dd").parse(s).getTime / 86400000
          } else if(s.matches("[0-9]{6}")) {
            new SimpleDateFormat("yyyyMMdd").parse("20" + s).getTime / 86400000
          } else {
            -1
          }
        }

        val time1 = getTime(seq1(i).trim)
        val time2 = getTime(seq2(i).trim)
        val gap = math.abs(time2 - time1)
        if(gap == 0) {
          sim += 25
        } else if(gap == 1) {
          sim += (25 * 0.87)
        } else if (gap == 2) {
          sim += (25 * 0.64)
        } else if(gap == 3) {
          sim += (25 * 0.55)
        } else if(gap == 4) {
          sim += (25 * 0.52)
        } else if(gap == 5) {
          sim += (25 * 0.5)
        }
      }

      i += 1
    }

    sim.formatted("%.3f").toDouble
  }

  def checkNull(s1: String, s2: String): Boolean = (s1 == null || "null".equals(s1.trim.toLowerCase) || s1.trim.isEmpty) ||
    (s2 == null || "null".equals(s2.trim.toLowerCase) || s2.trim.isEmpty)

  def dynamicUDF1(tradeInfo1: Seq[String], tradeInfo2: Seq[String], s: Int): Double = {

    val maxColumns: Array[Double] = new Array[Double](tradeInfo1.length)
    val maxRows: Array[Double] = new Array[Double](tradeInfo2.length)

    for(i <- tradeInfo1.indices) {
      for(j <- tradeInfo2.indices) {
        val termSim = (100 / math.exp(math.abs(tradeInfo1(i).toInt - tradeInfo2(j).toInt))).formatted("%.3f").toDouble
        if(maxColumns(i) < termSim) {
          maxColumns(i) = termSim
        }
        if(maxRows(j) < termSim) {
          maxRows(j) = termSim
        }
      }
    }

    s match {
      case 0 => math.max(maxColumns.sum / (maxColumns.length + 0.5), maxRows.sum / (maxRows.length + 0.5)).formatted("%.3f").toDouble
      case 1 => (((maxColumns.sum / maxColumns.length) + (maxRows.sum / maxRows.length)) / 2 + 0.5).formatted("%.3f").toDouble
      case _ => throw new IllegalArgumentException("set_similrity_method的值只能为0或者1")
    }
  }

  //todo 中间文件未删除,待处理
  def stockSim2(spark: SparkSession, df: DataFrame, shrAcctColumn: String, secCodeColumn: String, tradDateColumn: String, tradDircColumn: String,
                          limitTradeDirection: Boolean, setSimMethod: Int, minSimFlt: Double, tradeDateThreshold: Int, mktTypeColumn: String, mktType: String, tempPath: String, num: Int): DataFrame = {
    require(secCodeColumn != null, "secColumn 为空")
    require(tradDateColumn != null, "tradDateColumn 为空")
    require(tradDircColumn != null, "tradDircColumn 为空")
    require(tradDircColumn != null, "tempPath 为空")
    val a1 = "_1"
    val a2 = "_2"
    val tradeDay = "tradeDay"
    import spark.implicits._
    //删除遗留中间文件
    val fs = FileSystem.get(new URI(tempPath), new Configuration())
    val path1 = new Path(tempPath + "_sim_checkpoint")
    val path2 = new Path(tempPath + "_sim_checkpoint")
    if(fs.exists(path1)) {
      fs.delete(path1, true)
    }
    if(fs.exists(path2)) {
      fs.delete(path2, true)
    }

    spark.sparkContext.setCheckpointDir(tempPath + "/_sim_checkpoint")

    //将secCode列、tradDate列和tradDirc列 转换成Long,Long的二进制的最后一位代表tradDirc,倒数第20位到第2位为tradDate,其余代表secCode
    val tradeColDf = df.filter(col(mktTypeColumn).equalTo(mktType)).withColumn(tradeDay, (year(col(tradDateColumn)) - 2000) * 365 + dayofyear(col(tradDateColumn)))
      .select(shrAcctColumn, secCodeColumn, tradeDay, tradDircColumn).as[(String, String, Int, String)]
      .map(s => {
        val code = Integer.toHexString(s._2.toInt)
        val dirc = if("B".equals(s._4)) 0 else 1
        var day = Integer.toHexString((s._3 << 1) + dirc)
        if(day.length > 5) {
          throw new IllegalStateException("到这步说明day过大,估计要到2100年才会报这个错,报错时只要修改下上面的2000")
        } else {
          for(i <- Range(day.length, 5)) {
            day = "0" + day
          }
        }
        (s._1, java.lang.Long.parseLong(code + day, 16))
      }).toDF(a1, a2)

    //聚合交易信息
    val df1 = tradeColDf.groupBy(a1).agg(collect_list(a2).alias(a2)).withColumn("_id_", monotonically_increasing_id()).as[(String, Seq[Long], Long)]
    val df2 = df1.checkpoint()

    //开始迭代生成相似度
    val total = df2.count()
    println("账户数为: " + total)
    var list = df2.head(num)
    var max = list(list.length - 1)._3
    var i = 0
    while( i * num < total) {
      println("第" + i + "次循环开始...")
      val start = System.currentTimeMillis()
      val df3 = df2.flatMap(s1 => {
        //生成secCode(dirc)和位置i的键值对
        val map = new mutable.HashMap[Long, ListBuffer[Int]]
        s1._2.indices.foreach(i => {
          var t1 = 0l
          if(limitTradeDirection) {
            t1 = ((s1._2(i) >> 20) << 1) + (s1._2(i) & 0x1)
          } else {
            t1 = s1._2(i) >> 20
          }
          if(map.contains(t1)) {
            map(t1).append(i)
          } else {
            val buffer = new ListBuffer[Int]
            buffer.append(i)
            map.put(t1, buffer)
          }
        })

        //遍历list,计算sim,生成相似数组
        list.filter(_._3 > s1._3).map(s2 => {
          val maxColumns: Array[Double] = new Array[Double](s1._2.length)
          var sumRow = 0.0
          s2._2.foreach(l => {
            var code = 0l
            if(limitTradeDirection) {
              code = ((l >> 20) << 1) + (l & 0x1)
            } else {
              code = l >> 20
            }
            if(map.contains(code)) {
              var maxRow = 0.0
              map(code).foreach(i => {
                val gap = math.abs(((s1._2(i) & 0xFFFFF) >> 1) - ((l & 0xFFFFF) >> 1))
                val termSim = if(gap < tradeDateThreshold) math.round(100 / math.exp(gap)) else 0.0
                if(maxColumns(i) < termSim) {
                  maxColumns(i) = termSim
                }
                if(maxRow < termSim) {
                  maxRow = termSim
                }
              })
              sumRow += maxRow
            }
          })

          val sim = setSimMethod match {
            case 0 => math.max(maxColumns.sum / (maxColumns.length + 0.5), sumRow / (s2._2.length + 0.5)).formatted("%.3f")
            case 1 => (((maxColumns.sum / maxColumns.length) + (sumRow / s2._2.length)) / 2 + 0.5).formatted("%.3f")
            case _ => throw new IllegalArgumentException("set_similarity_method的值只能为0或者1")
          }
          (s1._1, s2._1, sim)
        })

      }).filter(_._3.toDouble >= minSimFlt).coalesce(100)

      df3.write.mode(SaveMode.Overwrite).csv(tempPath + "/_dynamic_sim_temp/part-" + i)

      i += 1
      if(i * num >= total) {
        list = df2.filter(_._3 > max).head(num)
        max = list(list.length - 1)._3
      }
      val end = System.currentTimeMillis()
      println("第" + i + "次循环结束")
      println("第" + i + "次循环耗时:" + (end -start))

    }
    spark.read.csv(tempPath + "/_dynamic_sim_temp/*").toDF(shrAcctColumn + a1, shrAcctColumn + a2, "similarity")
  }


}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy