All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.stages.TextPreprocessor.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.stages

import com.microsoft.ml.spark.core.contracts.{HasInputCol, HasOutputCol, Wrappable}
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer}
import org.apache.spark.ml.param.{MapParam, Param, ParamMap}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.udf
import spray.json.DefaultJsonProtocol._

class Trie(map: Map[Char, Trie] = Map.empty,
           textValue: Seq[Char] = Seq.empty,
           normFunction: Char => Char = identity) extends Serializable {
  val trieMap: Map[Char, Trie] = map
  val value: Seq[Char] = textValue
  val normFunc: Char => Char = normFunction
  def getValueString: String = this.value.toString

  def getNodeValue(key: Seq[Char], value: Seq[Char]): Seq[Char] = (key, value) match {
    case (k, v) if k.tail.isEmpty => v
    case (_, _) => Seq.empty[Char]
  }
  def put(key: String, value: String): Trie = this.put(key.toCharArray, value)

  def put(key: Seq[Char], value: String, isNormalized: Boolean = false): Trie = (key, value) match {
    case (_, _) if key.isEmpty => this.copy()
    case (k, v) =>
      val normalizedKey = if (isNormalized) k else k.map(this.normFunc)
      val newLeaf = new Trie(textValue = getNodeValue(normalizedKey, v.toCharArray), normFunction = this.normFunc)
      val next = this.trieMap.getOrElse(normalizedKey.head, newLeaf).put(normalizedKey.tail, v, true)
      val newMap = this.trieMap + (normalizedKey.head -> next)
      new Trie(newMap, this.value, this.normFunc)
  }

  def get(letter: Char): Option[Trie] = this.trieMap.get(this.normFunc(letter))

  def putAll(stringMap: Map[String, String]): Trie = {
    stringMap.foldLeft(this){
      case (out, (k, v)) => out.put(k, v)
    }
  }

  def copy(): Trie = new Trie(this.trieMap, this.value, this.normFunc)

  def mapText(inputText: String): String = {
    val chars = inputText.toCharArray
    val outputText: StringBuilder = new StringBuilder("")
    def isAlpha(char: Char): Boolean = {
      char.isLetterOrDigit || char.equals("_".charAt(0))
    }
    def skipAlphas(rest: Seq[Char]): Unit = rest match {
      case _ if rest.isEmpty || !isAlpha(rest.head)=> scan(rest)
      case _  => skipAlphas(rest.tail)
    }

    def findMatch(rest: Seq[Char],
                  matched: Seq[Char],
                  hasMatch: Boolean,
                  chars: Seq[Char],
                  trie: Option[Trie]): Unit = (rest, matched, hasMatch, chars, trie) match {
      case _ if trie.isEmpty || chars.isEmpty =>
        outputText ++= matched
        if (hasMatch) skipAlphas(rest) else scan(rest)
      case _ if trie.get.value.isEmpty => findMatch(rest, matched, false, chars.tail, trie.get.get(chars.head))
      case _ => findMatch(chars, trie.get.value, true, chars.tail, trie.get.get(chars.head))
    }

    def scan(chars: Seq[Char]): Unit = {
      if (chars.nonEmpty) findMatch(chars.tail, Array(chars.head), false, chars.tail, this.get(chars.head))
    }

    scan(chars)
    outputText.mkString
  }

  override def toString: String = s"Val: $getValueString Rest:\n $trieMap"
}

object Trie {
  def apply(map: Map[String, String]): Trie = {
    new Trie().putAll(map)
  }
}

object TextPreprocessor extends ComplexParamsReadable[TextPreprocessor]

/** TextPreprocessor takes a dataframe and a dictionary
  * that maps (text -> replacement text), scans each cell in the input col
  * and replaces all substring matches with the corresponding value.
  * Priority is given to longer keys and from left to right.
  */
class TextPreprocessor(val uid: String) extends Transformer
  with HasInputCol with HasOutputCol with Wrappable with ComplexParamsWritable {
  def this() = this(Identifiable.randomUID("TextPreprocessor"))

  val normFuncs: Map[String, Char => Char] = Map[String, Char => Char] (
    "identity"   -> identity,
    "lowerCase"  -> Character.toLowerCase,
    "upperCase"  -> Character.toUpperCase
  )

  val map = new MapParam[String, String](this, "map", "Map of substring match to replacement")

  /** @group getParam */
  def getMap: Map[String, String] = get(map).getOrElse(Map())

  /** @group setParam */
  def setMap(value: Map[String, String]): this.type = set("map", value)

  def isValidNormFunc(normFuncName: String): Boolean = normFuncs.contains(normFuncName)

  val normFunc = new Param[String](this, "normFunc", "Name of normalization function to apply", isValidNormFunc _)

  /** @group getParam */
  def getNormFunc: String = get(normFunc).getOrElse("identity")

  /** @group setParam */
  def setNormFunc(value: String): this.type = set("normFunc", value)

  /** @param dataset - The input dataset, to be transformed
    * @return The DataFrame that results from column selection
    */
  override def transform(dataset: Dataset[_]): DataFrame = {
    val spark = dataset.sparkSession
    val inputIndex = dataset.columns.indexOf(getInputCol)
    val trie = new Trie(normFunction = normFuncs(getNormFunc)).putAll(getMap)
    val broadcastedTrie = spark.sparkContext.broadcast(trie)

    require(inputIndex != -1, s"Input column $getInputCol does not exist")

    val mapText: String => String = broadcastedTrie.value.mapText
    val textMapper = udf(mapText)
    dataset.withColumn(getOutputCol, textMapper(dataset(getInputCol)).as(getOutputCol))
  }

  def transformSchema(schema: StructType): StructType = {
    schema.add(StructField(getOutputCol, StringType))
  }

  def copy(extra: ParamMap): TextPreprocessor = defaultCopy(extra)

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy