All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.sona.ml.feature.VectorAssembler.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.tencent.angel.sona.ml.feature

import java.util.NoSuchElementException

import org.apache.spark.SparkException
import com.tencent.angel.sona.ml.Transformer
import com.tencent.angel.sona.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute}
import org.apache.spark.linalg.{VectorUDT, Vectors}
import com.tencent.angel.sona.ml.param.{Param, ParamMap, ParamValidators}
import com.tencent.angel.sona.ml.param.shared.{HasHandleInvalid, HasInputCols, HasOutputCol}
import com.tencent.angel.sona.ml.util._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.linalg
import scala.collection.mutable
import scala.language.existentials

/**
  * A feature transformer that merges multiple columns into a vector column.
  *
  * This requires one pass over the entire dataset. In case we need to infer column lengths from the
  * data we require an additional call to the 'first' Dataset method, see 'handleInvalid' parameter.
  */

class VectorAssembler(override val uid: String)
  extends Transformer with HasInputCols with HasOutputCol with HasHandleInvalid
    with DefaultParamsWritable {

  def this() = this(Identifiable.randomUID("vecAssembler"))

  /** @group setParam */

  def setInputCols(value: Array[String]): this.type = set(inputCols, value)

  /** @group setParam */

  def setOutputCol(value: String): this.type = set(outputCol, value)

  /** @group setParam */

  def setHandleInvalid(value: String): this.type = set(handleInvalid, value)

  /**
    * Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
    * invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the
    * output). Column lengths are taken from the size of ML Attribute Group, which can be set using
    * `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred
    * from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'.
    * Default: "error"
    *
    * @group param
    */

  override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
    """Param for how to handle invalid data (NULL and NaN values). Options are 'skip' (filter out
      |rows with invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN
      |in the output). Column lengths are taken from the size of ML Attribute Group, which can be
      |set using `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also
      |be inferred from first rows of the data since it is safe to do so but only in case of 'error'
      |or 'skip'.""".stripMargin.replaceAll("\n", " "),
    ParamValidators.inArray(VectorAssembler.supportedHandleInvalids))

  setDefault(handleInvalid, VectorAssembler.ERROR_INVALID)


  override def transform(dataset: Dataset[_]): DataFrame = {
    transformSchema(dataset.schema, logging = true)
    // Schema transformation.
    val schema = dataset.schema

    val vectorCols = $(inputCols).filter { c =>
      schema(c).dataType match {
        case _: VectorUDT => true
        case _ => false
      }
    }

    val vectorColsLengths = VectorAssembler.getLengths(dataset, vectorCols, $(handleInvalid))

    val featureAttributesMap = $(inputCols).map { c =>
      val field = schema(c)
      field.dataType match {
        case DoubleType =>
          val attribute = Attribute.fromStructField(field)
          attribute match {
            case UnresolvedAttribute =>
              Seq(NumericAttribute.defaultAttr.withName(c))
            case _ =>
              Seq(attribute.withName(c))
          }
        case _: NumericType | BooleanType =>
          // If the input column type is a compatible scalar type, assume numeric.
          Seq(NumericAttribute.defaultAttr.withName(c))
        case _: VectorUDT if vectorColsLengths(c) < Int.MaxValue =>
          val attributeGroup = AttributeGroup.fromStructField(field)
          if (attributeGroup.attributes.isDefined) {
            attributeGroup.attributes.get.zipWithIndex.toSeq.map { case (attr, i) =>
              if (attr.name.isDefined) {
                // TODO: Define a rigorous naming scheme.
                attr.withName(c + "_" + attr.name.get)
              } else {
                attr.withName(c + "_" + i)
              }
            }
          } else {
            // Otherwise, treat all attributes as numeric. If we cannot get the number of attributes
            // from metadata, check the first row.
            // TODO: need to handle long key vector
            (0 until vectorColsLengths(c).toInt).map { i =>
              NumericAttribute.defaultAttr.withName(c + "_" + i)
            }
          }
        case _: VectorUDT if vectorColsLengths(c) > Int.MaxValue =>
          Seq.empty[Attribute]
        case otherType =>
          throw new SparkException(s"VectorAssembler does not support the $otherType type")
      }

      //    val vectorColsLengths = VectorAssembler.getLengths(dataset, vectorCols, $(handleInvalid))
      //
      //    val lengths = $(inputCols).map{
      //      case name if vectorColsLengths.contains(name) => vectorColsLengths(name)
      //      case _ => 1L
    }
    val featureAttributes = featureAttributesMap.flatten[Attribute]
    //    val lengths = featureAttributesMap.map(a => a.length)
    val lengths = $(inputCols).map {
      case name if vectorColsLengths.contains(name) => vectorColsLengths(name)
      case _ => 1L
    }
    val metadata = new AttributeGroup($(outputCol), featureAttributes).toMetadata

    //    val metadata = new AttributeGroup($(outputCol), numAttributes=lengths.sum).toMetadata()
    val (filteredDataset, keepInvalid) = $(handleInvalid) match {
      case VectorAssembler.SKIP_INVALID => (dataset.na.drop($(inputCols)), false)
      case VectorAssembler.KEEP_INVALID => (dataset, true)
      case VectorAssembler.ERROR_INVALID => (dataset, false)
    }
    // Data transformation.
    val assembleFunc = if (lengths.sum < Int.MaxValue) {
      udf { r: Row =>
        VectorAssembler.assembleInt(lengths.map(_.toInt), keepInvalid)(r.toSeq: _*)
      }
    } else {
      udf { r: Row =>
        VectorAssembler.assembleLong(lengths, keepInvalid)(r.toSeq: _*)
      }
    }

    val args = $(inputCols).map { c =>
      schema(c).dataType match {
        case DoubleType => dataset(c)
        case _: VectorUDT => dataset(c)
        case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid")
      }
    }

    filteredDataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata))
  }


  override def transformSchema(schema: StructType): StructType = {
    val inputColNames = $(inputCols)
    val outputColName = $(outputCol)
    val incorrectColumns = inputColNames.flatMap { name =>
      schema(name).dataType match {
        case _: NumericType | BooleanType => None
        case t if t.isInstanceOf[VectorUDT] => None
        case other => Some(s"Data type ${other.catalogString} of column $name is not supported.")
      }
    }
    if (incorrectColumns.nonEmpty) {
      throw new IllegalArgumentException(incorrectColumns.mkString("\n"))
    }
    if (schema.fieldNames.contains(outputColName)) {
      throw new IllegalArgumentException(s"Output column $outputColName already exists.")
    }
    StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, true))
  }


  override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra)
}


object VectorAssembler extends DefaultParamsReadable[VectorAssembler] {

  private[sona] val SKIP_INVALID: String = "skip"
  private[sona] val ERROR_INVALID: String = "error"
  private[sona] val KEEP_INVALID: String = "keep"
  private[sona] val supportedHandleInvalids: Array[String] =
    Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)

  /**
    * Infers lengths of vector columns from the first row of the dataset
    *
    * @param dataset the dataset
    * @param columns name of vector columns whose lengths need to be inferred
    * @return map of column names to lengths
    */
  private[sona] def getVectorLengthsFromFirstRow(
                                                  dataset: Dataset[_],
                                                  columns: Seq[String]): Map[String, Long] = {
    try {
      val first_row = dataset.toDF().select(columns.map(col): _*).first()
      columns.zip(first_row.toSeq).map {
        case (c, x) => c -> x.asInstanceOf[linalg.Vector].size
      }.toMap
    } catch {
      case e: NullPointerException => throw new NullPointerException(
        s"""Encountered null value while inferring lengths from the first row. Consider using
           |VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """
          .stripMargin.replaceAll("\n", " ") + e.toString)
      case e: NoSuchElementException => throw new NoSuchElementException(
        s"""Encountered empty dataframe while inferring lengths from the first row. Consider using
           |VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """
          .stripMargin.replaceAll("\n", " ") + e.toString)
    }
  }

  private[sona] def getLengths(
                                dataset: Dataset[_],
                                columns: Seq[String],
                                handleInvalid: String): Map[String, Long] = {
    val groupSizes = columns.map { c =>
      c -> AttributeGroup.fromStructField(dataset.schema(c)).size
    }.toMap
    val missingColumns = groupSizes.filter(_._2 == -1).keys.toSeq
    val firstSizes = (missingColumns.nonEmpty, handleInvalid) match {
      case (true, VectorAssembler.ERROR_INVALID) =>
        getVectorLengthsFromFirstRow(dataset, missingColumns)
      case (true, VectorAssembler.SKIP_INVALID) =>
        getVectorLengthsFromFirstRow(dataset.na.drop(missingColumns), missingColumns)
      case (true, VectorAssembler.KEEP_INVALID) => throw new RuntimeException(
        s"""Can not infer column lengths with handleInvalid = "keep". Consider using VectorSizeHint
           |to add metadata for columns: ${columns.mkString("[", ", ", "]")}."""
          .stripMargin.replaceAll("\n", " "))
      case (_, _) => Map.empty
    }
    groupSizes ++ firstSizes
  }


  override def load(path: String): VectorAssembler = super.load(path)

  /**
    * Returns a function that has the required information to assemble each row.
    *
    * @param lengths     an array of lengths of input columns, whose size should be equal to the number
    *                    of cells in the row (vv)
    * @param keepInvalid indicate whether to throw an error or not on seeing a null in the rows
    * @return a udf that can be applied on each row
    */
  private[sona] def assembleInt(lengths: Array[Int], keepInvalid: Boolean)(vv: Any*): linalg.Vector = {
    assert(lengths.sum <= Int.MaxValue, s"The Vector is too long, the max index is ${Int.MaxValue}")

    val indices = mutable.ArrayBuilder.make[Int]
    val values = mutable.ArrayBuilder.make[Double]
    var featureIndex = 0

    var inputColumnIndex = 0
    vv.foreach {
      case v: Double =>
        if (v.isNaN && !keepInvalid) {
          throw new SparkException(
            s"""Encountered NaN while assembling a row with handleInvalid = "error". Consider
               |removing NaNs from dataset or using handleInvalid = "keep" or "skip"."""
              .stripMargin)
        } else if (v != 0.0) {
          indices += featureIndex
          values += v
        }
        inputColumnIndex += 1
        featureIndex += 1
      case vec: linalg.Vector =>
        vec.foreachActive { case (i, v) =>
          if (v != 0.0) {
            indices += featureIndex + i.toInt
            values += v
          }
        }
        inputColumnIndex += 1
        featureIndex += vec.size.toInt
      case null =>
        if (keepInvalid) {
          val length: Int = lengths(inputColumnIndex)
          Array.range(0, length).foreach { i =>
            indices += featureIndex + i
            values += Double.NaN
          }
          inputColumnIndex += 1
          featureIndex += length
        } else {
          throw new SparkException(
            s"""Encountered null while assembling a row with handleInvalid = "keep". Consider
               |removing nulls from dataset or using handleInvalid = "keep" or "skip"."""
              .stripMargin)
        }
      case o =>
        throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.")
    }
    Vectors.sparse(featureIndex, indices.result(), values.result()).compressed
  }

  private[sona] def assembleLong(lengths: Array[Long], keepInvalid: Boolean)(vv: Any*): linalg.Vector = {
    assert(lengths.sum <= Long.MaxValue, s"The Vector is too long, the max index is ${Long.MaxValue}")

    val indices = mutable.ArrayBuilder.make[Long]
    val values = mutable.ArrayBuilder.make[Double]
    var featureIndex = 0L

    var inputColumnIndex = 0
    vv.foreach {
      case v: Double =>
        if (v.isNaN && !keepInvalid) {
          throw new SparkException(
            s"""Encountered NaN while assembling a row with handleInvalid = "error". Consider
               |removing NaNs from dataset or using handleInvalid = "keep" or "skip"."""
              .stripMargin)
        } else if (v != 0.0) {
          indices += featureIndex
          values += v
        }
        inputColumnIndex += 1
        featureIndex += 1
      case vec: linalg.Vector =>
        vec.foreachActive { case (i, v) =>
          if (v != 0.0) {
            indices += featureIndex + i
            values += v
          }
        }
        inputColumnIndex += 1
        featureIndex += vec.size
      case null =>
        if (keepInvalid) {
          val length: Long = lengths(inputColumnIndex)
          Range.Long(0, length, 1).foreach { i =>
            indices += featureIndex + i
            values += Double.NaN
          }
          inputColumnIndex += 1
          featureIndex += length
        } else {
          throw new SparkException(
            s"""Encountered null while assembling a row with handleInvalid = "keep". Consider
               |removing nulls from dataset or using handleInvalid = "keep" or "skip"."""
              .stripMargin)
        }
      case o =>
        throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.")
    }

    Vectors.sparse(featureIndex, indices.result(), values.result()).compressed
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy