All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.recommendation.SARModel.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.recommendation

import com.microsoft.ml.spark.core.contracts.Wrappable
import com.microsoft.ml.spark.core.env.InternalWrapper
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Model}
import org.apache.spark.ml.param.{DataFrameParam, ParamMap}
import org.apache.spark.ml.recommendation.{BaseRecommendationModel, Constants}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.DenseVector
import org.apache.spark.mllib.linalg.distributed.{CoordinateMatrix, MatrixEntry}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{StructField => SF, _}
import org.apache.spark.sql.{DataFrame, Dataset, Row}

/** SAR Model
  *
  * @param uid The id of the module
  */
@InternalWrapper
class SARModel(override val uid: String) extends Model[SARModel]
  with BaseRecommendationModel with Wrappable with SARParams with ComplexParamsWritable {

  /** @group setParam */
  def setUserDataFrame(value: DataFrame): this.type = set(userDataFrame, value)

  val userDataFrame = new DataFrameParam(this, "userDataFrame", "Time of activity")

  /** @group getParam */
  def getUserDataFrame: DataFrame = $(userDataFrame)

  /** @group setParam */
  def setItemDataFrame(value: DataFrame): this.type = set(itemDataFrame, value)

  val itemDataFrame = new DataFrameParam(this, "itemDataFrame", "Time of activity")

  /** @group getParam */
  def getItemDataFrame: DataFrame = $(itemDataFrame)

  def this() = this(Identifiable.randomUID("SARModel"))

  /**
    * Returns top `numItems` items recommended for each user, for all users.
    *
    * @param numItems max number of recommendations for each user
    * @return a DataFrame of (userCol: Int, recommendations), where recommendations are
    *         stored as an array of (itemCol: Int, rating: Float) Rows.
    */
  def recommendForAllUsers(numItems: Int): DataFrame = {
    recommendForAll(getUserDataFrame, getItemDataFrame, getUserCol, getItemCol, numItems)
  }

  /**
    * Returns top `numItems` items recommended for each user id in the input data set. Note that if
    * there are duplicate ids in the input dataset, only one set of recommendations per unique id
    * will be returned.
    *
    * @param dataset  a Dataset containing a column of user ids. The column name must match `userCol`.
    * @param numItems max number of recommendations for each user.
    * @return a DataFrame of (userCol: Int, recommendations), where recommendations are
    *         stored as an array of (itemCol: Int, rating: Float) Rows.
    */
  def recommendForUserSubset(dataset: Dataset[_], numItems: Int): DataFrame = {
    val srcFactorSubset = getSourceFactorSubset(dataset, getUserDataFrame, getUserCol)
    recommendForAll(srcFactorSubset, getItemDataFrame, getUserCol, getItemCol, numItems)
  }

  /**
    * Returns a subset of a factor DataFrame limited to only those unique ids contained
    * in the input dataset.
    *
    * @param dataset input Dataset containing id column to user to filter factors.
    * @param factors factor DataFrame to filter.
    * @param column  column name containing the ids in the input dataset.
    * @return DataFrame containing factors only for those ids present in both the input dataset and
    *         the factor DataFrame.
    */
  private def getSourceFactorSubset(
    dataset: Dataset[_],
    factors: DataFrame,
    column: String): DataFrame = {
    factors
      .join(dataset.select(column), factors(getUserCol) === dataset(column), joinType = "left_semi")
      .select(factors(getUserCol), factors("flatList"))
  }

  /**
    * Personalized recommendations for a single user are obtained by multiplying the Item-to-Item similarity matrix
    * with a user affinity vector. The user affinity vector is simply a transposed row of the affinity matrix
    * corresponding to that user.
    *
    * @param num
    * @return
    */
  private def recommendForAll(
    srcFactors: DataFrame,
    dstFactors: DataFrame,
    srcOutputColumn: String,
    dstOutputColumn: String,
    num: Int): DataFrame = {

    def dfToRDDMatrxEntry(dataframe: DataFrame) = {
      dataframe.rdd
        .flatMap(row =>
          row.getAs[Seq[Float]](1).zipWithIndex.map { case (list, index) => Row(row.getDouble(0), index, list) })
        .map(item => MatrixEntry(item.getDouble(0).toLong, item.getInt(1).toLong, item.getFloat(2).toDouble))
    }

    val sourceMatrix = new CoordinateMatrix(dfToRDDMatrxEntry(srcFactors)).toBlockMatrix()//.cache()
    val destMatrix = new CoordinateMatrix(dfToRDDMatrxEntry(dstFactors)).toBlockMatrix()//.cache()

    val userToItemMatrix = sourceMatrix
      .multiply(destMatrix)
      .toIndexedRowMatrix()
      .rows.map(indexedRow => (indexedRow.index.toInt, indexedRow.vector))

    val orderAndTakeTopK = udf((vector: DenseVector) => {
      vector.toArray.zipWithIndex
        .map { case (list, index) => (index, list) }
        .sortWith(_._2 > _._2)
        .take(num)
    })

    val recommendationArrayType =
      ArrayType(new StructType(Array(SF(dstOutputColumn, IntegerType), SF(Constants.RatingCol, FloatType))))

    getUserDataFrame.sparkSession.createDataFrame(userToItemMatrix)
      .toDF(id, ratings).withColumn(recommendations, orderAndTakeTopK(col(ratings))).select(id, recommendations)
      .select(col(id).as(getUserCol), col(recommendations).cast(recommendationArrayType))
  }

  private val id = Constants.IdCol
  private val ratings = Constants.RatingCol + "s"
  private val recommendations = Constants.Recommendations

  override def copy(extra: ParamMap): SARModel = {
    val copied = new SARModel(uid)
    copyValues(copied, extra).setParent(parent)
  }

  override def transform(dataset: Dataset[_]): DataFrame = {
    transform($(rank), $(userDataFrame), $(itemDataFrame), dataset)
  }

  override def transformSchema(schema: StructType): StructType = {
    checkNumericType(schema, $(userCol))
    checkNumericType(schema, $(itemCol))
    schema
  }

  /**
    * Check whether the given schema contains a column of the numeric data type.
    *
    * @param colName column name
    */
  private def checkNumericType(
    schema: StructType,
    colName: String,
    msg: String = ""): Unit = {
    val actualDataType = schema(colName).dataType
    val message = if (msg != null && msg.trim.length > 0) " " + msg else ""
    require(actualDataType.isInstanceOf[NumericType], s"Column $colName must be of type " +
      s"NumericType but was actually of type $actualDataType.$message")
  }
}

object SARModel extends ComplexParamsReadable[SARModel]




© 2015 - 2024 Weber Informatics LLC | Privacy Policy