All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.ml.odkl.texts.LanguageAwareAnalyzer.scala Maven / Gradle / Ivy

package org.apache.spark.ml.odkl.texts

import org.apache.lucene.analysis.util.StopwordAnalyzerBase
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.shared.HasOutputCol
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{ArrayType, StringType, StructType}

/**
  * Created by eugeny.malyutin on 05.05.16.
  */

class LanguageAwareAnalyzer(override val uid: String) extends Transformer with HasOutputCol {

  @transient lazy val languageAnalyzerMap = {
    LanguageAwareStemmerUtil.languageAnalyzersMap.mapValues(analyzer => {
      new ThreadLocal[StopwordAnalyzerBase]() {
        override def initialValue() = analyzer()
      }
    }).asInstanceOf[Object]
  }

  @transient lazy val tokenizer = {
    languageAnalyzerMap.asInstanceOf[Map[String, ThreadLocal[StopwordAnalyzerBase]]]($(defaultLanguage))

  }
  val inputColLang = new Param[String](this, "inputColLang",
    "language code from langdetect")

  setDefault(inputColLang -> "lang")

  val inputColText = new Param[String](this, "inputColText",
    "column with text")

  setDefault(inputColText -> "text")

  val defaultLanguage = new Param[String](this, "defaultLanguage",
  "language to use as default if actual unknown")

  setDefault(defaultLanguage -> "ru")

  val stemmTextUDF = udf((lang: String, text: String) => {
    val analyzer = languageAnalyzerMap.asInstanceOf[Map[String, ThreadLocal[StopwordAnalyzerBase]]].getOrElse(lang, tokenizer).get()
    LanguageAwareStemmerUtil.stemmString(text, analyzer)
  })

  /** @group getParam */
  def getInputColHash: String = $(inputColLang)

  /** @group getParam */
  def getInputColText: String = $(inputColText)

  /** @group setParam */
  def setInputColLang(value: String): this.type = set(inputColLang, value)

  /** @group setParam */
  def setInputColText(value: String): this.type = set(inputColText, value)

  /** @group setParam */
  def setDefaultLanguage(value: String): this.type = set(defaultLanguage, value)

  /** @group setParam */
  def setOutputCol(value: String): this.type = set(outputCol, value)

  override def copy(extra: ParamMap): Transformer = {
    defaultCopy(extra)
  }

  def this() = this(Identifiable.randomUID("languageAnalyzer"))

  override def transform(dataset: DataFrame): DataFrame = {
    dataset.withColumn($(outputCol), stemmTextUDF(dataset.col($(inputColLang)), dataset.col($(inputColText))))
  }

  @DeveloperApi
  override def transformSchema(schema: StructType): StructType = {
    if ($(inputColText) equals $(outputCol)) {
      val schemaWithoutInput = new StructType(schema.fields.filterNot(_.name equals $(inputColText)))
      SchemaUtils.appendColumn(schemaWithoutInput, $(outputCol), ArrayType(StringType, true))
    } else {
      SchemaUtils.appendColumn(schema, $(outputCol), ArrayType(StringType, true))
    }
  }

}






© 2015 - 2024 Weber Informatics LLC | Privacy Policy