All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.ml.Predictor.scala Maven / Gradle / Ivy

There is a newer version: 4.0.0-preview2
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.ml

import org.apache.spark.annotation.Since
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}

/**
 * (private[ml])  Trait for parameters for prediction (regression and classification).
 */
private[ml] trait PredictorParams extends Params
  with HasLabelCol with HasFeaturesCol with HasPredictionCol {

  /**
   * Validates and transforms the input schema with the provided param map.
   *
   * @param schema input schema
   * @param fitting whether this is in fitting
   * @param featuresDataType  SQL DataType for FeaturesType.
   *                          E.g., `VectorUDT` for vector features.
   * @return output schema
   */
  protected def validateAndTransformSchema(
      schema: StructType,
      fitting: Boolean,
      featuresDataType: DataType): StructType = {
    // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector
    SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType)
    if (fitting) {
      SchemaUtils.checkNumericType(schema, $(labelCol))

      this match {
        case p: HasWeightCol =>
          if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) {
            SchemaUtils.checkNumericType(schema, $(p.weightCol))
          }
        case _ =>
      }
    }
    SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
  }

  /**
   * Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset,
   * and put it in an RDD with strong types.
   */
  protected def extractInstances(dataset: Dataset[_]): RDD[Instance] = {
    val w = this match {
      case p: HasWeightCol =>
        if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) {
          checkNonNegativeWeight((col($(p.weightCol)).cast(DoubleType)))
        } else {
          lit(1.0)
        }
    }

    dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
      case Row(label: Double, weight: Double, features: Vector) =>
        Instance(label, weight, features)
    }
  }

  /**
   * Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset,
   * and put it in an RDD with strong types.
   * Validate the output instances with the given function.
   */
  protected def extractInstances(
      dataset: Dataset[_],
      validateInstance: Instance => Unit): RDD[Instance] = {
    extractInstances(dataset).map { instance =>
      validateInstance(instance)
      instance
    }
  }
}

/**
 * Abstraction for prediction problems (regression and classification). It accepts all NumericType
 * labels and will automatically cast it to DoubleType in `fit()`. If this predictor supports
 * weights, it accepts all NumericType weights, which will be automatically casted to DoubleType
 * in `fit()`.
 *
 * @tparam FeaturesType  Type of features.
 *                       E.g., `VectorUDT` for vector features.
 * @tparam Learner  Specialization of this class.  If you subclass this type, use this type
 *                  parameter to specify the concrete type.
 * @tparam M  Specialization of [[PredictionModel]].  If you subclass this type, use this type
 *            parameter to specify the concrete type for the corresponding model.
 */
abstract class Predictor[
    FeaturesType,
    Learner <: Predictor[FeaturesType, Learner, M],
    M <: PredictionModel[FeaturesType, M]]
  extends Estimator[M] with PredictorParams {

  /** @group setParam */
  def setLabelCol(value: String): Learner = set(labelCol, value).asInstanceOf[Learner]

  /** @group setParam */
  def setFeaturesCol(value: String): Learner = set(featuresCol, value).asInstanceOf[Learner]

  /** @group setParam */
  def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner]

  override def fit(dataset: Dataset[_]): M = {
    // This handles a few items such as schema validation.
    // Developers only need to implement train().
    transformSchema(dataset.schema, logging = true)

    // Cast LabelCol to DoubleType and keep the metadata.
    val labelMeta = dataset.schema($(labelCol)).metadata
    val labelCasted = dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta)

    // Cast WeightCol to DoubleType and keep the metadata.
    val casted = this match {
      case p: HasWeightCol =>
        if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) {
          val weightMeta = dataset.schema($(p.weightCol)).metadata
          labelCasted.withColumn($(p.weightCol), col($(p.weightCol)).cast(DoubleType), weightMeta)
        } else {
          labelCasted
        }
      case _ => labelCasted
    }

    copyValues(train(casted).setParent(this))
  }

  override def copy(extra: ParamMap): Learner

  /**
   * Train a model using the given dataset and parameters.
   * Developers can implement this instead of `fit()` to avoid dealing with schema validation
   * and copying parameters into the model.
   *
   * @param dataset  Training dataset
   * @return  Fitted model
   */
  protected def train(dataset: Dataset[_]): M

  /**
   * Returns the SQL DataType corresponding to the FeaturesType type parameter.
   *
   * This is used by `validateAndTransformSchema()`.
   * This workaround is needed since SQL has different APIs for Scala and Java.
   *
   * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
   */
  private[ml] def featuresDataType: DataType = new VectorUDT

  override def transformSchema(schema: StructType): StructType = {
    validateAndTransformSchema(schema, fitting = true, featuresDataType)
  }

  /**
   * Extract [[labelCol]] and [[featuresCol]] from the given dataset,
   * and put it in an RDD with strong types.
   */
  protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = {
    dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
      case Row(label: Double, features: Vector) => LabeledPoint(label, features)
    }
  }
}

/**
 * Abstraction for a model for prediction tasks (regression and classification).
 *
 * @tparam FeaturesType  Type of features.
 *                       E.g., `VectorUDT` for vector features.
 * @tparam M  Specialization of [[PredictionModel]].  If you subclass this type, use this type
 *            parameter to specify the concrete type for the corresponding model.
 */
abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, M]]
  extends Model[M] with PredictorParams {

  /** @group setParam */
  def setFeaturesCol(value: String): M = set(featuresCol, value).asInstanceOf[M]

  /** @group setParam */
  def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M]

  /** Returns the number of features the model was trained on. If unknown, returns -1 */
  @Since("1.6.0")
  def numFeatures: Int = -1

  /**
   * Returns the SQL DataType corresponding to the FeaturesType type parameter.
   *
   * This is used by `validateAndTransformSchema()`.
   * This workaround is needed since SQL has different APIs for Scala and Java.
   *
   * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
   */
  protected def featuresDataType: DataType = new VectorUDT

  override def transformSchema(schema: StructType): StructType = {
    var outputSchema = validateAndTransformSchema(schema, fitting = false, featuresDataType)
    if ($(predictionCol).nonEmpty) {
      outputSchema = SchemaUtils.updateNumeric(outputSchema, $(predictionCol))
    }
    outputSchema
  }

  /**
   * Transforms dataset by reading from [[featuresCol]], calling `predict`, and storing
   * the predictions as a new column [[predictionCol]].
   *
   * @param dataset input dataset
   * @return transformed dataset with [[predictionCol]] of type `Double`
   */
  override def transform(dataset: Dataset[_]): DataFrame = {
    transformSchema(dataset.schema, logging = true)
    if ($(predictionCol).nonEmpty) {
      transformImpl(dataset)
    } else {
      this.logWarning(s"$uid: Predictor.transform() does nothing" +
        " because no output columns were set.")
      dataset.toDF
    }
  }

  protected def transformImpl(dataset: Dataset[_]): DataFrame = {
    val outputSchema = transformSchema(dataset.schema, logging = true)
    val predictUDF = udf { features: Any =>
      predict(features.asInstanceOf[FeaturesType])
    }
    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))),
      outputSchema($(predictionCol)).metadata)
  }

  /**
   * Predict label for the given features.
   * This method is used to implement `transform()` and output [[predictionCol]].
   */
  @Since("2.4.0")
  def predict(features: FeaturesType): Double
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy