All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.datastax.data.prepare.spark.dataset.StockHierarchicalCluster.scala Maven / Gradle / Ivy

package com.datastax.data.prepare.spark.dataset

import java.util

import com.datastax.insight.core.driver.SparkContextBuilder
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, RowFactory}
import org.apache.spark.storage.StorageLevel

import scala.collection.mutable

object StockHierarchicalCluster extends Serializable {

  def schema(class_one:String,class_two:String,sim:String):StructType = {
    val headerColumns = Array(class_one,class_two,sim)
    val fields = headerColumns.map(fieldName => StructField(fieldName, StringType, nullable = true))
    StructType(fields)
  }

  def cluster(dataset:DataFrame,class_one:String,class_two:String,sim:String,intera:Int,threshold:Double):DataFrame = {

    val order_data = dataset.orderBy(desc(sim))
    order_data.persist(StorageLevel.MEMORY_AND_DISK_SER)
    val new_class = order_data.take(intera).toSeq

    val except_intera_data = order_data.withColumn("id",monotonically_increasing_id()+1).filter(col("id") > intera)
      .select(col(class_one),col(class_two),col(sim))
    order_data.unpersist(true)

    val new_class_dif = cluster_distinct(new_class)

    var filter_data = except_intera_data
    new_class_dif.foreach(t => {
      filter_data = filter_data.filter(col(class_one).notEqual(t)).filter(col(class_two).notEqual(t))
    })
    val class_data = except_intera_data.except(filter_data)

    val props_dif = prop_distinct(filter_data,class_one,class_two,sim)

    val new_data = createNewData(class_data,class_one,class_two,sim,props_dif,new_class_dif)
    val dataResult = filter_data.union(new_data)
    dataResult
  }

  def createNewData(dataset:DataFrame,class_one:String,class_two:String,sim:String,props_dif:Seq[String],new_class_dif:Seq[String]):DataFrame = {
    val new_col = "new_col"
    val data = dataset.withColumn(new_col,concat_ws("_",col(class_one),col(class_two))).select(new_col,sim)
    data.persist(StorageLevel.MEMORY_AND_DISK)
    val rows = new util.ArrayList[Row]
    var i = 0
    println("prop start!!")
    props_dif.foreach(t => {
      val new_clu_one = new_class_dif.mkString(";")
      val new_clu_two = t.toString
      val new_sim = simOperator(data,class_one,class_two,sim,new_class_dif,t)
      rows.add(RowFactory.create(new_clu_one,new_clu_two,new_sim.toString))
      if (i % 100 == 0) {
        println("prop " + i + " is: "+ t)
      }
      i = i + 1
    })
    data.unpersist(true)
    println("prop end!!")
    SparkContextBuilder.getSession.createDataFrame(rows,schema(class_one,class_two,sim))
  }

  def simOperator(dataset:DataFrame,class_one:String,class_two:String,sim:String,new_class_dif:Seq[String],u_id:String):String = {
    val ds = dataset.filter(r => {
      val one_two = r.getString(0)
      one_two.split("_").contains(u_id) && new_class_dif.contains(one_two.replace("_","").replace(u_id,""))
    })
    var sim_sum:Double = 0.0
    for (row:Row <- ds.select(col(sim)).collect()) {
      sim_sum = sim_sum + row.getString(0).toDouble
    }
    val sim_result = (sim_sum / new_class_dif.size).formatted("%.4f")
    sim_result
  }

  def prop_distinct(dataset:DataFrame,class_one:String,class_two:String,sim:String):Seq[String] = {
    val clu_one = dataset.select(class_one).distinct().collect()
    val clu_two = dataset.select(class_two).distinct().collect()
    val set = mutable.HashSet[String]()
    clu_one.foreach(t => {set += t.getString(0)})
    clu_two.foreach(t => {set += t.getString(0)})
    set.toSeq
  }

  def cluster_distinct(interaData: Seq[Row]):Seq[String] = {
    val set = mutable.HashSet[String]()
    interaData.foreach(t => {
      set += t.getString(0)
      set += t.getString(1)
    })
    set.toSeq
  }

  def runCluster(dataset:DataFrame,x:String,y:String,sim:String,intera:Int,threshold:Double,min_sim:Double): DataFrame = {
    var data_new = dataset
    var flag = true
    while (flag) {
      data_new = cluster(data_new,x,y,sim,intera,threshold)
      if (data_new.count() > 0) {
        val cluster_sim = data_new.orderBy(desc(sim)).take(1).apply(0).getString(2).toDouble
        if (cluster_sim < threshold.toDouble) {
          flag = false
        }
      }
    }
    val cluster_result = "cluster"
    val data_tmp = data_new.withColumn(cluster_result,concat_ws(";",col(x),col(y))).select(col(sim),col(cluster_result))
    data_tmp.filter(col(sim) > min_sim)
  }



}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy