All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.vesoft.nebula.tools.importer.Exchange.scala Maven / Gradle / Ivy

The newest version!
/* Copyright (c) 2020 vesoft inc. All rights reserved.
 *
 * This source code is licensed under Apache 2.0 License,
 * attached with Common Clause Condition 1.0, found in the LICENSES directory.
 */

package com.vesoft.nebula.tools.importer

import org.apache.spark.sql.{DataFrame, SparkSession}
import java.io.File
import java.util.concurrent.{Executors, TimeUnit}

import com.google.common.base.Optional
import com.google.common.net.HostAndPort
import com.google.common.util.concurrent.{FutureCallback, Futures, MoreExecutors, RateLimiter}
import com.vesoft.nebula.client.graph.async.AsyncGraphClientImpl
import com.vesoft.nebula.graph.ErrorCode
import com.vesoft.nebula.tools.importer.processor.{EdgeProcessor, VerticesProcessor}
import com.vesoft.nebula.tools.importer.reader.{
  CSVReader,
  HiveReader,
  JSONReader,
  KafkaReader,
  MySQLReader,
  Neo4JReader,
  ORCReader,
  ParquetReader,
  PulsarReader,
  SocketReader
}
import com.vesoft.nebula.tools.importer.writer.{NebulaGraphClientWriter, Writer}
import org.apache.log4j.Logger
import org.apache.spark.SparkConf

import scala.collection.JavaConverters._

final case class Argument(config: File = new File("application.conf"),
                          hive: Boolean = false,
                          directly: Boolean = false,
                          dry: Boolean = false,
                          reload: String = "")

final case class TooManyErrorsException(private val message: String) extends Exception(message)

/**
  * SparkClientGenerator is a simple spark job used to write data into Nebula Graph parallel.
  */
object Exchange {
  private[this] val LOG = Logger.getLogger(this.getClass)

  private[this] val DEFAULT_ERROR_TIMES = 16

  def main(args: Array[String]): Unit = {
    val PROGRAM_NAME = "Spark Writer"
    val options      = Configs.parser(args, PROGRAM_NAME)
    val c: Argument = options match {
      case Some(config) => config
      case _ => {
        LOG.error("Argument parse failed")
        sys.exit(-1)
      }
    }

    val configs = Configs.parse(c.config)
    LOG.info(s"Config ${configs}")

    val session = SparkSession
      .builder()
      .appName(PROGRAM_NAME)
      .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
      .config(
        new SparkConf()
          .registerKryoClasses(
            Array(
              classOf[com.facebook.thrift.async.TAsyncClientManager]
            )))

    val spark = if (c.hive) {
      session.enableHiveSupport().getOrCreate()
    } else {
      session.getOrCreate()
    }

    // reload the execution sentence
    if (!c.reload.isEmpty) {
      val batchSuccess = spark.sparkContext.longAccumulator(s"batchSuccess.reload")
      val batchFailure = spark.sparkContext.longAccumulator(s"batchFailure.reload")

      spark.read
        .text(c.reload)
        .foreachPartition { records =>
          val hostAndPorts = configs.databaseConfig.addresses.map(HostAndPort.fromString).asJava
          val client = new AsyncGraphClientImpl(
            hostAndPorts,
            configs.connectionConfig.timeout,
            configs.connectionConfig.retry,
            configs.executionConfig.retry
          )
          client.setUser(configs.userConfig.user)
          client.setPassword(configs.userConfig.user)

          if (isSuccessfully(client.connect())) {
            val rateLimiter = RateLimiter.create(configs.rateConfig.limit)
            val service     = MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(1))
            records.foreach { row =>
              val exec = row.getString(0)
              if (rateLimiter.tryAcquire(configs.rateConfig.timeout, TimeUnit.MILLISECONDS)) {
                val future = client.execute(exec)
                Futures.addCallback(
                  future,
                  new FutureCallback[Optional[Integer]] {
                    override def onSuccess(result: Optional[Integer]): Unit = {
                      batchSuccess.add(1)
                    }

                    override def onFailure(t: Throwable): Unit = {
                      if (batchFailure.value > DEFAULT_ERROR_TIMES) {
                        throw TooManyErrorsException("too many errors")
                      }
                      batchFailure.add(1)
                    }
                  },
                  service
                )
              } else {
                batchFailure.add(1)
              }
            }
            client.close()
          } else {
            LOG.error(
              s"Client connection failed. ${configs.userConfig.user}:${configs.userConfig.password}")
          }
        }
      sys.exit(0)
    }

    if (!configs.tagsConfig.isEmpty) {
      for (tagConfig <- configs.tagsConfig) {
        LOG.info(s"Processing Tag ${tagConfig.name}")

        val fieldKeys  = tagConfig.fields.keys.toList
        val nebulaKeys = tagConfig.fields.values.map(_.unwrapped.toString).toList

        val data = createDataSource(spark, tagConfig.dataSourceConfigEntry)
        if (data.isDefined && !c.dry) {
          val batchSuccess =
            spark.sparkContext.longAccumulator(s"batchSuccess.${tagConfig.name}")
          val batchFailure =
            spark.sparkContext.longAccumulator(s"batchFailure.${tagConfig.name}")
          val processor = new VerticesProcessor(repartition(data.get, tagConfig.partition),
                                                tagConfig,
                                                fieldKeys,
                                                nebulaKeys,
                                                configs,
                                                batchSuccess,
                                                batchFailure)

          processor.process()
          LOG.info(s"batchSuccess.${tagConfig.name}: ${batchSuccess.value}")
          LOG.info(s"batchFailure.${tagConfig.name}: ${batchFailure.value}")
        }
      }
    } else {
      LOG.warn("Tag is not defined")
    }

    if (!configs.edgesConfig.isEmpty) {
      for (edgeConfig <- configs.edgesConfig) {
        LOG.info(s"Processing Edge ${edgeConfig.name}")

        val fieldKeys  = edgeConfig.fields.keys.toList
        val nebulaKeys = edgeConfig.fields.values.map(_.unwrapped.toString).toList
        val data       = createDataSource(spark, edgeConfig.dataSourceConfigEntry)
        if (data.isDefined && !c.dry) {
          val batchSuccess = spark.sparkContext.longAccumulator(s"batchSuccess.${edgeConfig.name}")
          val batchFailure = spark.sparkContext.longAccumulator(s"batchFailure.${edgeConfig.name}")

          val processor = new EdgeProcessor(repartition(data.get, edgeConfig.partition),
                                            edgeConfig,
                                            fieldKeys,
                                            nebulaKeys,
                                            configs,
                                            batchSuccess,
                                            batchFailure)

          processor.process()
        } else {
          LOG.warn("Edge is not defined")
        }
      }
    }
    spark.close()
  }

  /**
    * Create data source for different data type.
    *
    * @param session The Spark Session.
    * @param config  The config.
    * @return
    */
  private[this] def createDataSource(
      session: SparkSession,
      config: DataSourceConfigEntry
  ): Option[DataFrame] = {
    config.category match {
      case SourceCategory.PARQUET => {
        val parquetConfig = config.asInstanceOf[FileBaseSourceConfigEntry]
        LOG.info(s"""Loading Parquet files from ${parquetConfig.path}""")
        val reader = new ParquetReader(session, parquetConfig.path)
        Some(reader.read())
      }
      case SourceCategory.ORC => {
        val orcConfig = config.asInstanceOf[FileBaseSourceConfigEntry]
        LOG.info(s"""Loading ORC files from ${orcConfig.path}""")
        val reader = new ORCReader(session, orcConfig.path)
        Some(reader.read())
      }
      case SourceCategory.JSON => {
        val jsonConfig = config.asInstanceOf[FileBaseSourceConfigEntry]
        LOG.info(s"""Loading JSON files from ${jsonConfig.path}""")
        val reader = new JSONReader(session, jsonConfig.path)
        Some(reader.read())
      }
      case SourceCategory.CSV => {
        val csvConfig = config.asInstanceOf[CSVSourceConfigEntry]
        LOG.info(s"""Loading CSV files from ${csvConfig.path}""")
        val reader = new CSVReader(session, csvConfig.path, csvConfig.separator, csvConfig.header)
        Some(reader.read())
      }
      case SourceCategory.HIVE => {
        val hiveConfig = config.asInstanceOf[HiveSourceConfigEntry]
        LOG.info(s"""Loading from Hive and exec ${hiveConfig.exec}""")
        val reader = new HiveReader(session, hiveConfig.exec)
        Some(reader.read())
      }
      // TODO: (darion.yaphet) Support Structured Streaming
      case SourceCategory.SOCKET => {
        val socketConfig = config.asInstanceOf[SocketSourceConfigEntry]
        LOG.warn("Socket streaming mode is not suitable for production environment")
        LOG.info(s"""Reading data stream from Socket ${socketConfig.host}:${socketConfig.port}""")
        val reader = new SocketReader(session, socketConfig.host, socketConfig.port)
        Some(reader.read())
      }
      case SourceCategory.KAFKA => {
        val kafkaConfig = config.asInstanceOf[KafkaSourceConfigEntry]
        LOG.info(s"""Loading from Kafka ${kafkaConfig.server} and subscribe ${kafkaConfig.topic}""")
        val reader = new KafkaReader(session, kafkaConfig.server, kafkaConfig.topic)
        Some(reader.read())
      }
      case SourceCategory.NEO4J =>
        val neo4jConfig = config.asInstanceOf[Neo4JSourceConfigEntry]
        LOG.info(s"Loading from neo4j config: ${neo4jConfig}")
        val reader = new Neo4JReader(session, neo4jConfig)
        Some(reader.read())
      case SourceCategory.MYSQL =>
        val mysqlConfig = config.asInstanceOf[MySQLSourceConfigEntry]
        LOG.info(s"Loading from mysql config: ${mysqlConfig}")
        val reader = new MySQLReader(session,
                                     mysqlConfig.host,
                                     mysqlConfig.port,
                                     mysqlConfig.database,
                                     mysqlConfig.table,
                                     mysqlConfig.user,
                                     mysqlConfig.password,
                                     mysqlConfig.sentence)
        Some(reader.read())
      case SourceCategory.PULSAR =>
        val pulsarConfig = config.asInstanceOf[PulsarSourceConfigEntry]
        LOG.info(s"Loading from pulsar config: ${pulsarConfig}")
        val reader = new PulsarReader(session,
                                      pulsarConfig.serviceUrl,
                                      pulsarConfig.adminUrl,
                                      pulsarConfig.options)
        Some(reader.read())
      case _ => {
        LOG.error(s"Data source ${config.category} not supported")
        None
      }
    }
  }

  private[this] def createDataSink(entry: DataSinkConfigEntry,
                                   dataBaseConfigEntry: DataBaseConfigEntry,
                                   userConfigEntry: UserConfigEntry,
                                   connectionConfigEntry: ConnectionConfigEntry,
                                   executionRetry: Int,
                                   config: SchemaConfigEntry): Writer = {
    entry.category match {
      case SinkCategory.CLIENT =>
        LOG.info("Write to Nebula using Graph Client")
        new NebulaGraphClientWriter(dataBaseConfigEntry,
                                    userConfigEntry,
                                    connectionConfigEntry,
                                    executionRetry,
                                    config)
      case _ =>
        throw new IllegalArgumentException("Not Support")
    }
  }

  /**
    * Repartition the data frame using the specified partition number.
    *
    * @param frame
    * @param partition
    * @return
    */
  private[this] def repartition(frame: DataFrame, partition: Int): DataFrame = {
    if (partition > 0) {
      frame.repartition(partition).toDF
    } else {
      frame
    }
  }

  /**
    * Check the statement execution result.
    *
    * @param code The statement's execution result code.
    * @return
    */
  private[this] def isSuccessfully(code: Int) = code == ErrorCode.SUCCEEDED

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy