org.apache.hadoop.hbase.security.SaslClientHandler Maven / Gradle / Ivy
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hbase.security;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hbase.classification.InterfaceAudience;
import org.apache.hadoop.ipc.RemoteException;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenIdentifier;
import javax.security.auth.callback.CallbackHandler;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import java.io.IOException;
import java.nio.charset.Charset;
import java.security.PrivilegedExceptionAction;
import java.util.Random;
/**
* Handles Sasl connections
*/
@InterfaceAudience.Private
public class SaslClientHandler extends ChannelDuplexHandler {
private static final Log LOG = LogFactory.getLog(SaslClientHandler.class);
private final boolean fallbackAllowed;
private final UserGroupInformation ticket;
/**
* Used for client or server's token to send or receive from each other.
*/
private final SaslClient saslClient;
private final SaslExceptionHandler exceptionHandler;
private final SaslSuccessfulConnectHandler successfulConnectHandler;
private byte[] saslToken;
private boolean firstRead = true;
private int retryCount = 0;
private Random random;
/**
* Constructor
*
* @param ticket the ugi
* @param method auth method
* @param token for Sasl
* @param serverPrincipal Server's Kerberos principal name
* @param fallbackAllowed True if server may also fall back to less secure connection
* @param rpcProtection Quality of protection. Can be 'authentication', 'integrity' or
* 'privacy'.
* @param exceptionHandler handler for exceptions
* @param successfulConnectHandler handler for succesful connects
* @throws java.io.IOException if handler could not be created
*/
public SaslClientHandler(UserGroupInformation ticket, AuthMethod method,
Token extends TokenIdentifier> token, String serverPrincipal, boolean fallbackAllowed,
String rpcProtection, SaslExceptionHandler exceptionHandler,
SaslSuccessfulConnectHandler successfulConnectHandler) throws IOException {
this.ticket = ticket;
this.fallbackAllowed = fallbackAllowed;
this.exceptionHandler = exceptionHandler;
this.successfulConnectHandler = successfulConnectHandler;
SaslUtil.initSaslProperties(rpcProtection);
switch (method) {
case DIGEST:
if (LOG.isDebugEnabled())
LOG.debug("Creating SASL " + AuthMethod.DIGEST.getMechanismName()
+ " client to authenticate to service at " + token.getService());
saslClient = createDigestSaslClient(new String[] { AuthMethod.DIGEST.getMechanismName() },
SaslUtil.SASL_DEFAULT_REALM, new HBaseSaslRpcClient.SaslClientCallbackHandler(token));
break;
case KERBEROS:
if (LOG.isDebugEnabled()) {
LOG.debug("Creating SASL " + AuthMethod.KERBEROS.getMechanismName()
+ " client. Server's Kerberos principal name is " + serverPrincipal);
}
if (serverPrincipal == null || serverPrincipal.isEmpty()) {
throw new IOException("Failed to specify server's Kerberos principal name");
}
String[] names = SaslUtil.splitKerberosName(serverPrincipal);
if (names.length != 3) {
throw new IOException(
"Kerberos principal does not have the expected format: " + serverPrincipal);
}
saslClient = createKerberosSaslClient(new String[] { AuthMethod.KERBEROS.getMechanismName() },
names[0], names[1]);
break;
default:
throw new IOException("Unknown authentication method " + method);
}
if (saslClient == null) {
throw new IOException("Unable to find SASL client implementation");
}
}
/**
* Create a Digest Sasl client
*
* @param mechanismNames names of mechanisms
* @param saslDefaultRealm default realm for sasl
* @param saslClientCallbackHandler handler for the client
* @return new SaslClient
* @throws java.io.IOException if creation went wrong
*/
protected SaslClient createDigestSaslClient(String[] mechanismNames, String saslDefaultRealm,
CallbackHandler saslClientCallbackHandler) throws IOException {
return Sasl.createSaslClient(mechanismNames, null, null, saslDefaultRealm, SaslUtil.SASL_PROPS,
saslClientCallbackHandler);
}
/**
* Create Kerberos client
*
* @param mechanismNames names of mechanisms
* @param userFirstPart first part of username
* @param userSecondPart second part of username
* @return new SaslClient
* @throws java.io.IOException if fails
*/
protected SaslClient createKerberosSaslClient(String[] mechanismNames, String userFirstPart,
String userSecondPart) throws IOException {
return Sasl
.createSaslClient(mechanismNames, null, userFirstPart, userSecondPart, SaslUtil.SASL_PROPS,
null);
}
@Override
public void channelUnregistered(ChannelHandlerContext ctx) throws Exception {
saslClient.dispose();
}
private byte[] evaluateChallenge(final byte[] challenge) throws Exception {
return ticket.doAs(new PrivilegedExceptionAction() {
@Override
public byte[] run() throws Exception {
return saslClient.evaluateChallenge(challenge);
}
});
}
@Override
public void handlerAdded(final ChannelHandlerContext ctx) throws Exception {
saslToken = new byte[0];
if (saslClient.hasInitialResponse()) {
saslToken = evaluateChallenge(saslToken);
}
if (saslToken != null) {
writeSaslToken(ctx, saslToken);
if (LOG.isDebugEnabled()) {
LOG.debug("Have sent token of size " + saslToken.length + " from initSASLContext.");
}
}
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
ByteBuf in = (ByteBuf) msg;
// If not complete, try to negotiate
if (!saslClient.isComplete()) {
while (!saslClient.isComplete() && in.isReadable()) {
readStatus(in);
int len = in.readInt();
if (firstRead) {
firstRead = false;
if (len == SaslUtil.SWITCH_TO_SIMPLE_AUTH) {
if (!fallbackAllowed) {
throw new IOException("Server asks us to fall back to SIMPLE auth, " + "but this "
+ "client is configured to only allow secure connections.");
}
if (LOG.isDebugEnabled()) {
LOG.debug("Server asks us to fall back to simple auth.");
}
saslClient.dispose();
ctx.pipeline().remove(this);
successfulConnectHandler.onSuccess(ctx.channel());
return;
}
}
saslToken = new byte[len];
if (LOG.isDebugEnabled()) {
LOG.debug("Will read input token of size " + saslToken.length
+ " for processing by initSASLContext");
}
in.readBytes(saslToken);
saslToken = evaluateChallenge(saslToken);
if (saslToken != null) {
if (LOG.isDebugEnabled()) {
LOG.debug("Will send token of size " + saslToken.length + " from initSASLContext.");
}
writeSaslToken(ctx, saslToken);
}
}
if (saslClient.isComplete()) {
String qop = (String) saslClient.getNegotiatedProperty(Sasl.QOP);
if (LOG.isDebugEnabled()) {
LOG.debug("SASL client context established. Negotiated QoP: " + qop);
}
boolean useWrap = qop != null && !"auth".equalsIgnoreCase(qop);
if (!useWrap) {
ctx.pipeline().remove(this);
}
successfulConnectHandler.onSuccess(ctx.channel());
}
}
// Normal wrapped reading
else {
try {
int length = in.readInt();
if (LOG.isDebugEnabled()) {
LOG.debug("Actual length is " + length);
}
saslToken = new byte[length];
in.readBytes(saslToken);
} catch (IndexOutOfBoundsException e) {
return;
}
try {
ByteBuf b = ctx.channel().alloc().buffer(saslToken.length);
b.writeBytes(saslClient.unwrap(saslToken, 0, saslToken.length));
ctx.fireChannelRead(b);
} catch (SaslException se) {
try {
saslClient.dispose();
} catch (SaslException ignored) {
LOG.debug("Ignoring SASL exception", ignored);
}
throw se;
}
}
}
/**
* Write SASL token
* @param ctx to write to
* @param saslToken to write
*/
private void writeSaslToken(final ChannelHandlerContext ctx, byte[] saslToken) {
ByteBuf b = ctx.alloc().buffer(4 + saslToken.length);
b.writeInt(saslToken.length);
b.writeBytes(saslToken, 0, saslToken.length);
ctx.writeAndFlush(b).addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
exceptionCaught(ctx, future.cause());
}
}
});
}
/**
* Get the read status
*
* @param inStream to read
* @throws org.apache.hadoop.ipc.RemoteException if status was not success
*/
private static void readStatus(ByteBuf inStream) throws RemoteException {
int status = inStream.readInt(); // read status
if (status != SaslStatus.SUCCESS.state) {
throw new RemoteException(inStream.toString(Charset.forName("UTF-8")),
inStream.toString(Charset.forName("UTF-8")));
}
}
@Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
throws Exception {
saslClient.dispose();
ctx.close();
if (this.random == null) {
this.random = new Random();
}
exceptionHandler.handle(this.retryCount++, this.random, cause);
}
@Override
public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
throws Exception {
// If not complete, try to negotiate
if (!saslClient.isComplete()) {
super.write(ctx, msg, promise);
} else {
ByteBuf in = (ByteBuf) msg;
try {
saslToken = saslClient.wrap(in.array(), in.readerIndex(), in.readableBytes());
} catch (SaslException se) {
try {
saslClient.dispose();
} catch (SaslException ignored) {
LOG.debug("Ignoring SASL exception", ignored);
}
promise.setFailure(se);
}
if (saslToken != null) {
ByteBuf out = ctx.channel().alloc().buffer(4 + saslToken.length);
out.writeInt(saslToken.length);
out.writeBytes(saslToken, 0, saslToken.length);
ctx.write(out).addListener(new ChannelFutureListener() {
@Override public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
exceptionCaught(ctx, future.cause());
}
}
});
saslToken = null;
}
}
}
/**
* Handler for exceptions during Sasl connection
*/
public interface SaslExceptionHandler {
/**
* Handle the exception
*
* @param retryCount current retry count
* @param random to create new backoff with
* @param cause of fail
*/
public void handle(int retryCount, Random random, Throwable cause);
}
/**
* Handler for successful connects
*/
public interface SaslSuccessfulConnectHandler {
/**
* Runs on success
*
* @param channel which is successfully authenticated
*/
public void onSuccess(Channel channel);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy