org.apache.hive.spark.client.rpc.RpcServer Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hive.spark.client.rpc;
import java.io.Closeable;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.security.SecureRandom;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.TimeUnit;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.sasl.AuthorizeCallback;
import javax.security.sasl.RealmCallback;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.Maps;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.concurrent.Promise;
import io.netty.util.concurrent.ScheduledFuture;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.hadoop.hive.common.classification.InterfaceAudience;
/**
* An RPC server. The server matches remote clients based on a secret that is generated on
* the server - the secret needs to be given to the client through some other mechanism for
* this to work.
*/
@InterfaceAudience.Private
public class RpcServer implements Closeable {
private static final Logger LOG = LoggerFactory.getLogger(RpcServer.class);
private static final SecureRandom RND = new SecureRandom();
private final String address;
private final Channel channel;
private final EventLoopGroup group;
private final int port;
private final ConcurrentMap pendingClients;
private final RpcConfiguration config;
public RpcServer(Map mapConf) throws IOException, InterruptedException {
this.config = new RpcConfiguration(mapConf);
this.group = new NioEventLoopGroup(
this.config.getRpcThreadCount(),
new ThreadFactoryBuilder()
.setNameFormat("RPC-Handler-%d")
.setDaemon(true)
.build());
ServerBootstrap serverBootstrap = new ServerBootstrap()
.group(group)
.channel(NioServerSocketChannel.class)
.childHandler(new ChannelInitializer() {
@Override
public void initChannel(SocketChannel ch) throws Exception {
SaslServerHandler saslHandler = new SaslServerHandler(config);
final Rpc newRpc = Rpc.createServer(saslHandler, config, ch, group);
saslHandler.rpc = newRpc;
Runnable cancelTask = new Runnable() {
@Override
public void run() {
LOG.warn("Timed out waiting for hello from client.");
newRpc.close();
}
};
saslHandler.cancelTask = group.schedule(cancelTask,
RpcServer.this.config.getConnectTimeoutMs(),
TimeUnit.MILLISECONDS);
}
})
.option(ChannelOption.SO_BACKLOG, 1)
.option(ChannelOption.SO_REUSEADDR, true)
.childOption(ChannelOption.SO_KEEPALIVE, true);
this.channel = bindServerPort(serverBootstrap).channel();
this.port = ((InetSocketAddress) channel.localAddress()).getPort();
this.pendingClients = Maps.newConcurrentMap();
this.address = this.config.getServerAddress();
}
/**
* Retry the list of configured ports until one is found
* @param serverBootstrap
* @return
* @throws InterruptedException
* @throws IOException
*/
private ChannelFuture bindServerPort(ServerBootstrap serverBootstrap)
throws InterruptedException, IOException {
List ports = config.getServerPorts();
if (ports.contains(0)) {
return serverBootstrap.bind(0).sync();
} else {
Random rand = new Random();
while(!ports.isEmpty()) {
int index = rand.nextInt(ports.size());
int port = ports.get(index);
ports.remove(index);
try {
return serverBootstrap.bind(port).sync();
} catch(Exception e) {
// Retry the next port
}
}
throw new IOException("No available ports from configured RPC Server ports for HiveServer2");
}
}
/**
* Tells the RPC server to expect a connection from a new client.
*
* @param clientId An identifier for the client. Must be unique.
* @param secret The secret the client will send to the server to identify itself.
* @param serverDispatcher The dispatcher to use when setting up the RPC instance.
* @return A future that can be used to wait for the client connection, which also provides the
* secret needed for the client to connect.
*/
public Future registerClient(final String clientId, String secret,
RpcDispatcher serverDispatcher) {
return registerClient(clientId, secret, serverDispatcher, config.getServerConnectTimeoutMs());
}
@VisibleForTesting
Future registerClient(final String clientId, String secret,
RpcDispatcher serverDispatcher, long clientTimeoutMs) {
final Promise promise = group.next().newPromise();
Runnable timeout = new Runnable() {
@Override
public void run() {
promise.setFailure(new TimeoutException("Timed out waiting for client connection."));
}
};
ScheduledFuture> timeoutFuture = group.schedule(timeout,
clientTimeoutMs,
TimeUnit.MILLISECONDS);
final ClientInfo client = new ClientInfo(clientId, promise, secret, serverDispatcher,
timeoutFuture);
if (pendingClients.putIfAbsent(clientId, client) != null) {
throw new IllegalStateException(
String.format("Client '%s' already registered.", clientId));
}
promise.addListener(new GenericFutureListener>() {
@Override
public void operationComplete(Promise p) {
if (!p.isSuccess()) {
pendingClients.remove(clientId);
}
}
});
return promise;
}
/**
* Tells the RPC server to cancel the connection from an existing pending client
* @param clientId The identifier for the client
* @param msg The error message about why the connection should be canceled
*/
public void cancelClient(final String clientId, final String msg) {
final ClientInfo cinfo = pendingClients.remove(clientId);
if (cinfo == null) {
// Nothing to be done here.
return;
}
cinfo.timeoutFuture.cancel(true);
if (!cinfo.promise.isDone()) {
cinfo.promise.setFailure(new RuntimeException(
String.format("Cancel client '%s'. Error: " + msg, clientId)));
}
}
/**
* Creates a secret for identifying a client connection.
*/
public String createSecret() {
byte[] secret = new byte[config.getSecretBits() / 8];
RND.nextBytes(secret);
StringBuilder sb = new StringBuilder();
for (byte b : secret) {
if (b < 10) {
sb.append("0");
}
sb.append(Integer.toHexString(b));
}
return sb.toString();
}
public String getAddress() {
return address;
}
public int getPort() {
return port;
}
@Override
public void close() {
try {
channel.close();
for (ClientInfo client : pendingClients.values()) {
client.promise.cancel(true);
}
pendingClients.clear();
} finally {
group.shutdownGracefully();
}
}
private class SaslServerHandler extends SaslHandler implements CallbackHandler {
private final SaslServer server;
private Rpc rpc;
private ScheduledFuture> cancelTask;
private String clientId;
private ClientInfo client;
SaslServerHandler(RpcConfiguration config) throws IOException {
super(config);
this.server = Sasl.createSaslServer(config.getSaslMechanism(), Rpc.SASL_PROTOCOL,
Rpc.SASL_REALM, config.getSaslOptions(), this);
}
@Override
protected boolean isComplete() {
return server.isComplete();
}
@Override
protected String getNegotiatedProperty(String name) {
return (String) server.getNegotiatedProperty(name);
}
@Override
protected Rpc.SaslMessage update(Rpc.SaslMessage challenge) throws IOException {
if (clientId == null) {
Preconditions.checkArgument(challenge.clientId != null,
"Missing client ID in SASL handshake.");
clientId = challenge.clientId;
client = pendingClients.get(clientId);
Preconditions.checkArgument(client != null,
"Unexpected client ID '%s' in SASL handshake.", clientId);
}
return new Rpc.SaslMessage(server.evaluateResponse(challenge.payload));
}
@Override
public byte[] wrap(byte[] data, int offset, int len) throws IOException {
return server.wrap(data, offset, len);
}
@Override
public byte[] unwrap(byte[] data, int offset, int len) throws IOException {
return server.unwrap(data, offset, len);
}
@Override
public void dispose() throws IOException {
if (!server.isComplete()) {
onError(new SaslException("Server closed before SASL negotiation finished."));
}
server.dispose();
}
@Override
protected void onComplete() throws Exception {
cancelTask.cancel(true);
client.timeoutFuture.cancel(true);
rpc.setDispatcher(client.dispatcher);
client.promise.setSuccess(rpc);
pendingClients.remove(client.id);
}
@Override
protected void onError(Throwable error) {
cancelTask.cancel(true);
if (client != null) {
client.timeoutFuture.cancel(true);
if (!client.promise.isDone()) {
client.promise.setFailure(error);
}
}
}
@Override
public void handle(Callback[] callbacks) {
Preconditions.checkState(client != null, "Handshake not initialized yet.");
for (Callback cb : callbacks) {
if (cb instanceof NameCallback) {
((NameCallback)cb).setName(clientId);
} else if (cb instanceof PasswordCallback) {
((PasswordCallback)cb).setPassword(client.secret.toCharArray());
} else if (cb instanceof AuthorizeCallback) {
((AuthorizeCallback) cb).setAuthorized(true);
} else if (cb instanceof RealmCallback) {
RealmCallback rb = (RealmCallback) cb;
rb.setText(rb.getDefaultText());
}
}
}
}
private static class ClientInfo {
final String id;
final Promise promise;
final String secret;
final RpcDispatcher dispatcher;
final ScheduledFuture> timeoutFuture;
private ClientInfo(String id, Promise promise, String secret, RpcDispatcher dispatcher,
ScheduledFuture> timeoutFuture) {
this.id = id;
this.promise = promise;
this.secret = secret;
this.dispatcher = dispatcher;
this.timeoutFuture = timeoutFuture;
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy