All Downloads are FREE. Search and download functionalities are using the official Maven repository.

streaming.dsl.mmlib.algs.SQLTokenAnalysis.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package streaming.dsl.mmlib.algs

import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.{functions => F}
import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}
import streaming.dsl.mmlib.SQLAlg


import scala.collection.JavaConversions._
import scala.collection.mutable

/**
  * Created by allwefantasy on 24/4/2018.
  */
class SQLTokenAnalysis extends SQLAlg with Functions {

  def internal_train(df: DataFrame, params: Map[String, String]) = {
    val session = df.sparkSession

    require(params.contains("inputCol"), "inputCol is required")
    val fieldName = params("inputCol")
    val split = params.getOrElse("split", null)
    val arrayWords = {
      if (params.contains("wordsArray")) {
        params("wordsArray").split(",")
      } else
        Array[String]()
    }

    val words = SQLTokenAnalysis.loadDics(session, params) ++ arrayWords


    val rdd = df.rdd.mapPartitions { mp =>

      val parser = SQLTokenAnalysis.createAnalyzer(words, params)
      mp.map { f =>
        val content = f.getAs[String](fieldName)
        val res = {
          if (split != null) {
            content.split(split)
          }
          else SQLTokenAnalysis.parseStr(parser, content, params)
        }
        val index = f.fieldIndex(fieldName)
        val newValue = f.toSeq.zipWithIndex.filterNot(f => f._2 == index).map(f => f._1) ++ Seq(res)
        Row.fromSeq(newValue)
      }
    }
    session.createDataFrame(rdd,
      StructType(df.schema.filterNot(f => f.name == fieldName) ++ Seq(StructField(fieldName, ArrayType(StringType)))))

  }

  override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
    val newDf = internal_train(df, params)
    val fieldName = params("inputCol")
    val id = params("idCol")
    newDf.select(F.col(fieldName).alias("keywords"), F.col(id)).write.mode(SaveMode.Overwrite).parquet(path)
    emptyDataFrame()(df)
  }

  override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = {
    null
  }

  override def predict(sparkSession: SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = {
    null
  }
}

object SQLTokenAnalysis {
  def parseStr(parser: Any, content: String, params: Map[String, String]) = {

    val ignoreNature = params.getOrElse("ignoreNature", "true").toBoolean
    val filterNatures = params.getOrElse("filterNatures", "").split(",").filterNot(f => f.isEmpty).toSet
    val deduplicateResult = params.getOrElse("deduplicateResult", "false").toBoolean

    val udg = try {
      parser.getClass.getMethod("parseStr", classOf[String]).invoke(parser, content)
    } catch {
      case e: Exception =>
        println(s"parser invoke error:${content}")
        throw e
    }

    def getAllWords(udg: Any) = {
      val result = udg.getClass.getMethod("getTerms").invoke(udg).asInstanceOf[java.util.List[Object]]
      var res = result.map { f =>
        val (name, nature) = AnsjFunctions.getTerm(f)
        (name.toString, nature.toString)
      }

      if (deduplicateResult) {
        val tmpSet = new mutable.HashSet[(String, String)]()
        res.map { f =>
          if (!tmpSet.contains(f)) {
            tmpSet.add(f)
          }
          tmpSet.contains(f)
        }
        res = tmpSet.toBuffer
      }

      if (filterNatures.size > 0) {
        res = res.filter(f => filterNatures.contains(f._2))
      }

      if (ignoreNature) {
        res.map(f => s"${f._1}")
      } else {
        res.map(f => s"${f._1}/${f._2}")
      }
    }

    getAllWords(udg).toArray
  }

  def loadDics(spark: SparkSession, params: Map[String, String]) = {
    var result = Array[String]()
    result ++= params.getOrElse("dic.paths", "").split(",").filter(f => !f.isEmpty).map { f =>
      val wordsList = spark.sparkContext.textFile(f).collect()
      wordsList
    }.flatMap(f => f)
    result
  }


  def createForest(words: Array[String], params: Map[String, String]) = {

    val forestClassName = params.getOrElse("forest", "org.nlpcn.commons.lang.tire.domain.Forest")
    val forest = Class.forName(forestClassName).newInstance().asInstanceOf[AnyRef]
    words.foreach { f =>
      AnsjFunctions.addWord(f, forest)
    }
    forest
  }

  def createAnalyzerFromForest(forest: AnyRef, params: Map[String, String]) = {
    val parserClassName = params.getOrElse("parser", "org.ansj.splitWord.analysis.NlpAnalysis")
    val forestClassName = params.getOrElse("forest", "org.nlpcn.commons.lang.tire.domain.Forest")
    val parser = Class.forName(parserClassName).newInstance().asInstanceOf[AnyRef]
    AnsjFunctions.configureDic(parser, forest, parserClassName, forestClassName)
    parser
  }

  def createAnalyzer(words: Array[String], params: Map[String, String]) = {

    val parserClassName = params.getOrElse("parser", "org.ansj.splitWord.analysis.NlpAnalysis")
    val forestClassName = params.getOrElse("forest", "org.nlpcn.commons.lang.tire.domain.Forest")

    val forest = Class.forName(forestClassName).newInstance().asInstanceOf[AnyRef]
    val parser = Class.forName(parserClassName).newInstance().asInstanceOf[AnyRef]

    words.foreach { f =>
      AnsjFunctions.addWord(f, forest)
    }

    AnsjFunctions.configureDic(parser, forest, parserClassName, forestClassName)
    parser
  }
}






© 2015 - 2024 Weber Informatics LLC | Privacy Policy