All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.recommendation.RecommendationIndexer.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.recommendation

import com.microsoft.ml.spark.core.contracts.Wrappable
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.feature.{StringIndexer, StringIndexerModel}
import org.apache.spark.ml.param.{Param, ParamMap, Params, TransformerParam}
import org.apache.spark.ml.util._
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Estimator, Model}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{NumericType, StringType, StructType}
import org.apache.spark.sql.{DataFrame, Dataset}

class RecommendationIndexer(override val uid: String)
  extends Estimator[RecommendationIndexerModel] with RecommendationIndexerBase {

  def this() = this(Identifiable.randomUID("RecommendationIndexer"))

  override def fit(dataset: Dataset[_]): RecommendationIndexerModel = {
    val userIndexModel: StringIndexerModel = new StringIndexer()
      .setInputCol(getUserInputCol)
      .setOutputCol(getUserOutputCol)
      .fit(dataset)

    val itemIndexModel: StringIndexerModel = new StringIndexer()
      .setInputCol(getItemInputCol)
      .setOutputCol(getItemOutputCol)
      .fit(dataset)

    new RecommendationIndexerModel(uid)
      .setParent(this)
      .setUserIndexModel(userIndexModel)
      .setItemIndexModel(itemIndexModel)
      .setUserInputCol(getUserInputCol)
      .setUserOutputCol(getUserOutputCol)
      .setItemInputCol(getItemInputCol)
      .setItemOutputCol(getItemOutputCol)
  }

  override def copy(extra: ParamMap): Estimator[RecommendationIndexerModel] = defaultCopy(extra)

}

object RecommendationIndexer extends ComplexParamsReadable[RecommendationIndexer]

class RecommendationIndexerModel(override val uid: String) extends Model[RecommendationIndexerModel] with
  RecommendationIndexerBase with Wrappable {
  override def copy(extra: ParamMap): RecommendationIndexerModel = defaultCopy(extra)

  override def transform(dataset: Dataset[_]): DataFrame = {
    getItemIndexModel.transform(getUserIndexModel.transform(dataset))
  }

  def this() = this(Identifiable.randomUID("RecommendationIndexerModel"))

  val userIndexModel = new TransformerParam(this, "userIndexModel", "userIndexModel", {
    case _: StringIndexerModel => true
    case _ => false
  })

  def setUserIndexModel(m: StringIndexerModel): this.type = set(userIndexModel, m)

  def getUserIndexModel: StringIndexerModel = $(userIndexModel).asInstanceOf[StringIndexerModel]

  val itemIndexModel = new TransformerParam(this, "itemIndexModel", "itemIndexModel", {
    case _: StringIndexerModel => true
    case _ => false
  })

  def setItemIndexModel(m: StringIndexerModel): this.type = set(itemIndexModel, m)

  def getItemIndexModel: StringIndexerModel = $(itemIndexModel).asInstanceOf[StringIndexerModel]

  def getUserIndex: Map[Int, String] = {
    getUserIndexModel
      .labels
      .zipWithIndex
      .map(t => (t._2, t._1))
      .toMap
  }

  def getItemIndex: Map[Int, String] = {
    getItemIndexModel
      .labels
      .zipWithIndex
      .map(t => (t._2, t._1))
      .toMap
  }

  def recoverUser(): UserDefinedFunction = udf((userID: Integer) => getUserIndex.getOrElse[String](userID, "-1"))

  def recoverItem(): UserDefinedFunction = udf((itemID: Integer) => getItemIndex.getOrElse[String](itemID, "-1"))

}

object RecommendationIndexerModel extends ComplexParamsReadable[RecommendationIndexerModel]

trait RecommendationIndexerBase extends Params with ComplexParamsWritable {
  /** @group setParam */
  def setUserInputCol(value: String): this.type = set(userInputCol, value)

  /** @group getParam */
  def getUserInputCol: String = $(userInputCol)

  val userInputCol = new Param[String](this, "userInputCol", "User Input Col")

  /** @group setParam */
  def setUserOutputCol(value: String): this.type = set(userOutputCol, value)

  /** @group getParam */
  def getUserOutputCol: String = $(userOutputCol)

  val userOutputCol = new Param[String](this, "userOutputCol", "User Output Col")

  /** @group setParam */
  def setItemInputCol(value: String): this.type = set(itemInputCol, value)

  /** @group getParam */
  def getItemInputCol: String = $(itemInputCol)

  val itemInputCol = new Param[String](this, "itemInputCol", "Item Input Col")

  /** @group setParam */
  def setItemOutputCol(value: String): this.type = set(itemOutputCol, value)

  /** @group getParam */
  def getItemOutputCol: String = $(itemOutputCol)

  val itemOutputCol = new Param[String](this, "itemOutputCol", "Item Output Col")

  /** @group setParam */
  def setRatingCol(value: String): this.type = set(ratingCol, value)

  /** @group getParam */
  def getRatingCol: String = $(ratingCol)

  val ratingCol = new Param[String](this, "ratingCol", "Rating Col")

  def transformSchema(schema: StructType): StructType = {
    val userInputColName = getUserInputCol
    val userInputDataType = schema(userInputColName).dataType
    require(userInputDataType == StringType || userInputDataType.isInstanceOf[NumericType],
      s"The input column $userInputColName must be either string type or numeric type, " +
        s"but got $userInputDataType .")
    val itemInputColName = getItemInputCol
    val itemInputDataType = schema(userInputColName).dataType
    require(itemInputDataType == StringType || itemInputDataType.isInstanceOf[NumericType],
      s"The input column $itemInputColName must be either string type or numeric type, " +
        s"but got $itemInputDataType .")

    val inputFields = schema.fields
    val userOutputColName = getUserOutputCol
    require(inputFields.forall(_.name != userOutputColName),
      s"Output column $userOutputColName already exists.")
    val userAttr = NominalAttribute.defaultAttr.withName(getUserOutputCol)
    val itemOutputColName = getItemOutputCol
    require(inputFields.forall(_.name != itemOutputColName),
      s"Output column $itemOutputColName already exists.")
    val itemAttr = NominalAttribute.defaultAttr.withName(getItemOutputCol)

    val outputFields = inputFields :+ userAttr.toStructField() :+ itemAttr.toStructField()
    StructType(outputFields)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy