All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.vw.VowpalWabbitFeaturizer.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.vw

import com.microsoft.ml.spark.core.contracts.{HasInputCols, HasOutputCol, Wrappable}
import com.microsoft.ml.spark.vw.featurizer._
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer}
import org.apache.spark.ml.param.{IntParam, ParamMap, StringArrayParam}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, struct, udf}
import org.vowpalwabbit.spark.VowpalWabbitMurmur
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType

import scala.collection.mutable.ArrayBuilder

object VowpalWabbitFeaturizer extends ComplexParamsReadable[VowpalWabbitFeaturizer]

class VowpalWabbitFeaturizer(override val uid: String) extends Transformer
  with HasInputCols with HasOutputCol with HasNumBits with Wrappable with ComplexParamsWritable
{
  def this() = this(Identifiable.randomUID("VowpalWabbitFeaturizer"))

  setDefault(inputCols -> Array())

  val seed = new IntParam(this, "seed", "Hash seed")
  setDefault(seed -> 0)

  def getSeed: Int = $(seed)
  def setSeed(value: Int): this.type = set(seed, value)

  val stringSplitInputCols = new StringArrayParam(this, "stringSplitInputCols",
    "Input cols that should be split at word boundaries")
  setDefault(stringSplitInputCols -> Array())

  def getStringSplitInputCols: Array[String] = $(stringSplitInputCols)
  def setStringSplitInputCols(value: Array[String]): this.type = set(stringSplitInputCols, value)

  private def getAllInputCols = getInputCols ++ getStringSplitInputCols

  private def getFeaturizer(name: String, dataType: DataType, idx: Int, namespaceHash: Int): Featurizer = {
    val stringSplitInputCols = getStringSplitInputCols

    dataType match {
      case DoubleType => new NumericFeaturizer(idx, name, namespaceHash, getMask, r => r.getDouble(idx))
      case FloatType => new NumericFeaturizer(idx, name, namespaceHash, getMask, r => r.getFloat(idx).toDouble)
      case IntegerType => new NumericFeaturizer(idx, name, namespaceHash, getMask, r => r.getInt(idx).toDouble)
      case LongType => new NumericFeaturizer(idx, name, namespaceHash, getMask, r => r.getLong(idx).toDouble)
      case ShortType => new NumericFeaturizer(idx, name, namespaceHash, getMask, r => r.getShort(idx).toDouble)
      case ByteType => new NumericFeaturizer(idx, name, namespaceHash, getMask, r => r.getByte(idx).toDouble)
      case BooleanType => new BooleanFeaturizer(idx, name, namespaceHash, getMask)
      case StringType =>
        if (stringSplitInputCols.contains(name))
          new StringSplitFeaturizer(idx, name, namespaceHash, getMask)
      else new StringFeaturizer(idx, name, namespaceHash, getMask)
      case arr: ArrayType => {
        if (arr.elementType != DataTypes.StringType)
          throw new RuntimeException(s"Unsupported array element type: ${dataType}")

        new StringArrayFeaturizer(idx, name, namespaceHash, getMask)
      }
      case m: MapType => {
        if (m.keyType != DataTypes.StringType)
          throw new RuntimeException(s"Unsupported map key type: ${dataType}")

        m.valueType match {
          case StringType => new MapStringFeaturizer(idx, name, namespaceHash, getMask)
          case DoubleType => new MapFeaturizer[Double](idx, name, namespaceHash, getMask, v => v)
          case FloatType => new MapFeaturizer[Float](idx, name, namespaceHash, getMask, v => v.toDouble)
          case IntegerType => new MapFeaturizer[Int](idx, name, namespaceHash, getMask, v => v.toDouble)
          case LongType => new MapFeaturizer[Long](idx, name, namespaceHash, getMask, v => v.toDouble)
          case ShortType => new MapFeaturizer[Short](idx, name, namespaceHash, getMask, v => v.toDouble)
          case ByteType => new MapFeaturizer[Byte](idx, name, namespaceHash, getMask, v => v.toDouble)
          case _ => throw new RuntimeException(s"Unsupported map value type: ${dataType}")
        }
      }
      case m: Any =>
        if (m == VectorType) // unfortunately the type is private
          new VectorFeaturizer(idx, getMask)
        else
          throw new RuntimeException(s"Unsupported data type: ${dataType}")
    }
  }

  override def transform(dataset: Dataset[_]): DataFrame = {
    val inputColsList = getAllInputCols
    val namespaceHash: Int = VowpalWabbitMurmur.hash(this.getOutputCol, this.getSeed)

    val fieldSubset = dataset.schema.fields
      .filter(f => inputColsList.contains(f.name))

    val featurizers: Array[Featurizer] = fieldSubset.zipWithIndex
      .map { case (field, idx) => getFeaturizer(field.name, field.dataType, idx, namespaceHash) }

        // TODO: list types
        // BinaryType
        // CalendarIntervalType
        // DateType
        // NullType
        // TimestampType
        // getStruct

    val mode = udf((r: Row) => {
      val indices = ArrayBuilder.make[Int]
      val values = ArrayBuilder.make[Double]

      // educated guess on size
      indices.sizeHint(featurizers.length)
      values.sizeHint(featurizers.length)

      // apply all featurizers
      for (f <- featurizers)
        if (!r.isNullAt(f.fieldIdx))
          f.featurize(r, indices, values)

      // sort by indices and remove duplicate values
      // Warning:
      //   - due to SparseVector limitations (which doesn't allow duplicates) we need filter
      //   - VW command line allows for duplicate features with different values (just updates twice)
      val (indicesSorted, valuesSorted) = VectorUtils.sortAndDistinct(indices.result, values.result, getSumCollisions)

      Vectors.sparse(1 << getNumBits, indicesSorted, valuesSorted)
    })

    dataset.toDF.withColumn(getOutputCol, mode.apply(struct(fieldSubset.map(f => col(f.name)): _*)))
  }

  override def copy(extra: ParamMap): VowpalWabbitFeaturizer = defaultCopy(extra)

  override def transformSchema(schema: StructType): StructType = {
    val fieldNames = schema.fields.map(_.name)
    for (f <- getAllInputCols)
      if (!fieldNames.contains(f))
        throw new IllegalArgumentException(s"missing input column $f")

    schema.add(new StructField(getOutputCol, VectorType, true))
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy