All Downloads are FREE. Search and download functionalities are using the official Maven repository.

streaming.dsl.mmlib.algs.SQLCommunityBasedSimilityInPlace.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package streaming.dsl.mmlib.algs

import org.apache.spark.graphx._
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
import org.apache.spark.sql.expressions.UserDefinedFunction
import streaming.dsl.mmlib.SQLAlg
import streaming.dsl.mmlib.algs.meta.graphx.{GroupVeterxs, VeterxAndGroup}

/**
  * Created by allwefantasy on 9/8/2018.
  */
class SQLCommunityBasedSimilarityInPlace extends SQLAlg with Functions {
  override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
    val rowNumCol = params.getOrElse("rowNum", "i")
    val columnNumCol = params.getOrElse("columnNum", "j")
    val edgeValueCol = params.getOrElse("edgeValue", "v")
    val minSimilarity = params.getOrElse("minSimilarity", "0.7").toDouble
    val minCommunitySize = params.getOrElse("minCommunitySize", "10").toInt
    val minCommunityPercent = params.getOrElse("minCommunityPercent", "0.1").toDouble

    val relationships = df.rdd.map(f => Edge(f.getAs[Long](rowNumCol), f.getAs[Long](columnNumCol), f.getAs[Double](edgeValueCol)))
    val graph = Graph.fromEdges(relationships, 0d)

    val vertexCount = Math.max(Math.round(graph.vertices.count() * minCommunityPercent), minCommunitySize)

    val validGraph = graph.subgraph(epred = et => {
      et.attr > minSimilarity
    }).connectedComponents()


    val rdd = validGraph.vertices.map(f => VeterxAndGroup(f._1, f._2)).groupBy(f => f.group).
      filter(f => f._2.size > vertexCount).
      map(f => GroupVeterxs(f._1, f._2.map(k => k.vertexId).toSeq))
    import df.sparkSession.implicits._
    df.sparkSession.createDataset(rdd).write.mode(SaveMode.Overwrite).parquet(path + "/data")
    emptyDataFrame()(df)

  }

  override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = {
    new UnsupportedOperationException("Register is not supported by SQLCommunityBasedSimilarityInPlace module")
  }

  override def predict(sparkSession: SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = {
    null
  }
}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy