org.apache.lens.ml.algo.spark.SparkMLDriver Maven / Gradle / Ivy
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.lens.ml.algo.spark;
import java.io.File;
import java.io.FilenameFilter;
import java.util.ArrayList;
import java.util.List;
import org.apache.lens.api.LensConf;
import org.apache.lens.api.LensException;
import org.apache.lens.ml.algo.api.MLAlgo;
import org.apache.lens.ml.algo.api.MLDriver;
import org.apache.lens.ml.algo.lib.Algorithms;
import org.apache.lens.ml.algo.spark.dt.DecisionTreeAlgo;
import org.apache.lens.ml.algo.spark.lr.LogisticRegressionAlgo;
import org.apache.lens.ml.algo.spark.nb.NaiveBayesAlgo;
import org.apache.lens.ml.algo.spark.svm.SVMAlgo;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
/**
* The Class SparkMLDriver.
*/
public class SparkMLDriver implements MLDriver {
/** The Constant LOG. */
public static final Log LOG = LogFactory.getLog(SparkMLDriver.class);
/** The owns spark context. */
private boolean ownsSparkContext = true;
/**
* The Enum SparkMasterMode.
*/
private enum SparkMasterMode {
// Embedded mode used in tests
/** The embedded. */
EMBEDDED,
// Yarn client and Yarn cluster modes are used when deploying the app to Yarn cluster
/** The yarn client. */
YARN_CLIENT,
/** The yarn cluster. */
YARN_CLUSTER
}
/** The algorithms. */
private final Algorithms algorithms = new Algorithms();
/** The client mode. */
private SparkMasterMode clientMode = SparkMasterMode.EMBEDDED;
/** The is started. */
private boolean isStarted;
/** The spark conf. */
private SparkConf sparkConf;
/** The spark context. */
private JavaSparkContext sparkContext;
/**
* Use spark context.
*
* @param jsc the jsc
*/
public void useSparkContext(JavaSparkContext jsc) {
ownsSparkContext = false;
this.sparkContext = jsc;
}
/*
* (non-Javadoc)
*
* @see org.apache.lens.ml.MLDriver#isAlgoSupported(java.lang.String)
*/
@Override
public boolean isAlgoSupported(String name) {
return algorithms.isAlgoSupported(name);
}
/*
* (non-Javadoc)
*
* @see org.apache.lens.ml.MLDriver#getAlgoInstance(java.lang.String)
*/
@Override
public MLAlgo getAlgoInstance(String name) throws LensException {
checkStarted();
if (!isAlgoSupported(name)) {
return null;
}
MLAlgo algo = null;
try {
algo = algorithms.getAlgoForName(name);
if (algo instanceof BaseSparkAlgo) {
((BaseSparkAlgo) algo).setSparkContext(sparkContext);
}
} catch (LensException exc) {
LOG.error("Error creating algo object", exc);
}
return algo;
}
/**
* Register algos.
*/
private void registerAlgos() {
algorithms.register(NaiveBayesAlgo.class);
algorithms.register(SVMAlgo.class);
algorithms.register(LogisticRegressionAlgo.class);
algorithms.register(DecisionTreeAlgo.class);
}
/*
* (non-Javadoc)
*
* @see org.apache.lens.ml.MLDriver#init(org.apache.lens.api.LensConf)
*/
@Override
public void init(LensConf conf) throws LensException {
sparkConf = new SparkConf();
registerAlgos();
for (String key : conf.getProperties().keySet()) {
if (key.startsWith("lens.ml.sparkdriver.")) {
sparkConf.set(key.substring("lens.ml.sparkdriver.".length()), conf.getProperties().get(key));
}
}
String sparkAppMaster = sparkConf.get("spark.master");
if ("yarn-client".equalsIgnoreCase(sparkAppMaster)) {
clientMode = SparkMasterMode.YARN_CLIENT;
} else if ("yarn-cluster".equalsIgnoreCase(sparkAppMaster)) {
clientMode = SparkMasterMode.YARN_CLUSTER;
} else if ("local".equalsIgnoreCase(sparkAppMaster) || StringUtils.isBlank(sparkAppMaster)) {
clientMode = SparkMasterMode.EMBEDDED;
} else {
throw new IllegalArgumentException("Invalid master mode " + sparkAppMaster);
}
if (clientMode == SparkMasterMode.YARN_CLIENT || clientMode == SparkMasterMode.YARN_CLUSTER) {
String sparkHome = System.getenv("SPARK_HOME");
if (StringUtils.isNotBlank(sparkHome)) {
sparkConf.setSparkHome(sparkHome);
}
// If SPARK_HOME is not set, SparkConf can read from the Lens-site.xml or System properties.
if (StringUtils.isBlank(sparkConf.get("spark.home"))) {
throw new IllegalArgumentException("Spark home is not set");
}
LOG.info("Spark home is set to " + sparkConf.get("spark.home"));
}
sparkConf.setAppName("lens-ml");
}
/*
* (non-Javadoc)
*
* @see org.apache.lens.ml.MLDriver#start()
*/
@Override
public void start() throws LensException {
if (sparkContext == null) {
sparkContext = new JavaSparkContext(sparkConf);
}
// Adding jars to spark context is only required when running in yarn-client mode
if (clientMode != SparkMasterMode.EMBEDDED) {
// TODO Figure out only necessary set of JARs to be added for HCatalog
// Add hcatalog and hive jars
String hiveLocation = System.getenv("HIVE_HOME");
if (StringUtils.isBlank(hiveLocation)) {
throw new LensException("HIVE_HOME is not set");
}
LOG.info("HIVE_HOME at " + hiveLocation);
File hiveLibDir = new File(hiveLocation, "lib");
FilenameFilter jarFileFilter = new FilenameFilter() {
@Override
public boolean accept(File file, String s) {
return s.endsWith(".jar");
}
};
List jarFiles = new ArrayList();
// Add hive jars
for (File jarFile : hiveLibDir.listFiles(jarFileFilter)) {
jarFiles.add(jarFile.getAbsolutePath());
LOG.info("Adding HIVE jar " + jarFile.getAbsolutePath());
sparkContext.addJar(jarFile.getAbsolutePath());
}
// Add hcatalog jars
File hcatalogDir = new File(hiveLocation + "/hcatalog/share/hcatalog");
for (File jarFile : hcatalogDir.listFiles(jarFileFilter)) {
jarFiles.add(jarFile.getAbsolutePath());
LOG.info("Adding HCATALOG jar " + jarFile.getAbsolutePath());
sparkContext.addJar(jarFile.getAbsolutePath());
}
// Add the current jar
String[] lensSparkLibJars = JavaSparkContext.jarOfClass(SparkMLDriver.class);
for (String lensSparkJar : lensSparkLibJars) {
LOG.info("Adding Lens JAR " + lensSparkJar);
sparkContext.addJar(lensSparkJar);
}
}
isStarted = true;
LOG.info("Created Spark context for app: '" + sparkContext.appName() + "', Spark master: " + sparkContext.master());
}
/*
* (non-Javadoc)
*
* @see org.apache.lens.ml.MLDriver#stop()
*/
@Override
public void stop() throws LensException {
if (!isStarted) {
LOG.warn("Spark driver was not started");
return;
}
isStarted = false;
if (ownsSparkContext) {
sparkContext.stop();
}
LOG.info("Stopped spark context " + this);
}
@Override
public List getAlgoNames() {
return algorithms.getAlgorithmNames();
}
/**
* Check started.
*
* @throws LensException the lens exception
*/
public void checkStarted() throws LensException {
if (!isStarted) {
throw new LensException("Spark driver is not started yet");
}
}
public JavaSparkContext getSparkContext() {
return sparkContext;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy