All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.hive.spark.client.rpc.RpcServer Maven / Gradle / Ivy

There is a newer version: 2.3.10
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.hive.spark.client.rpc;

import java.io.Closeable;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.security.SecureRandom;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.TimeUnit;

import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.sasl.AuthorizeCallback;
import javax.security.sasl.RealmCallback;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.Maps;
import com.google.common.util.concurrent.ThreadFactoryBuilder;

import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.concurrent.Promise;
import io.netty.util.concurrent.ScheduledFuture;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.hadoop.hive.common.classification.InterfaceAudience;

/**
 * An RPC server. The server matches remote clients based on a secret that is generated on
 * the server - the secret needs to be given to the client through some other mechanism for
 * this to work.
 */
@InterfaceAudience.Private
public class RpcServer implements Closeable {

  private static final Logger LOG = LoggerFactory.getLogger(RpcServer.class);
  private static final SecureRandom RND = new SecureRandom();

  private final String address;
  private final Channel channel;
  private final EventLoopGroup group;
  private final int port;
  private final ConcurrentMap pendingClients;
  private final RpcConfiguration config;

  public RpcServer(Map mapConf) throws IOException, InterruptedException {
    this.config = new RpcConfiguration(mapConf);
    this.group = new NioEventLoopGroup(
        this.config.getRpcThreadCount(),
        new ThreadFactoryBuilder()
            .setNameFormat("RPC-Handler-%d")
            .setDaemon(true)
            .build());
     ServerBootstrap serverBootstrap = new ServerBootstrap()
      .group(group)
      .channel(NioServerSocketChannel.class)
      .childHandler(new ChannelInitializer() {
          @Override
          public void initChannel(SocketChannel ch) throws Exception {
            SaslServerHandler saslHandler = new SaslServerHandler(config);
            final Rpc newRpc = Rpc.createServer(saslHandler, config, ch, group);
            saslHandler.rpc = newRpc;

            Runnable cancelTask = new Runnable() {
                @Override
                public void run() {
                  LOG.warn("Timed out waiting for hello from client.");
                  newRpc.close();
                }
            };
            saslHandler.cancelTask = group.schedule(cancelTask,
                RpcServer.this.config.getConnectTimeoutMs(),
                TimeUnit.MILLISECONDS);

          }
      })
      .option(ChannelOption.SO_BACKLOG, 1)
      .option(ChannelOption.SO_REUSEADDR, true)
      .childOption(ChannelOption.SO_KEEPALIVE, true);

    this.channel = bindServerPort(serverBootstrap).channel();
    this.port = ((InetSocketAddress) channel.localAddress()).getPort();
    this.pendingClients = Maps.newConcurrentMap();
    this.address = this.config.getServerAddress();
  }

  /**
   * Retry the list of configured ports until one is found
   * @param serverBootstrap
   * @return
   * @throws InterruptedException
   * @throws IOException
   */
  private ChannelFuture bindServerPort(ServerBootstrap serverBootstrap)
      throws InterruptedException, IOException {
    List ports = config.getServerPorts();
    if (ports.contains(0)) {
      return serverBootstrap.bind(0).sync();
    } else {
      Random rand = new Random();
      while(!ports.isEmpty()) {
        int index = rand.nextInt(ports.size());
        int port = ports.get(index);
        ports.remove(index);
        try {
          return serverBootstrap.bind(port).sync();
        } catch(Exception e) {
          // Retry the next port
        }
      }
      throw new IOException("No available ports from configured RPC Server ports for HiveServer2");
    }
  }

  /**
   * Tells the RPC server to expect a connection from a new client.
   *
   * @param clientId An identifier for the client. Must be unique.
   * @param secret The secret the client will send to the server to identify itself.
   * @param serverDispatcher The dispatcher to use when setting up the RPC instance.
   * @return A future that can be used to wait for the client connection, which also provides the
   *         secret needed for the client to connect.
   */
  public Future registerClient(final String clientId, String secret,
      RpcDispatcher serverDispatcher) {
    return registerClient(clientId, secret, serverDispatcher, config.getServerConnectTimeoutMs());
  }

  @VisibleForTesting
  Future registerClient(final String clientId, String secret,
      RpcDispatcher serverDispatcher, long clientTimeoutMs) {
    final Promise promise = group.next().newPromise();

    Runnable timeout = new Runnable() {
      @Override
      public void run() {
        promise.setFailure(new TimeoutException("Timed out waiting for client connection."));
      }
    };
    ScheduledFuture timeoutFuture = group.schedule(timeout,
        clientTimeoutMs,
        TimeUnit.MILLISECONDS);
    final ClientInfo client = new ClientInfo(clientId, promise, secret, serverDispatcher,
        timeoutFuture);
    if (pendingClients.putIfAbsent(clientId, client) != null) {
      throw new IllegalStateException(
          String.format("Client '%s' already registered.", clientId));
    }

    promise.addListener(new GenericFutureListener>() {
      @Override
      public void operationComplete(Promise p) {
        if (!p.isSuccess()) {
          pendingClients.remove(clientId);
        }
      }
    });

    return promise;
  }

  /**
   * Tells the RPC server to cancel the connection from an existing pending client
   * @param clientId The identifier for the client
   * @param msg The error message about why the connection should be canceled
   */
  public void cancelClient(final String clientId, final String msg) {
    final ClientInfo cinfo = pendingClients.remove(clientId);
    if (cinfo == null) {
      // Nothing to be done here.
      return;
    }
    cinfo.timeoutFuture.cancel(true);
    if (!cinfo.promise.isDone()) {
      cinfo.promise.setFailure(new RuntimeException(
          String.format("Cancel client '%s'. Error: " + msg, clientId)));
    }
  }

  /**
   * Creates a secret for identifying a client connection.
   */
  public String createSecret() {
    byte[] secret = new byte[config.getSecretBits() / 8];
    RND.nextBytes(secret);

    StringBuilder sb = new StringBuilder();
    for (byte b : secret) {
      if (b < 10) {
        sb.append("0");
      }
      sb.append(Integer.toHexString(b));
    }
    return sb.toString();
  }

  public String getAddress() {
    return address;
  }

  public int getPort() {
    return port;
  }

  @Override
  public void close() {
    try {
      channel.close();
      for (ClientInfo client : pendingClients.values()) {
        client.promise.cancel(true);
      }
      pendingClients.clear();
    } finally {
      group.shutdownGracefully();
    }
  }

  private class SaslServerHandler extends SaslHandler implements CallbackHandler {

    private final SaslServer server;
    private Rpc rpc;
    private ScheduledFuture cancelTask;
    private String clientId;
    private ClientInfo client;

    SaslServerHandler(RpcConfiguration config) throws IOException {
      super(config);
      this.server = Sasl.createSaslServer(config.getSaslMechanism(), Rpc.SASL_PROTOCOL,
        Rpc.SASL_REALM, config.getSaslOptions(), this);
    }

    @Override
    protected boolean isComplete() {
      return server.isComplete();
    }

    @Override
    protected String getNegotiatedProperty(String name) {
      return (String) server.getNegotiatedProperty(name);
    }

    @Override
    protected Rpc.SaslMessage update(Rpc.SaslMessage challenge) throws IOException {
      if (clientId == null) {
        Preconditions.checkArgument(challenge.clientId != null,
          "Missing client ID in SASL handshake.");
        clientId = challenge.clientId;
        client = pendingClients.get(clientId);
        Preconditions.checkArgument(client != null,
          "Unexpected client ID '%s' in SASL handshake.", clientId);
      }

      return new Rpc.SaslMessage(server.evaluateResponse(challenge.payload));
    }

    @Override
    public byte[] wrap(byte[] data, int offset, int len) throws IOException {
      return server.wrap(data, offset, len);
    }

    @Override
    public byte[] unwrap(byte[] data, int offset, int len) throws IOException {
      return server.unwrap(data, offset, len);
    }

    @Override
    public void dispose() throws IOException {
      if (!server.isComplete()) {
        onError(new SaslException("Server closed before SASL negotiation finished."));
      }
      server.dispose();
    }

    @Override
    protected void onComplete() throws Exception {
      cancelTask.cancel(true);
      client.timeoutFuture.cancel(true);
      rpc.setDispatcher(client.dispatcher);
      client.promise.setSuccess(rpc);
      pendingClients.remove(client.id);
    }

    @Override
    protected void onError(Throwable error) {
      cancelTask.cancel(true);
      if (client != null) {
        client.timeoutFuture.cancel(true);
        if (!client.promise.isDone()) {
          client.promise.setFailure(error);
        }
      }
    }

    @Override
    public void handle(Callback[] callbacks) {
      Preconditions.checkState(client != null, "Handshake not initialized yet.");
      for (Callback cb : callbacks) {
        if (cb instanceof NameCallback) {
          ((NameCallback)cb).setName(clientId);
        } else if (cb instanceof PasswordCallback) {
          ((PasswordCallback)cb).setPassword(client.secret.toCharArray());
        } else if (cb instanceof AuthorizeCallback) {
          ((AuthorizeCallback) cb).setAuthorized(true);
        } else if (cb instanceof RealmCallback) {
          RealmCallback rb = (RealmCallback) cb;
          rb.setText(rb.getDefaultText());
        }
      }
    }

  }

  private static class ClientInfo {

    final String id;
    final Promise promise;
    final String secret;
    final RpcDispatcher dispatcher;
    final ScheduledFuture timeoutFuture;

    private ClientInfo(String id, Promise promise, String secret, RpcDispatcher dispatcher,
        ScheduledFuture timeoutFuture) {
      this.id = id;
      this.promise = promise;
      this.secret = secret;
      this.dispatcher = dispatcher;
      this.timeoutFuture = timeoutFuture;
    }

  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy