All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.zoo.common.NNContext.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2018 Analytics Zoo Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.zoo.common

import java.util.Properties

import com.intel.analytics.bigdl.utils.Engine
import org.apache.log4j.Logger
import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, SparkException}

/**
 * [[NNContext]] wraps a spark context in Analytics Zoo.
 *
 */
object NNContext {

  private val logger = Logger.getLogger(getClass)

  private[zoo] def checkSparkVersion(reportWarning: Boolean = false) = {
    checkVersion(SPARK_VERSION, ZooBuildInfo.spark_version, "Spark", reportWarning)
  }

  private[zoo] def checkScalaVersion(reportWarning: Boolean = false) = {
    checkVersion(scala.util.Properties.versionNumberString,
      ZooBuildInfo.scala_version, "Scala", reportWarning, level = 2)
  }

  private def checkVersion(
                            runtimeVersion: String,
                            compileTimeVersion: String,
                            project: String,
                            reportWarning: Boolean = false,
                            level: Int = 1): Unit = {
    val Array(runtimeMajor, runtimeFeature, runtimeMaintenance) =
      runtimeVersion.split("\\.").map(_.toInt)
    val Array(compileMajor, compileFeature, compileMaintenance) =
      compileTimeVersion.split("\\.").map(_.toInt)

    if (runtimeVersion != compileTimeVersion) {
      val warnMessage = s"The compile time $project version is not compatible with" +
        s" the runtime $project version. Compile time version is $compileTimeVersion," +
        s" runtime version is $runtimeVersion. "
      val errorMessage = s"\nIf you want to bypass this check, please set" +
        s"spark.analytics.zoo.versionCheck to false, and if you want to only" +
        s"report a warning message, please set spark.analytics.zoo" +
        s".versionCheck.warning to true."
      val diffLevel = if (runtimeMajor != compileMajor) {
        1
      } else if (runtimeFeature != compileFeature) {
        2
      } else {
        3
      }
      if (diffLevel <= level && !reportWarning) {
        Utils.logUsageErrorAndThrowException(warnMessage + errorMessage)
      }
      logger.warn(warnMessage)
    } else {
      logger.info(s"$project version check pass")
    }
  }

  private[zoo] object ZooBuildInfo {

    val (
      analytics_zoo_verion: String,
      spark_version: String,
      scala_version: String,
      java_version: String) = {

      val resourceStream = Thread.currentThread().getContextClassLoader.
        getResourceAsStream("zoo-version-info.properties")

      try {
        val unknownProp = ""
        val props = new Properties()
        props.load(resourceStream)
        (
          props.getProperty("analytics_zoo_verion", unknownProp),
          props.getProperty("spark_version", unknownProp),
          props.getProperty("scala_version", unknownProp),
          props.getProperty("java_version", unknownProp)
        )
      } catch {
        case npe: NullPointerException =>
          throw new RuntimeException("Error while locating file zoo-version-info.properties, " +
            "if you are using an IDE to run your program, please make sure the mvn" +
            " generate-resources phase is executed and a zoo-version-info.properties file" +
            " is located in zoo/target/extra-resources", npe)
        case e: Exception =>
          throw new RuntimeException("Error loading properties from zoo-version-info.properties", e)
      } finally {
        if (resourceStream != null) {
          try {
            resourceStream.close()
          } catch {
            case e: Exception =>
              throw new SparkException("Error closing zoo build info resource stream", e)
          }
        }
      }
    }
  }

  /**
   * Creates or gets a SparkContext with optimized configuration for BigDL performance.
   * The method will also initialize the BigDL engine.
   *
   * Note: if you use spark-shell or Jupyter notebook, as the Spark context is created
   * before your code, you have to set Spark conf values through command line options
   * or properties file, and init BigDL engine manually.
   *
   * @param conf User defined Spark conf
   * @param appName name of the current context
   * @return Spark Context
   */
  def initNNContext(conf: SparkConf, appName: String): SparkContext = {
    val bigdlConf = Engine.createSparkConf(conf)
    if (appName != null) {
      bigdlConf.setAppName(appName)
    }
    if (bigdlConf.getBoolean("spark.analytics.zoo.versionCheck", defaultValue = false)) {
      val reportWarning =
        bigdlConf.getBoolean("spark.analytics.zoo.versionCheck.warning", defaultValue = false)
      checkSparkVersion(reportWarning)
      checkScalaVersion(reportWarning)
    }
    val sc = SparkContext.getOrCreate(bigdlConf)
    Engine.init
    sc
  }

  /**
   * Creates or gets SparkContext with optimized configuration for BigDL performance.
   * The method will also initialize the BigDL engine.
   *
   * Note: if you use spark-shell or Jupyter notebook, as the Spark context is created
   * before your code, you have to set Spark conf values through command line options
   * or properties file, and init BigDL engine manually.
   *
   * @param conf User defined Spark conf
   * @return Spark Context
   */
  def initNNContext(conf: SparkConf): SparkContext = {
    initNNContext(conf = conf, appName = null)
  }

  /**
   * Creates or gets a SparkContext with optimized configuration for BigDL performance.
   * The method will also initialize the BigDL engine.
   *
   * Note: if you use spark-shell or Jupyter notebook, as the Spark context is created
   * before your code, you have to set Spark conf values through command line options
   * or properties file, and init BigDL engine manually.
   *
   * @param appName name of the current context
   * @return Spark Context
   */
  def initNNContext(appName: String): SparkContext = {
    initNNContext(conf = null, appName = appName)
  }

  def initNNContext(): SparkContext = {
    initNNContext(null, null)
  }

}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy