All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.hive.spark.client.SparkClientImpl Maven / Gradle / Ivy

There is a newer version: 2.3.10
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.hive.spark.client;

import com.google.common.base.Charsets;
import com.google.common.base.Joiner;
import com.google.common.base.Strings;
import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.io.Resources;

import io.netty.channel.ChannelHandlerContext;
import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.concurrent.Promise;

import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.Serializable;
import java.io.Writer;
import java.net.URI;
import java.net.URL;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.UUID;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;

import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.conf.HiveConf.ConfVars;
import org.apache.hadoop.hive.shims.Utils;
import org.apache.hadoop.security.SecurityUtil;
import org.apache.hive.spark.client.rpc.Rpc;
import org.apache.hive.spark.client.rpc.RpcConfiguration;
import org.apache.hive.spark.client.rpc.RpcServer;
import org.apache.spark.SparkContext;
import org.apache.spark.SparkException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class SparkClientImpl implements SparkClient {
  private static final long serialVersionUID = 1L;

  private static final Logger LOG = LoggerFactory.getLogger(SparkClientImpl.class);

  private static final long DEFAULT_SHUTDOWN_TIMEOUT = 10000; // In milliseconds

  private static final String OSX_TEST_OPTS = "SPARK_OSX_TEST_OPTS";
  private static final String SPARK_HOME_ENV = "SPARK_HOME";
  private static final String SPARK_HOME_KEY = "spark.home";
  private static final String DRIVER_OPTS_KEY = "spark.driver.extraJavaOptions";
  private static final String EXECUTOR_OPTS_KEY = "spark.executor.extraJavaOptions";
  private static final String DRIVER_EXTRA_CLASSPATH = "spark.driver.extraClassPath";
  private static final String EXECUTOR_EXTRA_CLASSPATH = "spark.executor.extraClassPath";

  private final Map conf;
  private final HiveConf hiveConf;
  private final AtomicInteger childIdGenerator;
  private final Thread driverThread;
  private final Map> jobs;
  private final Rpc driverRpc;
  private final ClientProtocol protocol;
  private volatile boolean isAlive;

  SparkClientImpl(RpcServer rpcServer, Map conf, HiveConf hiveConf) throws IOException, SparkException {
    this.conf = conf;
    this.hiveConf = hiveConf;
    this.childIdGenerator = new AtomicInteger();
    this.jobs = Maps.newConcurrentMap();

    String clientId = UUID.randomUUID().toString();
    String secret = rpcServer.createSecret();
    this.driverThread = startDriver(rpcServer, clientId, secret);
    this.protocol = new ClientProtocol();

    try {
      // The RPC server will take care of timeouts here.
      this.driverRpc = rpcServer.registerClient(clientId, secret, protocol).get();
    } catch (Throwable e) {
      LOG.warn("Error while waiting for client to connect.", e);
      driverThread.interrupt();
      try {
        driverThread.join();
      } catch (InterruptedException ie) {
        // Give up.
        LOG.debug("Interrupted before driver thread was finished.");
      }
      throw Throwables.propagate(e);
    }

    driverRpc.addListener(new Rpc.Listener() {
        @Override
        public void rpcClosed(Rpc rpc) {
          if (isAlive) {
            LOG.warn("Client RPC channel closed unexpectedly.");
            isAlive = false;
          }
        }
    });
    isAlive = true;
  }

  @Override
  public  JobHandle submit(Job job) {
    return protocol.submit(job);
  }

  @Override
  public  Future run(Job job) {
    return protocol.run(job);
  }

  @Override
  public void stop() {
    if (isAlive) {
      isAlive = false;
      try {
        protocol.endSession();
      } catch (Exception e) {
        LOG.warn("Exception while waiting for end session reply.", e);
      } finally {
        driverRpc.close();
      }
    }

    long endTime = System.currentTimeMillis() + DEFAULT_SHUTDOWN_TIMEOUT;
    try {
      driverThread.join(DEFAULT_SHUTDOWN_TIMEOUT);
    } catch (InterruptedException ie) {
      LOG.debug("Interrupted before driver thread was finished.");
    }
    if (endTime - System.currentTimeMillis() <= 0) {
      LOG.warn("Timed out shutting down remote driver, interrupting...");
      driverThread.interrupt();
    }
  }

  @Override
  public Future addJar(URI uri) {
    return run(new AddJarJob(uri.toString()));
  }

  @Override
  public Future addFile(URI uri) {
    return run(new AddFileJob(uri.toString()));
  }

  @Override
  public Future getExecutorCount() {
    return run(new GetExecutorCountJob());
  }

  @Override
  public Future getDefaultParallelism() {
    return run(new GetDefaultParallelismJob());
  }

  @Override
  public boolean isActive() {
    return isAlive && driverRpc.isActive();
  }

  void cancel(String jobId) {
    protocol.cancel(jobId);
  }

  private Thread startDriver(final RpcServer rpcServer, final String clientId, final String secret)
      throws IOException {
    Runnable runnable;
    final String serverAddress = rpcServer.getAddress();
    final String serverPort = String.valueOf(rpcServer.getPort());

    if (conf.containsKey(SparkClientFactory.CONF_KEY_IN_PROCESS)) {
      // Mostly for testing things quickly. Do not do this in production.
      LOG.warn("!!!! Running remote driver in-process. !!!!");
      runnable = new Runnable() {
        @Override
        public void run() {
          List args = Lists.newArrayList();
          args.add("--remote-host");
          args.add(serverAddress);
          args.add("--remote-port");
          args.add(serverPort);
          args.add("--client-id");
          args.add(clientId);
          args.add("--secret");
          args.add(secret);

          for (Map.Entry e : conf.entrySet()) {
            args.add("--conf");
            args.add(String.format("%s=%s", e.getKey(), conf.get(e.getKey())));
          }
          try {
            RemoteDriver.main(args.toArray(new String[args.size()]));
          } catch (Exception e) {
            LOG.error("Error running driver.", e);
          }
        }
      };
    } else {
      // If a Spark installation is provided, use the spark-submit script. Otherwise, call the
      // SparkSubmit class directly, which has some caveats (like having to provide a proper
      // version of Guava on the classpath depending on the deploy mode).
      String sparkHome = conf.get(SPARK_HOME_KEY);
      if (sparkHome == null) {
        sparkHome = System.getenv(SPARK_HOME_ENV);
      }
      if (sparkHome == null) {
        sparkHome = System.getProperty(SPARK_HOME_KEY);
      }
      String sparkLogDir = conf.get("hive.spark.log.dir");
      if (sparkLogDir == null) {
        if (sparkHome == null) {
          sparkLogDir = "./target/";
        } else {
          sparkLogDir = sparkHome + "/logs/";
        }
      }

      String osxTestOpts = "";
      if (Strings.nullToEmpty(System.getProperty("os.name")).toLowerCase().contains("mac")) {
        osxTestOpts = Strings.nullToEmpty(System.getenv(OSX_TEST_OPTS));
      }

      String driverJavaOpts = Joiner.on(" ").skipNulls().join(
          "-Dhive.spark.log.dir=" + sparkLogDir, osxTestOpts, conf.get(DRIVER_OPTS_KEY));
      String executorJavaOpts = Joiner.on(" ").skipNulls().join(
          "-Dhive.spark.log.dir=" + sparkLogDir, osxTestOpts, conf.get(EXECUTOR_OPTS_KEY));

      // Create a file with all the job properties to be read by spark-submit. Change the
      // file's permissions so that only the owner can read it. This avoid having the
      // connection secret show up in the child process's command line.
      File properties = File.createTempFile("spark-submit.", ".properties");
      if (!properties.setReadable(false) || !properties.setReadable(true, true)) {
        throw new IOException("Cannot change permissions of job properties file.");
      }
      properties.deleteOnExit();

      Properties allProps = new Properties();
      // first load the defaults from spark-defaults.conf if available
      try {
        URL sparkDefaultsUrl = Thread.currentThread().getContextClassLoader().getResource("spark-defaults.conf");
        if (sparkDefaultsUrl != null) {
          LOG.info("Loading spark defaults: " + sparkDefaultsUrl);
          allProps.load(new ByteArrayInputStream(Resources.toByteArray(sparkDefaultsUrl)));
        }
      } catch (Exception e) {
        String msg = "Exception trying to load spark-defaults.conf: " + e;
        throw new IOException(msg, e);
      }
      // then load the SparkClientImpl config
      for (Map.Entry e : conf.entrySet()) {
        allProps.put(e.getKey(), conf.get(e.getKey()));
      }
      allProps.put(SparkClientFactory.CONF_CLIENT_ID, clientId);
      allProps.put(SparkClientFactory.CONF_KEY_SECRET, secret);
      allProps.put(DRIVER_OPTS_KEY, driverJavaOpts);
      allProps.put(EXECUTOR_OPTS_KEY, executorJavaOpts);

      String isTesting = conf.get("spark.testing");
      if (isTesting != null && isTesting.equalsIgnoreCase("true")) {
        String hiveHadoopTestClasspath = Strings.nullToEmpty(System.getenv("HIVE_HADOOP_TEST_CLASSPATH"));
        if (!hiveHadoopTestClasspath.isEmpty()) {
          String extraDriverClasspath = Strings.nullToEmpty((String)allProps.get(DRIVER_EXTRA_CLASSPATH));
          if (extraDriverClasspath.isEmpty()) {
            allProps.put(DRIVER_EXTRA_CLASSPATH, hiveHadoopTestClasspath);
          } else {
            extraDriverClasspath = extraDriverClasspath.endsWith(File.pathSeparator) ? extraDriverClasspath : extraDriverClasspath + File.pathSeparator;
            allProps.put(DRIVER_EXTRA_CLASSPATH, extraDriverClasspath + hiveHadoopTestClasspath);
          }

          String extraExecutorClasspath = Strings.nullToEmpty((String)allProps.get(EXECUTOR_EXTRA_CLASSPATH));
          if (extraExecutorClasspath.isEmpty()) {
            allProps.put(EXECUTOR_EXTRA_CLASSPATH, hiveHadoopTestClasspath);
          } else {
            extraExecutorClasspath = extraExecutorClasspath.endsWith(File.pathSeparator) ? extraExecutorClasspath : extraExecutorClasspath + File.pathSeparator;
            allProps.put(EXECUTOR_EXTRA_CLASSPATH, extraExecutorClasspath + hiveHadoopTestClasspath);
          }
        }
      }

      Writer writer = new OutputStreamWriter(new FileOutputStream(properties), Charsets.UTF_8);
      try {
        allProps.store(writer, "Spark Context configuration");
      } finally {
        writer.close();
      }

      // Define how to pass options to the child process. If launching in client (or local)
      // mode, the driver options need to be passed directly on the command line. Otherwise,
      // SparkSubmit will take care of that for us.
      String master = conf.get("spark.master");
      Preconditions.checkArgument(master != null, "spark.master is not defined.");

      List argv = Lists.newArrayList();

      if (hiveConf.getVar(HiveConf.ConfVars.HIVE_SERVER2_AUTHENTICATION).equalsIgnoreCase("kerberos")) {
          argv.add("kinit");
          String principal = SecurityUtil.getServerPrincipal(hiveConf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL),
              "0.0.0.0");
          String keyTabFile = hiveConf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB);
          argv.add(principal);
          argv.add("-k");
          argv.add("-t");
          argv.add(keyTabFile + ";");
      }

      if (sparkHome != null) {
        argv.add(new File(sparkHome, "bin/spark-submit").getAbsolutePath());
      } else {
        LOG.info("No spark.home provided, calling SparkSubmit directly.");
        argv.add(new File(System.getProperty("java.home"), "bin/java").getAbsolutePath());

        if (master.startsWith("local") || master.startsWith("mesos") || master.endsWith("-client") || master.startsWith("spark")) {
          String mem = conf.get("spark.driver.memory");
          if (mem != null) {
            argv.add("-Xms" + mem);
            argv.add("-Xmx" + mem);
          }

          String cp = conf.get("spark.driver.extraClassPath");
          if (cp != null) {
            argv.add("-classpath");
            argv.add(cp);
          }

          String libPath = conf.get("spark.driver.extraLibPath");
          if (libPath != null) {
            argv.add("-Djava.library.path=" + libPath);
          }

          String extra = conf.get(DRIVER_OPTS_KEY);
          if (extra != null) {
            for (String opt : extra.split("[ ]")) {
              if (!opt.trim().isEmpty()) {
                argv.add(opt.trim());
              }
            }
          }
        }

        argv.add("org.apache.spark.deploy.SparkSubmit");
      }

      if (master.equals("yarn-cluster")) {
        String executorCores = conf.get("spark.executor.cores");
        if (executorCores != null) {
          argv.add("--executor-cores");
          argv.add(executorCores);
        }

        String executorMemory = conf.get("spark.executor.memory");
        if (executorMemory != null) {
          argv.add("--executor-memory");
          argv.add(executorMemory);
        }

        String numOfExecutors = conf.get("spark.executor.instances");
        if (numOfExecutors != null) {
          argv.add("--num-executors");
          argv.add(numOfExecutors);
        }
      }

      if (hiveConf.getBoolVar(HiveConf.ConfVars.HIVE_SERVER2_ENABLE_DOAS)) {
        try {
          String currentUser = Utils.getUGI().getShortUserName();
          // do not do impersonation in CLI mode
          if (!currentUser.equals(System.getProperty("user.name"))) {
            LOG.info("Attempting impersonation of " + currentUser);
            argv.add("--proxy-user");
            argv.add(currentUser);
          }
        } catch (Exception e) {
          String msg = "Cannot obtain username: " + e;
          throw new IllegalStateException(msg, e);
        }
      }

      argv.add("--properties-file");
      argv.add(properties.getAbsolutePath());
      argv.add("--class");
      argv.add(RemoteDriver.class.getName());

      String jar = "spark-internal";
      if (SparkContext.jarOfClass(this.getClass()).isDefined()) {
        jar = SparkContext.jarOfClass(this.getClass()).get();
      }
      argv.add(jar);

      argv.add("--remote-host");
      argv.add(serverAddress);
      argv.add("--remote-port");
      argv.add(serverPort);

      //hive.spark.* keys are passed down to the RemoteDriver via --conf,
      //as --properties-file contains the spark.* keys that are meant for SparkConf object.
      for (String hiveSparkConfKey : RpcConfiguration.HIVE_SPARK_RSC_CONFIGS) {
        String value = RpcConfiguration.getValue(hiveConf, hiveSparkConfKey);
        argv.add("--conf");
        argv.add(String.format("%s=%s", hiveSparkConfKey, value));
      }

      String cmd = Joiner.on(" ").join(argv);
      LOG.info("Running client driver with argv: {}", cmd);
      ProcessBuilder pb = new ProcessBuilder("sh", "-c", cmd);

      // Prevent hive configurations from being visible in Spark.
      pb.environment().remove("HIVE_HOME");
      pb.environment().remove("HIVE_CONF_DIR");

      if (isTesting != null) {
        pb.environment().put("SPARK_TESTING", isTesting);
      }

      final Process child = pb.start();
      int childId = childIdGenerator.incrementAndGet();
      redirect("stdout-redir-" + childId, child.getInputStream());
      redirect("stderr-redir-" + childId, child.getErrorStream());

      runnable = new Runnable() {
        @Override
        public void run() {
          try {
            int exitCode = child.waitFor();
            if (exitCode != 0) {
              rpcServer.cancelClient(clientId, "Child process exited before connecting back");
              LOG.warn("Child process exited with code {}.", exitCode);
            }
          } catch (InterruptedException ie) {
            LOG.warn("Waiting thread interrupted, killing child process.");
            Thread.interrupted();
            child.destroy();
          } catch (Exception e) {
            LOG.warn("Exception while waiting for child process.", e);
          }
        }
      };
    }

    Thread thread = new Thread(runnable);
    thread.setDaemon(true);
    thread.setName("Driver");
    thread.start();
    return thread;
  }

  private void redirect(String name, InputStream in) {
    Thread thread = new Thread(new Redirector(in));
    thread.setName(name);
    thread.setDaemon(true);
    thread.start();
  }

  private class ClientProtocol extends BaseProtocol {

     JobHandleImpl submit(Job job) {
      final String jobId = UUID.randomUUID().toString();
      final Promise promise = driverRpc.createPromise();
      final JobHandleImpl handle = new JobHandleImpl(SparkClientImpl.this, promise, jobId);
      jobs.put(jobId, handle);

      final io.netty.util.concurrent.Future rpc = driverRpc.call(new JobRequest(jobId, job));
      LOG.debug("Send JobRequest[{}].", jobId);

      // Link the RPC and the promise so that events from one are propagated to the other as
      // needed.
      rpc.addListener(new GenericFutureListener>() {
        @Override
        public void operationComplete(io.netty.util.concurrent.Future f) {
          if (f.isSuccess()) {
            handle.changeState(JobHandle.State.QUEUED);
          } else if (!promise.isDone()) {
            promise.setFailure(f.cause());
          }
        }
      });
      promise.addListener(new GenericFutureListener>() {
        @Override
        public void operationComplete(Promise p) {
          if (jobId != null) {
            jobs.remove(jobId);
          }
          if (p.isCancelled() && !rpc.isDone()) {
            rpc.cancel(true);
          }
        }
      });
      return handle;
    }

     Future run(Job job) {
      @SuppressWarnings("unchecked")
      final io.netty.util.concurrent.Future rpc = (io.netty.util.concurrent.Future)
        driverRpc.call(new SyncJobRequest(job), Serializable.class);
      return rpc;
    }

    void cancel(String jobId) {
      driverRpc.call(new CancelJob(jobId));
    }

    Future endSession() {
      return driverRpc.call(new EndSession());
    }

    private void handle(ChannelHandlerContext ctx, Error msg) {
      LOG.warn("Error reported from remote driver.", msg.cause);
    }

    private void handle(ChannelHandlerContext ctx, JobMetrics msg) {
      JobHandleImpl handle = jobs.get(msg.jobId);
      if (handle != null) {
        handle.getMetrics().addMetrics(msg.sparkJobId, msg.stageId, msg.taskId, msg.metrics);
      } else {
        LOG.warn("Received metrics for unknown job {}", msg.jobId);
      }
    }

    private void handle(ChannelHandlerContext ctx, JobResult msg) {
      JobHandleImpl handle = jobs.remove(msg.id);
      if (handle != null) {
        LOG.info("Received result for {}", msg.id);
        handle.setSparkCounters(msg.sparkCounters);
        Throwable error = msg.error != null ? new SparkException(msg.error) : null;
        if (error == null) {
          handle.setSuccess(msg.result);
        } else {
          handle.setFailure(error);
        }
      } else {
        LOG.warn("Received result for unknown job {}", msg.id);
      }
    }

    private void handle(ChannelHandlerContext ctx, JobStarted msg) {
      JobHandleImpl handle = jobs.get(msg.id);
      if (handle != null) {
        handle.changeState(JobHandle.State.STARTED);
      } else {
        LOG.warn("Received event for unknown job {}", msg.id);
      }
    }

    private void handle(ChannelHandlerContext ctx, JobSubmitted msg) {
      JobHandleImpl handle = jobs.get(msg.clientJobId);
      if (handle != null) {
        LOG.info("Received spark job ID: {} for {}", msg.sparkJobId, msg.clientJobId);
        handle.addSparkJobId(msg.sparkJobId);
      } else {
        LOG.warn("Received spark job ID: {} for unknown job {}", msg.sparkJobId, msg.clientJobId);
      }
    }

  }

  private class Redirector implements Runnable {

    private final BufferedReader in;

    Redirector(InputStream in) {
      this.in = new BufferedReader(new InputStreamReader(in));
    }

    @Override
    public void run() {
      try {
        String line = null;
        while ((line = in.readLine()) != null) {
          LOG.info(line);
        }
      } catch (Exception e) {
        LOG.warn("Error in redirector thread.", e);
      }
    }

  }

  private static class AddJarJob implements Job {
    private static final long serialVersionUID = 1L;

    private final String path;

    AddJarJob() {
      this(null);
    }

    AddJarJob(String path) {
      this.path = path;
    }

    @Override
    public Serializable call(JobContext jc) throws Exception {
      jc.sc().addJar(path);
      // Following remote job may refer to classes in this jar, and the remote job would be executed
      // in a different thread, so we add this jar path to JobContext for further usage.
      jc.getAddedJars().put(path, System.currentTimeMillis());
      return null;
    }

  }

  private static class AddFileJob implements Job {
    private static final long serialVersionUID = 1L;

    private final String path;

    AddFileJob() {
      this(null);
    }

    AddFileJob(String path) {
      this.path = path;
    }

    @Override
    public Serializable call(JobContext jc) throws Exception {
      jc.sc().addFile(path);
      return null;
    }

  }

  private static class GetExecutorCountJob implements Job {
      private static final long serialVersionUID = 1L;

      @Override
      public Integer call(JobContext jc) throws Exception {
        // minus 1 here otherwise driver is also counted as an executor
        int count = jc.sc().sc().getExecutorMemoryStatus().size() - 1;
        return Integer.valueOf(count);
      }

  }

  private static class GetDefaultParallelismJob implements Job {
    private static final long serialVersionUID = 1L;

    @Override
    public Integer call(JobContext jc) throws Exception {
      return jc.sc().sc().defaultParallelism();
    }
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy