All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.recommendation.RankingTrainValidationSplit.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.recommendation

import com.microsoft.ml.spark.core.contracts.Wrappable
import com.microsoft.ml.spark.core.env.InternalWrapper
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param._
import org.apache.spark.ml.recommendation._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.{Model, _}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{collect_list, rank => r, _}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Dataset}

import scala.collection.mutable
import scala.concurrent.duration.Duration
import scala.concurrent.{ExecutionContext, Future}
import scala.util.Random

@InternalWrapper
class RankingTrainValidationSplit(override val uid: String) extends Estimator[RankingTrainValidationSplitModel]
  with RankingTrainValidationSplitParams with Wrappable
  with ComplexParamsWritable with RecommendationParams {

  def this() = this(Identifiable.randomUID("RankingTrainValidationSplit"))

  /** @group setParam */
  def setUserCol(value: String): this.type = set(userCol, value)

  /** @group setParam */
  def setItemCol(value: String): this.type = set(itemCol, value)

  /** @group setParam */
  def setRatingCol(value: String): this.type = set(ratingCol, value)

  /** @group setParam */
  def setEstimator(value: Estimator[_ <: Model[_]]): this.type = set(estimator, value)

  /** @group setParam */
  def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value)

  /** @group setParam */
  def setEvaluator(value: Evaluator): this.type = set(evaluator, value)

  /** @group setParam */
  def setTrainRatio(value: Double): this.type = set(trainRatio, value)

  /** @group setParam */
  def setSeed(value: Long): this.type = set(seed, value)

  /** @group setParam */
  def setMinRatingsU(value: Int): this.type = set(minRatingsU, value)

  /** @group setParam */
  def setMinRatingsI(value: Int): this.type = set(minRatingsI, value)

  override def transformSchema(schema: StructType): StructType = transformSchemaImpl(schema)

  /**
    * The number of threads to use when running parallel algorithms.
    * Default is 1 for serial execution
    *
    * @group expertParam
    */
  val parallelism = new IntParam(this, "parallelism",
    "the number of threads to use when running parallel algorithms", ParamValidators.gtEq(1))

  setDefault(parallelism -> 1)

  /** @group expertGetParam */
  def getParallelism: Int = $(parallelism)

  /** @group expertSetParam */
  def setParallelism(value: Int): this.type = set(parallelism, value)

  private[ml] def getExecutionContext: ExecutionContext = {

    getParallelism match {
      case 1 =>
        SparkHelpers.getThreadUtils().sameThread
      case n =>
        ExecutionContext.fromExecutorService(SparkHelpers.getThreadUtils()
          .newDaemonCachedThreadPool(s"${this.getClass.getSimpleName}-thread-pool", n))
    }
  }

  override def fit(dataset: Dataset[_]): RankingTrainValidationSplitModel = {
    val schema = dataset.schema
    transformSchema(schema, logging = true)
    val est = getEstimator
    val eval = getEvaluator.asInstanceOf[RankingEvaluator]
    val epm = getEstimatorParamMaps
    val numModels = epm.length

    dataset.cache()
    eval.setNItems(dataset.agg(countDistinct(col(getItemCol))).take(1)(0).getLong(0))
    val filteredDataset = filterRatings(dataset.dropDuplicates())

    //Stratified Split of Dataset
    val Array(trainingDataset, validationDataset): Array[DataFrame] = splitDF(filteredDataset)
    trainingDataset.cache()
    validationDataset.cache()

    val executionContext = getExecutionContext

    def calculateMetrics(model: Transformer, validationDataset: Dataset[_]): Double = model match {
      case p: PipelineModel =>
        //Assume Rec Algo is last stage of pipeline
        val modelTemp = model.asInstanceOf[PipelineModel].stages.last
        calculateMetrics(modelTemp, validationDataset)
      case a: ALSModel      =>
        val recs = model.asInstanceOf[ALSModel].recommendForAllUsers(eval.getK)
        val preparedTest: Dataset[_] = prepareTestData(validationDataset.toDF(), recs, eval.getK)
        eval.evaluate(preparedTest)
    }

    val metricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
      Future[Double] {
        val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]]
        calculateMetrics(model, validationDataset)
      }(executionContext)
    }

    val metrics = metricFutures.map(SparkHelpers.getThreadUtils().awaitResult(_, Duration.Inf))

    trainingDataset.unpersist()
    validationDataset.unpersist()

    val (bestMetric, bestIndex) =
      if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1)
      else metrics.zipWithIndex.minBy(_._1)

    val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
    copyValues(new RankingTrainValidationSplitModel(uid)
      .setBestModel(bestModel)
      .setValidationMetrics(metrics)
      .setParent(this))
  }

  override def copy(extra: ParamMap): RankingTrainValidationSplit = defaultCopy(extra)

  private def filterByItemCount(dataset: Dataset[_]): DataFrame = {
    dataset
      .groupBy(getUserCol)
      .agg(col(getUserCol), count(col(getItemCol)))
      .withColumnRenamed(s"count(${getItemCol})", "nitems")
      .where(col("nitems") >= getMinRatingsU)
      .drop("nitems")
      .cache()
  }

  private def filterByUserRatingCount(dataset: Dataset[_]): DataFrame = dataset
    .groupBy(getItemCol)
    .agg(col(getItemCol), count(col(getUserCol)))
    .withColumnRenamed(s"count(${getUserCol})", "ncustomers")
    .where(col("ncustomers") >= getMinRatingsI)
    .join(dataset, getItemCol)
    .drop("ncustomers")
    .cache()

  def filterRatings(dataset: Dataset[_]): DataFrame = filterByUserRatingCount(dataset)
    .join(filterByItemCount(dataset), $(userCol))

  def splitDF(dataset: DataFrame): Array[DataFrame] = {
    val shuffleFlag = true
    val shuffleBC = dataset.sparkSession.sparkContext.broadcast(shuffleFlag)

    if (dataset.columns.contains(getRatingCol)) {
      val wrapColumn = udf((itemId: Double, rating: Double) => Array(itemId, rating))

      val sliceudf = udf(
        (r: Seq[Array[Double]]) => r.slice(0, math.round(r.length * $(trainRatio)).toInt))

      val shuffle = udf((r: Seq[Array[Double]]) =>
        if (shuffleBC.value) Random.shuffle(r.toSeq)
        else r
      )
      val dropudf = udf((r: Seq[Array[Double]]) => r.drop(math.round(r.length * $(trainRatio)).toInt))

      val testds = dataset
        .withColumn("itemIDRating", wrapColumn(col(getItemCol), col(getRatingCol)))
        .groupBy(col(getUserCol))
        .agg(collect_list(col("itemIDRating")))
        .withColumn("shuffle", shuffle(col("collect_list(itemIDRating)")))
        .withColumn("train", sliceudf(col("shuffle")))
        .withColumn("test", dropudf(col("shuffle")))
        .drop(col("collect_list(itemIDRating)")).drop(col("shuffle"))
        //.cache()

      val train = testds
        .select(getUserCol, "train")
        .withColumn("itemIdRating", explode(col("train")))
        .drop("train")
        .withColumn(getItemCol, col("itemIdRating").getItem(0))
        .withColumn(getRatingCol, col("itemIdRating").getItem(1))
        .drop("itemIdRating")

      val test = testds
        .select(getUserCol, "test")
        .withColumn("itemIdRating", explode(col("test")))
        .drop("test")
        .withColumn(getItemCol, col("itemIdRating").getItem(0))
        .withColumn(getRatingCol, col("itemIdRating").getItem(1))
        .drop("itemIdRating")

      Array(train, test)
    }
    else {
      val shuffle = udf((r: Seq[Double]) =>
        if (shuffleBC.value) Random.shuffle(r.toSeq)
        else r
      )
      val sliceudf = udf(
        (r: Seq[Double]) => r.slice(0, math.round(r.length * $(trainRatio)).toInt))
      val dropudf = udf((r: Seq[Double]) => r.drop(math.round(r.length * $(trainRatio)).toInt))

      val testds = dataset
        .groupBy(col(getUserCol))
        .agg(collect_list(col(getItemCol)))
        .withColumn("shuffle", shuffle(col(s"collect_list(${getItemCol})")))
        .withColumn("train", sliceudf(col("shuffle")))
        .withColumn("test", dropudf(col("shuffle")))
        .drop(col(s"collect_list(${getItemCol}")).drop(col("shuffle"))
        .cache()

      val train = testds
        .select(getUserCol, "train")
        .withColumn(getItemCol, explode(col("train")))
        .drop("train")

      val test = testds
        .select(getUserCol, "test")
        .withColumn(getItemCol, explode(col("test")))
        .drop("test")

      Array(train, test)
    }
  }

  def prepareTestData(validationDataset: DataFrame, recs: DataFrame, k: Int): Dataset[_] = {
    val est = $(estimator) match {
      case p: Pipeline =>
        //Assume Rec is last stage of pipeline
        val pipe = $(estimator).asInstanceOf[Pipeline].getStages.last
        pipe match {
          case a: ALS =>
            pipe.asInstanceOf[ALS]
        }
      case a: ALS =>
        $(estimator).asInstanceOf[ALS]
    }

    val userColumn = est.getUserCol
    val itemColumn = est.getItemCol

    val perUserRecommendedItemsDF: DataFrame = recs
      .select(userColumn, "recommendations." + itemColumn)
      .withColumnRenamed(itemColumn, "prediction")

    val perUserActualItemsDF = if (validationDataset.columns.contains($(ratingCol))) {
      val windowSpec = Window.partitionBy(userColumn).orderBy(col($(ratingCol)).desc)

      validationDataset
        .select(userColumn, itemColumn, $(ratingCol))
        .withColumn("rank", r().over(windowSpec).alias("rank"))
        .where(col("rank") <= k)
        .groupBy(userColumn)
        .agg(col(userColumn), collect_list(col(itemColumn)))
        .withColumnRenamed("collect_list(" + itemColumn + ")", "label")
        .select(userColumn, "label")
    } else {
      val windowSpec = Window.partitionBy(userColumn).orderBy(col($(itemCol)).desc)

      validationDataset
        .select(userColumn, itemColumn)
        .withColumn("rank", r().over(windowSpec).alias("rank"))
        .where(col("rank") <= k)
        .groupBy(userColumn)
        .agg(col(userColumn), collect_list(col(itemColumn)))
        .withColumnRenamed("collect_list(" + itemColumn + ")", "label")
        .select(userColumn, "label")
    }
    val joinedRecActual = perUserRecommendedItemsDF
      .join(perUserActualItemsDF, userColumn)
      .drop(userColumn)

    joinedRecActual
  }
}

object RankingTrainValidationSplit extends ComplexParamsReadable[RankingTrainValidationSplit]

@InternalWrapper
class RankingTrainValidationSplitModel(
  override val uid: String)
  extends Model[RankingTrainValidationSplitModel] with Wrappable
    with ComplexParamsWritable {

  def setValidationMetrics(value: Array[_]): this.type = set(validationMetrics, value)

  val validationMetrics = new ArrayParam(this, "validationMetrics", "Best Model")

  /** @group getParam */
  def getValidationMetrics: Array[_] = $(validationMetrics)

  def setBestModel(value: Model[_]): this.type = set(bestModel, value)

  val bestModel: TransformerParam =
    new TransformerParam(
    this,
    "bestModel", "The internal ALS model used splitter",
    { t => t.isInstanceOf[Model[_]] })

  /** @group getParam */
  def getBestModel: Model[_] = $(bestModel).asInstanceOf[Model[_]]

  def this() = this(Identifiable.randomUID("RankingTrainValidationSplitModel"))

  override def copy(extra: ParamMap): RankingTrainValidationSplitModel = {
    val copied = new RankingTrainValidationSplitModel(uid)
    copyValues(copied, extra).setParent(parent)
  }

  override def transform(dataset: Dataset[_]): DataFrame = {
    transformSchema(dataset.schema, logging = true)

    //sort to pass unit test
    $(bestModel).transform(dataset).sort("prediction")
  }

  override def transformSchema(schema: StructType): StructType = {
    $(bestModel).transformSchema(schema)
  }
}

object RankingTrainValidationSplitModel extends ComplexParamsReadable[RankingTrainValidationSplitModel]




© 2015 - 2024 Weber Informatics LLC | Privacy Policy