All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.ainsightful.sparkml.lite.FeatureHasherLite.scala Maven / Gradle / Ivy

The newest version!
package com.ainsightful.sparkml.lite

import org.apache.spark.SparkException
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.types.{DoubleType, NumericType, StructType}
import org.apache.spark.unsafe.hash.Murmur3_x86_32.{hashInt, hashLong, hashUnsafeBytes2}
import org.apache.spark.unsafe.types.UTF8String

import scala.collection.mutable

class FeatureHasherLite(schema: StructType, hashSize: Int, checkNullable: Boolean = true) {

  private val columns = schema.map(struct => struct.name -> struct).toMap

  private val realFields = schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map(_.name).toSet
  private val dataTypes = schema.fields.map(x => x.name -> x.dataType).toMap

  def hash(feature: Map[String, Any]) = {
    val map = new mutable.OpenHashMap[Int, Double]()
    columns.keys.foreach({ colName =>
      if (feature.contains(colName)) {
        val (rawIdx, value) =
          if (realFields(colName)) {
            // numeric values are kept as is, with vector index based on hash of "column_name"
            val value =
              if (dataTypes(colName) == DoubleType) feature(colName).asInstanceOf[Double]
              else feature(colName).toString.toDouble
            val hash = FeatureHasherLite.hash(colName)
            (hash, value)
          } else {
            // string, boolean and numeric values that are in catCols are treated as categorical,
            // with an indicator value of 1.0 and vector index based on hash of "column_name=value"
            val value = feature(colName).toString
            val fieldName = s"$colName=$value"
            val hash = FeatureHasherLite.hash(fieldName)
            (hash, 1.0)
          }
        val idx = FeatureHasherLite.nonNegativeMod(rawIdx, hashSize)
        map.put(idx, map.get(idx).getOrElse(0.0) + value)
      } else if (checkNullable && !columns(colName).nullable) {
        throw new RuntimeException(s"Non-nullable column should have value: ${colName}. Set checkNullable=false if the check is really not needed.")
      }
    })
    Vectors.sparse(hashSize, map.toSeq)
  }
}

object FeatureHasherLite {

  private val seed = 42

  def hash(term: Any): Int = {
    term match {
      case null => seed
      case b: Boolean => hashInt(if (b) 1 else 0, seed)
      case b: Byte => hashInt(b, seed)
      case s: Short => hashInt(s, seed)
      case i: Int => hashInt(i, seed)
      case l: Long => hashLong(l, seed)
      case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed)
      case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed)
      case s: String =>
        val utf8 = UTF8String.fromString(s)
        hashUnsafeBytes2(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed)
      case _ => throw new SparkException("FeatureHasher with murmur3 algorithm does not " +
        s"support type ${term.getClass.getCanonicalName} of input data.")
    }
  }

  /* Calculates 'x' modulo 'mod', takes to consideration sign of x,
   * i.e. if 'x' is negative, than 'x' % 'mod' is negative too
   * so function return (x % mod) + mod in that case.
   */
  def nonNegativeMod(x: Int, mod: Int): Int = {
    val rawMod = x % mod
    rawMod + (if (rawMod < 0) mod else 0)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy