All Downloads are FREE. Search and download functionalities are using the official Maven repository.

streaming.dsl.mmlib.algs.SQLRateSampler.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package streaming.dsl.mmlib.algs

import org.apache.spark.Partitioner
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, functions => F}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import streaming.dsl.mmlib.SQLAlg

/**
  * Created by allwefantasy on 7/5/2018.
  */
class SQLRateSampler extends SQLAlg with Functions {


  def internal_train(df: DataFrame, params: Map[String, String]) = {

    val getIntFromRow = (row: Row, i: Int) => {
      row.get(i) match {
        case a: Int => a.asInstanceOf[Int]
        case a: Double => a.asInstanceOf[Double].toInt
        case a: Float => a.asInstanceOf[Float].toInt
        case a: Long => a.asInstanceOf[Long].toInt
      }
    }

    val getIntFromRowByName = (row: Row, name: String) => {
      getIntFromRow(row, row.fieldIndex(name))
    }

    val labelCol = params.getOrElse("labelCol", "label")
    val isSplitWithSubLabel = params.getOrElse("isSplitWithSubLabel", "")

    // 0.8 0.1 0.1
    val sampleRates = params.getOrElse("sampleRate", "0.9,0.1").split(",").map(f => f.toDouble)
    var basicRate = sampleRates.head
    val newSampleRates = sampleRates.zipWithIndex.map { sr =>
      if (sr._2 > 0) {
        basicRate = sr._1 + basicRate
        (basicRate, sr._2)
      }
      else sr
    }


    val labelToCountSeq = df.groupBy(labelCol).agg(F.count(labelCol).as("subLabelCount")).orderBy(F.asc("subLabelCount")).
      select(labelCol, "subLabelCount").collect().map { f =>
      (getIntFromRow(f, 0), f.getLong(1))
    }
    val forLog = labelToCountSeq.map(f => s"${f._1}:${f._2}").mkString(",")
    logInfo(format(s"computing data stat:${forLog}"))

    val labelCount = labelToCountSeq.size

    val labelPartionMap = labelToCountSeq.map(_._1).zipWithIndex.toMap

    if (isSplitWithSubLabel == "true") {

      val dfArray = df.rdd.map { f =>
        (getIntFromRowByName(f, labelCol), f)
      }.collect()

      val splitWithSubLabel = dfArray.groupBy(_._1).flatMap(data => {

        val groupCount = data._2.length
        var splitIndex = 0
        var beginIndex = 0
        sampleRates.flatMap(percent => {
          val takeCount = (percent * groupCount).toInt
          val endIndex = beginIndex + takeCount
          val splitData = data._2.slice(beginIndex, endIndex).map(wow => {
            Row.fromSeq(wow._2.toSeq ++ Seq(splitIndex))
          })
          splitIndex = splitIndex + 1
          beginIndex = endIndex + 1
          splitData
        })
      }).toSeq


      df.sparkSession.createDataFrame(df.sparkSession.sparkContext.parallelize(splitWithSubLabel), StructType(
        df.schema ++ Seq(StructField("__split__", IntegerType))))

    } else {
      val dfWithLabelPartition = df.rdd.map { f =>
        (getIntFromRowByName(f, labelCol), f)
      }.partitionBy(new Partitioner {
        override def numPartitions: Int = labelCount

        override def getPartition(key: Any): Int = {
          labelPartionMap.getOrElse(key.asInstanceOf[Int], 0)
        }
      }).mapPartitions { iter =>
        val r = scala.util.Random
        iter.map { wow =>
          val item = r.nextFloat()
          val index = newSampleRates.filter { sr =>
            if (sr._2 > 0) {
              val newRate = sr._1
              item <= newRate && item > newSampleRates(sr._2 - 1)._1
            }
            else {
              item <= sr._1
            }
          }.head._2
          Row.fromSeq(wow._2.toSeq ++ Seq(index))
        }
      }
      df.sparkSession.createDataFrame(dfWithLabelPartition, StructType(
        df.schema ++ Seq(StructField("__split__", IntegerType))
      ))

    }

  }

  override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
    val newDF = internal_train(df, params)
    newDF.write.mode(SaveMode.Overwrite).parquet(path)
    emptyDataFrame()(df)
  }

  override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = ???

  override def predict(sparkSession: SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = ???
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy