Maven / Gradle / Ivy
The newest version!
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import{FactorizationMachines, FactorizationMachinesParams}
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
import org.apache.spark.sql._
* Params for FMClassifier.
private[classification] trait FMClassifierParams extends ProbabilisticClassifierParams
with FactorizationMachinesParams {
* Factorization Machines learning algorithm for classification.
* It supports normal gradient descent and AdamW solver.
* The implementation is based upon:
* S. Rendle. "Factorization machines" 2010.
* FM is able to estimate interactions even in problems with huge sparsity
* (like advertising and recommendation system).
* FM formula is:
* $$
* \begin{align}
* y = \sigma\left( w_0 + \sum\limits^n_{i-1} w_i x_i +
* \sum\limits^n_{i=1} \sum\limits^n_{j=i+1} \langle v_i, v_j \rangle x_i x_j \right)
* \end{align}
* $$
* First two terms denote global bias and linear term (as same as linear regression),
* and last term denotes pairwise interactions term. v_i describes the i-th variable
* with k factors.
* FM classification model uses logistic loss which can be solved by gradient descent method, and
* regularization terms like L2 are usually added to the loss function to prevent overfitting.
* @note Multiclass labels are not currently supported.
class FMClassifier @Since("3.0.0") (
@Since("3.0.0") override val uid: String)
extends ProbabilisticClassifier[Vector, FMClassifier, FMClassificationModel]
with FactorizationMachines with FMClassifierParams with DefaultParamsWritable with Logging {
def this() = this(Identifiable.randomUID("fmc"))
* Set the dimensionality of the factors.
* Default is 8.
* @group setParam
def setFactorSize(value: Int): this.type = set(factorSize, value)
* Set whether to fit intercept term.
* Default is true.
* @group setParam
def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
* Set whether to fit linear term.
* Default is true.
* @group setParam
def setFitLinear(value: Boolean): this.type = set(fitLinear, value)
* Set the L2 regularization parameter.
* Default is 0.0.
* @group setParam
def setRegParam(value: Double): this.type = set(regParam, value)
* Set the mini-batch fraction parameter.
* Default is 1.0.
* @group setParam
def setMiniBatchFraction(value: Double): this.type = set(miniBatchFraction, value)
* Set the standard deviation of initial coefficients.
* Default is 0.01.
* @group setParam
def setInitStd(value: Double): this.type = set(initStd, value)
* Set the maximum number of iterations.
* Default is 100.
* @group setParam
def setMaxIter(value: Int): this.type = set(maxIter, value)
* Set the initial step size for the first step (like learning rate).
* Default is 1.0.
* @group setParam
def setStepSize(value: Double): this.type = set(stepSize, value)
* Set the convergence tolerance of iterations.
* Default is 1E-6.
* @group setParam
def setTol(value: Double): this.type = set(tol, value)
* Set the solver algorithm used for optimization.
* Supported options: "gd", "adamW".
* Default: "adamW"
* @group setParam
def setSolver(value: String): this.type = set(solver, value)
* Set the random seed for weight initialization.
* @group setParam
def setSeed(value: Long): this.type = set(seed, value)
override protected def train(
dataset: Dataset[_]): FMClassificationModel = instrumented { instr =>
val numClasses = 2
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".train() called with non-matching numClasses and thresholds.length." +
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
instr.logParams(this, factorSize, fitIntercept, fitLinear, regParam,
miniBatchFraction, initStd, maxIter, stepSize, tol, solver, thresholds)
val numFeatures = getNumFeatures(dataset, $(featuresCol))
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
val data =
checkClassificationLabels($(labelCol), Some(2)),
) { case Row(l: Double, v: Vector) => (l, OldVectors.fromML(v))
}.setName("training instances")
if (handlePersistence) data.persist(StorageLevel.MEMORY_AND_DISK)
val (coefficients, objectiveHistory) = trainImpl(data, numFeatures, LogisticLoss)
val (intercept, linear, factors) = splitCoefficients(
coefficients, numFeatures, $(factorSize), $(fitIntercept), $(fitLinear))
if (handlePersistence) data.unpersist()
createModel(dataset, intercept, linear, factors, objectiveHistory)
private def createModel(
dataset: Dataset[_],
intercept: Double,
linear: Vector,
factors: Matrix,
objectiveHistory: Array[Double]): FMClassificationModel = {
val model = copyValues(new FMClassificationModel(uid, intercept, linear, factors))
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
val summary = new FMClassificationTrainingSummaryImpl(
override def copy(extra: ParamMap): FMClassifier = defaultCopy(extra)
object FMClassifier extends DefaultParamsReadable[FMClassifier] {
override def load(path: String): FMClassifier = super.load(path)
* Model produced by [[FMClassifier]]
class FMClassificationModel private[classification] (
@Since("3.0.0") override val uid: String,
@Since("3.0.0") val intercept: Double,
@Since("3.0.0") val linear: Vector,
@Since("3.0.0") val factors: Matrix)
extends ProbabilisticClassificationModel[Vector, FMClassificationModel]
with FMClassifierParams with MLWritable
with HasTrainingSummary[FMClassificationTrainingSummary]{
override val numClasses: Int = 2
override val numFeatures: Int = linear.size
* Gets summary of model on training set. An exception is thrown
* if `hasSummary` is false.
override def summary: FMClassificationTrainingSummary = super.summary
* Evaluates the model on a test dataset.
* @param dataset Test dataset to evaluate model on.
def evaluate(dataset: Dataset[_]): FMClassificationSummary = {
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
// Handle possible missing or invalid probability or prediction columns
val (summaryModel, probability, predictionColName) = findSummaryModel()
new FMClassificationSummaryImpl(summaryModel.transform(dataset),
probability, predictionColName, $(labelCol), weightColName)
override def predictRaw(features: Vector): Vector = {
val rawPrediction = getRawPrediction(features, intercept, linear, factors)
Vectors.dense(Array(-rawPrediction, rawPrediction))
override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
rawPrediction match {
case dv: DenseVector =>
dv.values(1) = 1.0 / (1.0 + math.exp(-dv.values(1)))
dv.values(0) = 1.0 - dv.values(1)
case sv: SparseVector =>
throw new RuntimeException("Unexpected error in FMClassificationModel:" +
" raw2probabilityInPlace encountered SparseVector")
override def copy(extra: ParamMap): FMClassificationModel = {
copyValues(new FMClassificationModel(uid, intercept, linear, factors), extra)
override def write: MLWriter =
new FMClassificationModel.FMClassificationModelWriter(this)
override def toString: String = {
s"FMClassificationModel: " +
s"uid=${super.toString}, numClasses=$numClasses, numFeatures=$numFeatures, " +
s"factorSize=${$(factorSize)}, fitLinear=${$(fitLinear)}, fitIntercept=${$(fitIntercept)}"
object FMClassificationModel extends MLReadable[FMClassificationModel] {
override def read: MLReader[FMClassificationModel] = new FMClassificationModelReader
override def load(path: String): FMClassificationModel = super.load(path)
/** [[MLWriter]] instance for [[FMClassificationModel]] */
private[FMClassificationModel] class FMClassificationModelWriter(
instance: FMClassificationModel) extends MLWriter with Logging {
private case class Data(
intercept: Double,
linear: Vector,
factors: Matrix)
override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.intercept, instance.linear, instance.factors)
val dataPath = new Path(path, "data").toString
private class FMClassificationModelReader extends MLReader[FMClassificationModel] {
private val className = classOf[FMClassificationModel].getName
override def load(path: String): FMClassificationModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data ="parquet").load(dataPath)
val Row(intercept: Double, linear: Vector, factors: Matrix) ="intercept", "linear", "factors").head()
val model = new FMClassificationModel(metadata.uid, intercept, linear, factors)
* Abstraction for FMClassifier results for a given model.
sealed trait FMClassificationSummary extends BinaryClassificationSummary
* Abstraction for FMClassifier training results.
sealed trait FMClassificationTrainingSummary extends FMClassificationSummary with TrainingSummary
* FMClassifier results for a given model.
* @param predictions dataframe output by the model's `transform` method.
* @param scoreCol field in "predictions" which gives the probability of each instance.
* @param predictionCol field in "predictions" which gives the prediction for a data instance as a
* double.
* @param labelCol field in "predictions" which gives the true label of each instance.
* @param weightCol field in "predictions" which gives the weight of each instance.
private class FMClassificationSummaryImpl(
@transient override val predictions: DataFrame,
override val scoreCol: String,
override val predictionCol: String,
override val labelCol: String,
override val weightCol: String)
extends FMClassificationSummary
* FMClassifier training results.
* @param predictions dataframe output by the model's `transform` method.
* @param scoreCol field in "predictions" which gives the probability of each instance.
* @param predictionCol field in "predictions" which gives the prediction for a data instance as a
* double.
* @param labelCol field in "predictions" which gives the true label of each instance.
* @param weightCol field in "predictions" which gives the weight of each instance.
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
private class FMClassificationTrainingSummaryImpl(
predictions: DataFrame,
scoreCol: String,
predictionCol: String,
labelCol: String,
weightCol: String,
override val objectiveHistory: Array[Double])
extends FMClassificationSummaryImpl(
predictions, scoreCol, predictionCol, labelCol, weightCol)
with FMClassificationTrainingSummary
© 2015 - 2025 Weber Informatics LLC | Privacy Policy