All Downloads are FREE. Search and download functionalities are using the official Maven repository.

streaming.dsl.mmlib.algs.SQLRawSimilarInPlace.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package streaming.dsl.mmlib.algs

import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}
import streaming.dsl.mmlib.SQLAlg
import streaming.dsl.mmlib.algs.feature.StringFeature

/**
 * Created by zhuml on 9/8/2018.
 */
class SQLRawSimilarInPlace extends SQLAlg with Functions {

  override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
    val spark = df.sparkSession

    val inputCol = params.getOrElse("inputCol", "content").toString
    val labelCol = params.getOrElse("labelCol", "label").toString
    val threshold = params.getOrElse("threshold", "0.8").toDouble
    val sentenceSplit = params.getOrElse("sentenceSplit", "。").toString
    val modelType = params.getOrElse("modelType", "Word2VecInPlace").toString
    val modelPath = params("modelPath")
    val newDf = modelType match {
      case _ => StringFeature.raw2vec(df, inputCol, sentenceSplit, modelPath)
    }
    val rdd = newDf.rdd.map(f => (f.getAs(inputCol).asInstanceOf[Seq[Seq[Double]]], f.getAs(labelCol).asInstanceOf[Long]))
    val rdd1 = rdd.cartesian(rdd).filter(x => x._1._2 > x._2._2).map(x => Row(x._1._2, x._2._2, StringFeature.rawSimilar(x._1._1, x._2._1, threshold)))
    val newDf1 = df.sparkSession.createDataFrame(rdd1,
      StructType(Seq(StructField("i", LongType), StructField("j", LongType), StructField("v", DoubleType))))
    newDf1.write.mode(SaveMode.Overwrite).parquet(path)
    emptyDataFrame()(df)
  }

  override def load(spark: SparkSession, path: String, params: Map[String, String]): Any = {
    val entries = spark.read.parquet(path)
    val rdd1 = entries.rdd.map { f =>
      (f.getLong(0), (f.getLong(1), f.getDouble(2)))
    }
    val rdd2 = entries.rdd.map { f =>
      (f.getLong(1), (f.getLong(0), f.getDouble(2)))
    }
    rdd1.union(rdd2).groupByKey().map(f => (f._1, f._2.toMap)).collect().toMap
  }

  override def predict(spark: SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = {
    val model = spark.sparkContext.broadcast(_model.asInstanceOf[Map[Long, Map[Long, Double]]])

    val f = (i: Long, threshhold: Double) => {
      model.value(i).filter(f => f._2 > threshhold)
    }
    UserDefinedFunction(f, MapType(LongType, DoubleType), Some(Seq(LongType, DoubleType)))
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy