All Downloads are FREE. Search and download functionalities are using the official Maven repository.

streaming.dsl.mmlib.algs.SQLVecMapInPlace.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package streaming.dsl.mmlib.algs

import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, functions => F}
import org.apache.spark.sql.expressions.UserDefinedFunction
import streaming.dsl.mmlib.SQLAlg
import MetaConst._
import org.apache.spark.ml.help.HSQLStringIndex
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.types._
import streaming.dsl.mmlib.algs.meta.Word2IndexMeta
import org.apache.spark.ml.linalg.SQLDataTypes._

/**
  * Created by allwefantasy on 18/6/2018.
  */
class SQLVecMapInPlace extends SQLAlg with Functions {

  def internal_train(df: DataFrame, params: Map[String, String]) = {
    val path = params("path")
    val metaPath = getMetaPath(path)
    saveTraningParams(df.sparkSession, params, metaPath)

    val spark = df.sparkSession
    import spark.implicits._

    val inputCol = params.getOrElse("inputCol", "")
    require(!inputCol.isEmpty, "inputCol is required when use SQLVecMapInPlace")

    val keyRDD = df.rdd.flatMap { f =>
      f.getMap[String, Number](f.fieldIndex(inputCol)).keys.toArray[String]
    }.map(f => Row(Seq(f)))

    val keyDF = spark.createDataFrame(keyRDD, StructType(Seq(
      StructField(inputCol, ArrayType(StringType))
    )))
    //train word index
    val si = new SQLStringIndex()
    si.train(keyDF, MetaConst.WORD_INDEX_PATH(metaPath, inputCol), params)
    val model = si.load(spark, MetaConst.WORD_INDEX_PATH(metaPath, inputCol), params)
    val word2IndexMapping = HSQLStringIndex.wordToIndex(spark, model)
    val featureSize = word2IndexMapping.size

    val f = (item: Map[String, Number]) => {
      val elements = item.map(f => (word2IndexMapping(f._1).toInt, f._2.doubleValue())).toSeq
      Vectors.sparse(featureSize, elements)
    }
    val fUDF = F.udf(f)
    val newDF = df.withColumn(inputCol, fUDF(F.col(inputCol)))
    newDF
  }

  override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
    val newDF = internal_train(df, params + ("path" -> path))
    newDF.write.mode(SaveMode.Overwrite).parquet(getDataPath(path))
    emptyDataFrame()(df)
  }

  override def load(spark: SparkSession, _path: String, params: Map[String, String]): Any = {
    import spark.implicits._
    //load train params
    val path = getMetaPath(_path)
    val df = spark.read.parquet(PARAMS_PATH(path, "params")).map(f => (f.getString(0), f.getString(1)))
    val trainParams = df.collect().toMap
    val inputCol = trainParams.getOrElse("inputCol", "")

    val si = new SQLStringIndex()
    val model = si.load(spark, MetaConst.WORD_INDEX_PATH(path, inputCol), params)
    val word2IndexMapping = HSQLStringIndex.wordToIndex(spark, model)
    Word2IndexMeta(trainParams, word2IndexMapping.toMap)

  }

  override def predict(sparkSession: SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = {
    val word2IndexMeta = _model.asInstanceOf[Word2IndexMeta]
    val word2IndexMappingBr = sparkSession.sparkContext.broadcast(word2IndexMeta.wordIndex)
    val featureSize = word2IndexMappingBr.value.size
    val f = (item: Map[String, Double]) => {
      val elements = item.map(f => (word2IndexMappingBr.value(f._1).toInt, f._2)).toSeq
      Vectors.sparse(featureSize, elements)
    }
    UserDefinedFunction(f, VectorType, Some(Seq(MapType(StringType, DoubleType))))
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy