All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.ml.util.DatasetUtils.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.ml.util

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.classification.ClassifierParams
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._


private[spark] object DatasetUtils extends Logging {

  private[ml] def checkNonNanValues(colName: String, displayed: String): Column = {
    val casted = col(colName).cast(DoubleType)
    when(casted.isNull || casted.isNaN, raise_error(lit(s"$displayed MUST NOT be Null or NaN")))
      .when(casted === Double.NegativeInfinity || casted === Double.PositiveInfinity,
        raise_error(concat(lit(s"$displayed MUST NOT be Infinity, but got "), casted)))
      .otherwise(casted)
  }

  private[ml] def checkRegressionLabels(labelCol: String): Column = {
    checkNonNanValues(labelCol, "Labels")
  }

  private[ml] def checkClassificationLabels(
    labelCol: String,
    numClasses: Option[Int]): Column = {
    val casted = col(labelCol).cast(DoubleType)
    numClasses match {
      case Some(2) =>
        when(casted.isNull || casted.isNaN, raise_error(lit("Labels MUST NOT be Null or NaN")))
          .when(casted =!= 0 && casted =!= 1,
            raise_error(concat(lit("Labels MUST be in {0, 1}, but got "), casted)))
          .otherwise(casted)

      case _ =>
        val n = numClasses.getOrElse(Int.MaxValue)
        require(0 < n && n <= Int.MaxValue)
        when(casted.isNull || casted.isNaN, raise_error(lit("Labels MUST NOT be Null or NaN")))
          .when(casted < 0 || casted >= n,
            raise_error(concat(lit(s"Labels MUST be in [0, $n), but got "), casted)))
          .when(casted =!= casted.cast(IntegerType),
            raise_error(concat(lit("Labels MUST be Integers, but got "), casted)))
          .otherwise(casted)
    }
  }

  private[ml] def checkNonNegativeWeights(weightCol: String): Column = {
    val casted = col(weightCol).cast(DoubleType)
    when(casted.isNull || casted.isNaN, raise_error(lit("Weights MUST NOT be Null or NaN")))
      .when(casted < 0 || casted === Double.PositiveInfinity,
        raise_error(concat(lit("Weights MUST NOT be Negative or Infinity, but got "), casted)))
      .otherwise(casted)
  }

  private[ml] def checkNonNegativeWeights(weightCol: Option[String]): Column = weightCol match {
    case Some(w) if w.nonEmpty => checkNonNegativeWeights(w)
    case _ => lit(1.0)
  }

  private[ml] def checkNonNanVectors(vectorCol: Column): Column = {
    when(vectorCol.isNull, raise_error(lit("Vectors MUST NOT be Null")))
      .when(!validateVector(vectorCol),
        raise_error(concat(lit("Vector values MUST NOT be NaN or Infinity, but got "),
          vectorCol.cast(StringType))))
      .otherwise(vectorCol)
  }

  private[ml] def checkNonNanVectors(vectorCol: String): Column = {
    checkNonNanVectors(col(vectorCol))
  }

  private lazy val validateVector = udf { vector: Vector =>
    vector match {
      case dv: DenseVector =>
        dv.values.forall(v => !v.isNaN && !v.isInfinity)
      case sv: SparseVector =>
        sv.values.forall(v => !v.isNaN && !v.isInfinity)
    }
  }

  private[ml] def extractInstances(
      p: PredictorParams,
      df: Dataset[_],
      numClasses: Option[Int] = None): RDD[Instance] = {
    val labelCol = p match {
      case c: ClassifierParams =>
        checkClassificationLabels(c.getLabelCol, numClasses)
      case _ => // TODO: there is no RegressorParams, maybe add it in the future?
        checkRegressionLabels(p.getLabelCol)
    }

    val weightCol = p match {
      case w: HasWeightCol => checkNonNegativeWeights(w.get(w.weightCol))
      case _ => lit(1.0)
    }

    df.select(labelCol, weightCol, checkNonNanVectors(p.getFeaturesCol))
      .rdd.map { case Row(l: Double, w: Double, v: Vector) => Instance(l, w, v) }
  }

  /**
   * Cast a column in a Dataset to Vector type.
   *
   * The supported data types of the input column are
   * - Vector
   * - float/double type Array.
   *
   * Note: The returned column does not have Metadata.
   *
   * @param dataset input DataFrame
   * @param colName column name.
   * @return Vector column
   */
  def columnToVector(dataset: Dataset[_], colName: String): Column = {
    val columnDataType = dataset.schema(colName).dataType
    columnDataType match {
      case _: VectorUDT => col(colName)
      case fdt: ArrayType =>
        val transferUDF = fdt.elementType match {
          case _: FloatType => udf(f = (vector: Seq[Float]) => {
            val inputArray = Array.ofDim[Double](vector.size)
            vector.indices.foreach(idx => inputArray(idx) = vector(idx).toDouble)
            Vectors.dense(inputArray)
          })
          case _: DoubleType => udf((vector: Seq[Double]) => {
            Vectors.dense(vector.toArray)
          })
          case other =>
            throw new IllegalArgumentException(s"Array[$other] column cannot be cast to Vector")
        }
        transferUDF(col(colName))
      case other =>
        throw new IllegalArgumentException(s"$other column cannot be cast to Vector")
    }
  }

  def columnToOldVector(dataset: Dataset[_], colName: String): RDD[OldVector] = {
    dataset.select(columnToVector(dataset, colName))
      .rdd.map {
      case Row(point: Vector) => OldVectors.fromML(point)
    }
  }

  /**
   * Get the number of classes.  This looks in column metadata first, and if that is missing,
   * then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses
   * by finding the maximum label value.
   *
   * Label validation (ensuring all labels are integers >= 0) needs to be handled elsewhere,
   * such as in `extractLabeledPoints()`.
   *
   * @param dataset  Dataset which contains a column [[labelCol]]
   * @param maxNumClasses  Maximum number of classes allowed when inferred from data.  If numClasses
   *                       is specified in the metadata, then maxNumClasses is ignored.
   * @return  number of classes
   * @throws IllegalArgumentException  if metadata does not specify numClasses, and the
   *                                   actual numClasses exceeds maxNumClasses
   */
  private[ml] def getNumClasses(
      dataset: Dataset[_],
      labelCol: String,
      maxNumClasses: Int = 100): Int = {
    MetadataUtils.getNumClasses(dataset.schema(labelCol)) match {
      case Some(n: Int) => n
      case None =>
        // Get number of classes from dataset itself.
        val maxLabelRow: Array[Row] = dataset
          .select(max(checkClassificationLabels(labelCol, Some(maxNumClasses))))
          .take(1)
        if (maxLabelRow.isEmpty || maxLabelRow(0).get(0) == null) {
          throw new SparkException("ML algorithm was given empty dataset.")
        }
        val maxDoubleLabel: Double = maxLabelRow.head.getDouble(0)
        require((maxDoubleLabel + 1).isValidInt, s"Classifier found max label value =" +
          s" $maxDoubleLabel but requires integers in range [0, ... ${Int.MaxValue})")
        val numClasses = maxDoubleLabel.toInt + 1
        require(numClasses <= maxNumClasses, s"Classifier inferred $numClasses from label values" +
          s" in column $labelCol, but this exceeded the max numClasses ($maxNumClasses) allowed" +
          s" to be inferred from values.  To avoid this error for labels with > $maxNumClasses" +
          s" classes, specify numClasses explicitly in the metadata; this can be done by applying" +
          s" StringIndexer to the label column.")
        logInfo(this.getClass.getCanonicalName + s" inferred $numClasses classes for" +
          s" labelCol=$labelCol since numClasses was not specified in the column metadata.")
        numClasses
    }
  }

  /**
   * Obtain the number of features in a vector column.
   * If no metadata is available, extract it from the dataset.
   */
  private[ml] def getNumFeatures(dataset: Dataset[_], vectorCol: String): Int = {
    MetadataUtils.getNumFeatures(dataset.schema(vectorCol)).getOrElse {
      dataset.select(columnToVector(dataset, vectorCol)).head.getAs[Vector](0).size
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy