All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.vw.featurizer.StringSplitFeaturizer.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.vw.featurizer

import com.microsoft.ml.spark.vw.VowpalWabbitMurmurWithPrefix
import org.apache.spark.sql.Row
import org.vowpalwabbit.spark.VowpalWabbitMurmur

import scala.collection.mutable.ArrayBuilder

/**
  * Featurize strings by splitting into native VW structure. (hash(s(0)):value, hash(s(1)):value, ...)
  * @param fieldIdx input field index.
  * @param columnName used as feature name prefix.
  * @param namespaceHash pre-hashed namespace.
  * @param mask bit mask applied to final hash.
  */
class StringSplitFeaturizer(override val fieldIdx: Int, val columnName: String, val namespaceHash: Int, val mask: Int)
  extends Featurizer(fieldIdx) {

  /**
    * (?U) makes \w unicode aware
    * https://stackoverflow.com/questions/4304928/unicode-equivalents-for-w-and-b-in-java-regular-expressions
    * we could follow
    * https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html
    * but that strips single character words...
    *
    * TODO: expose as user configurable parameter
    */
  val nonWhiteSpaces = "(?U)\\w+".r

  /**
    * Initialize hasher that already pre-hashes the column prefix.
    */
  val hasher = new VowpalWabbitMurmurWithPrefix(columnName)

  /**
    * Featurize a single row.
    * @param row input row.
    * @param indices output indices.
    * @param values output values.
    * @note this interface isn't very Scala-esce, but it avoids lots of allocation.
    *       Also due to SparseVector limitations we don't support 64bit indices (e.g. indices are signed 32bit ints)
    */
  override def featurize(row: Row, indices: ArrayBuilder[Int], values: ArrayBuilder[Double]): Unit = {
    val s = row.getString(fieldIdx)

    for (e <- nonWhiteSpaces.findAllMatchIn(s)) {
      // Note: since the hasher access the chars directly. this avoids allocation.
      indices +=  mask & hasher.hash(s, e.start, e.end, namespaceHash)

      values += 1.0
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy