All Downloads are FREE. Search and download functionalities are using the official Maven repository.

streaming.dsl.mmlib.algs.SQLPageRank.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package streaming.dsl.mmlib.algs

import org.apache.spark.graphx.lib.PageRank
import org.apache.spark.graphx.{Edge, Graph, VertexId}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.types.{DoubleType, LongType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}
import streaming.dsl.mmlib.SQLAlg

/**
  * Created by allwefantasy on 13/1/2018.
  */
class SQLPageRank extends SQLAlg with Functions {
  override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
    val vertexCol = params.getOrElse("vertexCol", "vertextIds").toString
    val edgeSrcCol = params.getOrElse("edgeSrcCol", "edgeSrc").toString
    val edgeDstCol = params.getOrElse("edgeDstCol", "edgeDst").toString

    val edges: RDD[Edge[Any]] = df.rdd.map { f =>
      Edge(f.getAs[Long](edgeSrcCol), f.getAs[Long](edgeDstCol))
    }

    var vertex: RDD[(VertexId, Any)] = null
    if (params.contains("vertexCol")) {
      vertex = df.rdd.map { f =>
        (f.getAs[Long](vertexCol), "")
      }
    } else {
      vertex = edges.flatMap(f => List(f.srcId, f.dstId)).distinct().map { f =>
        (f, "")
      }
    }

    val graph = Graph(vertex, edges, ())
    val tol = params.getOrElse("tol", "0.001").toDouble
    val resetProb = params.getOrElse("resetProb", "0.15").toDouble
    val pr = PageRank.runUntilConvergence(graph, tol, resetProb)
    val verticesDf = df.sparkSession.createDataFrame(pr.vertices.map(f => Row(f._1, f._2)),
      StructType(Seq(StructField("vertextId", LongType), StructField("pagerank", DoubleType))))

    val edgesDf = df.sparkSession.createDataFrame(pr.edges.map(f => Row(f.srcId, f.dstId, f.attr)),
      StructType(Seq(StructField("srcId", LongType), StructField("dstId", LongType), StructField("weight", DoubleType))))

    verticesDf.write.mode(SaveMode.Overwrite).parquet(path + "/vertices")
    edgesDf.write.mode(SaveMode.Overwrite).parquet(path + "/edges")
    emptyDataFrame()(df)

  }

  override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = {
    import sparkSession._
    val verticesDf = sparkSession.read.parquet(path + "/vertices")
    verticesDf.collect().map(f => (f.getLong(0), f.getDouble(1))).toMap
  }

  override def predict(sparkSession: SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = {
    val model = sparkSession.sparkContext.broadcast(_model.asInstanceOf[Map[Long, Double]])

    val f = (vertexId: Long) => {
      model.value(vertexId)
    }
    UserDefinedFunction(f, DoubleType, Some(Seq(LongType)))
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy