com.intel.analytics.zoo.common.PythonInterpreter.scala Maven / Gradle / Ivy
/*
* Copyright 2018 Analytics Zoo Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.intel.analytics.zoo.common
import java.util.concurrent.{ExecutorService, Executors, ThreadFactory}
import com.intel.analytics.zoo.core.TFNetNative
import com.intel.analytics.zoo.pipeline.api.net.NetUtils
import jep.{JepConfig, JepException, NamingConventionClassEnquirer, SharedInterpreter}
import org.apache.commons.lang.exception.ExceptionUtils
import org.apache.logging.log4j.{Level, Logger, LogManager}
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.concurrent.duration.Duration
object PythonInterpreter {
protected val logger = LogManager.getLogger(this.getClass)
private var threadPool: ExecutorService = null
private val context = new ExecutionContext {
threadPool = Executors.newFixedThreadPool(1, new ThreadFactory {
override def newThread(r: Runnable): Thread = {
val t = Executors.defaultThreadFactory().newThread(r)
t.setName("jep-thread " + t.getId)
t.setDaemon(true)
t
}
})
def execute(runnable: Runnable) {
threadPool.submit(runnable)
}
def reportFailure(t: Throwable): Unit = {
throw t
}
}
def getSharedInterpreter(): SharedInterpreter = {
sharedInterpreter
}
def check(): Unit = {
if (sharedInterpreter == null) {
init()
}
}
def init(): Unit = synchronized {
if (sharedInterpreter == null) {
sharedInterpreter = createInterpreter()
}
}
private var sharedInterpreter: SharedInterpreter = null
private def createInterpreter(): SharedInterpreter = {
if (System.getenv("PYTHONHOME") == null) {
throw new RuntimeException("PYTHONHOME is unset, please set PYTHONHOME first.")
}
// Load TFNet before create interpreter, or the TFNet will throw an OMP error #13
TFNetNative.isLoaded
val createInterp = () => {
val config: JepConfig = new JepConfig()
config.setClassEnquirer(new NamingConventionClassEnquirer())
SharedInterpreter.setConfig(config)
val sharedInterpreter = new SharedInterpreter()
sharedInterpreter
}
logger.debug("Creating jep interpreter...")
threadExecute(createInterp)
}
private def threadExecute[T](task: () => T,
timeout: Duration = Duration("100s")): T = {
try {
val re = Array(task).map(t => Future {
t()
}(context)).map(future => {
Await.result(future, timeout)
})
re(0)
} catch {
case t: Throwable =>
// Don't use logger here, or spark local will stuck when catch an exception.
// println("Warn: " + ExceptionUtils.getStackTrace(t))
throw new JepException(t)
}
}
def exec(s: String): Unit = {
logger.debug(s"jep exec ${s}")
check()
val func = () => {
sharedInterpreter.exec(s)
}
threadExecute(func)
}
def set(s: String, o: AnyRef): Unit = {
logger.debug(s"jep set ${s}")
check()
val func = () => {
sharedInterpreter.set(s, o)
}
threadExecute(func)
}
def getValue[T](name: String): T = {
logger.debug(s"jep getValue ${name}")
check()
val func = () => {
val re = sharedInterpreter.getValue(name)
re
}
threadExecute(func).asInstanceOf[T]
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy