org.apache.thrift.transport.TSaslTransport Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of libthrift Show documentation
Show all versions of libthrift Show documentation
Thrift is a software framework for scalable cross-language
services development.
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.thrift.transport;
import java.io.UnsupportedEncodingException;
import java.util.Arrays;
import java.util.Map;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Set;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
import org.apache.thrift.EncodingUtils;
import org.apache.thrift.TByteArrayOutputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A superclass for SASL client/server thrift transports. A subclass need only
* implement the open method.
*/
abstract class TSaslTransport extends TTransport {
private static final Logger LOGGER = LoggerFactory.getLogger(TSaslTransport.class);
protected static final int DEFAULT_MAX_LENGTH = 0x7FFFFFFF;
protected static final int MECHANISM_NAME_BYTES = 1;
protected static final int STATUS_BYTES = 1;
protected static final int PAYLOAD_LENGTH_BYTES = 4;
protected static enum SaslRole {
SERVER, CLIENT;
}
/**
* Status bytes used during the initial Thrift SASL handshake.
*/
protected static enum NegotiationStatus {
START((byte)0x01),
OK((byte)0x02),
BAD((byte)0x03),
ERROR((byte)0x04),
COMPLETE((byte)0x05);
private final byte value;
private static final Map reverseMap =
new HashMap();
static {
for (NegotiationStatus s : NegotiationStatus.class.getEnumConstants()) {
reverseMap.put(s.getValue(), s);
}
}
private NegotiationStatus(byte val) {
this.value = val;
}
public byte getValue() {
return value;
}
public static NegotiationStatus byValue(byte val) {
return reverseMap.get(val);
}
}
/**
* Transport underlying this one.
*/
protected TTransport underlyingTransport;
/**
* Either a SASL client or a SASL server.
*/
private SaslParticipant sasl;
/**
* Whether or not we should wrap/unwrap reads/writes. Determined by whether or
* not a QOP is negotiated during the SASL handshake.
*/
private boolean shouldWrap = false;
/**
* Buffer for input.
*/
private TMemoryInputTransport readBuffer = new TMemoryInputTransport();
/**
* Buffer for output.
*/
private final TByteArrayOutputStream writeBuffer = new TByteArrayOutputStream(1024);
/**
* Create a TSaslTransport. It's assumed that setSaslServer will be called
* later to initialize the SASL endpoint underlying this transport.
*
* @param underlyingTransport
* The thrift transport which this transport is wrapping.
*/
protected TSaslTransport(TTransport underlyingTransport) {
this.underlyingTransport = underlyingTransport;
}
/**
* Create a TSaslTransport which acts as a client.
*
* @param saslClient
* The SaslClient
which this transport will use for SASL
* negotiation.
* @param underlyingTransport
* The thrift transport which this transport is wrapping.
*/
protected TSaslTransport(SaslClient saslClient, TTransport underlyingTransport) {
sasl = new SaslParticipant(saslClient);
this.underlyingTransport = underlyingTransport;
}
protected void setSaslServer(SaslServer saslServer) {
sasl = new SaslParticipant(saslServer);
}
// Used to read the status byte and payload length.
private final byte[] messageHeader = new byte[STATUS_BYTES + PAYLOAD_LENGTH_BYTES];
/**
* Send a complete Thrift SASL message.
*
* @param status
* The status to send.
* @param payload
* The data to send as the payload of this message.
* @throws TTransportException
*/
protected void sendSaslMessage(NegotiationStatus status, byte[] payload) throws TTransportException {
if (payload == null)
payload = new byte[0];
messageHeader[0] = status.getValue();
EncodingUtils.encodeBigEndian(payload.length, messageHeader, STATUS_BYTES);
if (LOGGER.isDebugEnabled())
LOGGER.debug(getRole() + ": Writing message with status {} and payload length {}",
status, payload.length);
underlyingTransport.write(messageHeader);
underlyingTransport.write(payload);
underlyingTransport.flush();
}
/**
* Read a complete Thrift SASL message.
*
* @return The SASL status and payload from this message.
* @throws TTransportException
* Thrown if there is a failure reading from the underlying
* transport, or if a status code of BAD or ERROR is encountered.
*/
protected SaslResponse receiveSaslMessage() throws TTransportException {
underlyingTransport.readAll(messageHeader, 0, messageHeader.length);
byte statusByte = messageHeader[0];
byte[] payload = new byte[EncodingUtils.decodeBigEndian(messageHeader, STATUS_BYTES)];
underlyingTransport.readAll(payload, 0, payload.length);
NegotiationStatus status = NegotiationStatus.byValue(statusByte);
if (status == null) {
sendAndThrowMessage(NegotiationStatus.ERROR, "Invalid status " + statusByte);
} else if (status == NegotiationStatus.BAD || status == NegotiationStatus.ERROR) {
try {
String remoteMessage = new String(payload, "UTF-8");
throw new TTransportException("Peer indicated failure: " + remoteMessage);
} catch (UnsupportedEncodingException e) {
throw new TTransportException(e);
}
}
if (LOGGER.isDebugEnabled())
LOGGER.debug(getRole() + ": Received message with status {} and payload length {}",
status, payload.length);
return new SaslResponse(status, payload);
}
/**
* Send a Thrift SASL message with the given status (usaully BAD or ERROR) and
* string message, and then throw a TTransportException with the given
* message.
*
* @param status
* The Thrift SASL status code to send. Usually BAD or ERROR.
* @param message
* The optional message to send to the other side.
* @throws TTransportException
* Always thrown with the message provided.
*/
protected void sendAndThrowMessage(NegotiationStatus status, String message) throws TTransportException {
try {
sendSaslMessage(status, message.getBytes());
} catch (Exception e) {
LOGGER.warn("Could not send failure response", e);
message += "\nAlso, could not send response: " + e.toString();
}
throw new TTransportException(message);
}
/**
* Implemented by subclasses to start the Thrift SASL handshake process. When
* this method completes, the SaslParticipant
in this class is
* assumed to be initialized.
*
* @throws TTransportException
* @throws SaslException
*/
abstract protected void handleSaslStartMessage() throws TTransportException, SaslException;
protected abstract SaslRole getRole();
/**
* Opens the underlying transport if it's not already open and then performs
* SASL negotiation. If a QOP is negoiated during this SASL handshake, it used
* for all communication on this transport after this call is complete.
*/
@Override
public void open() throws TTransportException {
LOGGER.debug("opening transport {}", this);
if (sasl != null && sasl.isComplete())
throw new TTransportException("SASL transport already open");
if (!underlyingTransport.isOpen())
underlyingTransport.open();
try {
// Negotiate a SASL mechanism. The client also sends its
// initial response, or an empty one.
handleSaslStartMessage();
LOGGER.debug("{}: Start message handled", getRole());
SaslResponse message = null;
while (!sasl.isComplete()) {
message = receiveSaslMessage();
if (message.status != NegotiationStatus.COMPLETE &&
message.status != NegotiationStatus.OK) {
throw new TTransportException("Expected COMPLETE or OK, got " + message.status);
}
byte[] challenge = sasl.evaluateChallengeOrResponse(message.payload);
// If we are the client, and the server indicates COMPLETE, we don't need to
// send back any further response.
if (message.status == NegotiationStatus.COMPLETE &&
getRole() == SaslRole.CLIENT) {
LOGGER.debug("{}: All done!", getRole());
break;
}
sendSaslMessage(sasl.isComplete() ? NegotiationStatus.COMPLETE : NegotiationStatus.OK,
challenge);
}
LOGGER.debug("{}: Main negotiation loop complete", getRole());
assert sasl.isComplete();
// If we're the client, and we're complete, but the server isn't
// complete yet, we need to wait for its response. This will occur
// with ANONYMOUS auth, for example, where we send an initial response
// and are immediately complete.
if (getRole() == SaslRole.CLIENT &&
(message == null || message.status == NegotiationStatus.OK)) {
LOGGER.debug("{}: SASL Client receiving last message", getRole());
message = receiveSaslMessage();
if (message.status != NegotiationStatus.COMPLETE) {
throw new TTransportException(
"Expected SASL COMPLETE, but got " + message.status);
}
}
} catch (SaslException e) {
try {
sendAndThrowMessage(NegotiationStatus.BAD, e.getMessage());
} finally {
underlyingTransport.close();
}
}
String qop = (String) sasl.getNegotiatedProperty(Sasl.QOP);
if (qop != null && !qop.equalsIgnoreCase("auth"))
shouldWrap = true;
}
/**
* Get the underlying SaslClient
.
*
* @return The SaslClient
, or null
if this transport
* is backed by a SaslServer
.
*/
public SaslClient getSaslClient() {
return sasl.saslClient;
}
/**
* Get the underlying SaslServer
.
*
* @return The SaslServer
, or null
if this transport
* is backed by a SaslClient
.
*/
public SaslServer getSaslServer() {
return sasl.saslServer;
}
/**
* Read a 4-byte word from the underlying transport and interpret it as an
* integer.
*
* @return The length prefix of the next SASL message to read.
* @throws TTransportException
* Thrown if reading from the underlying transport fails.
*/
protected int readLength() throws TTransportException {
byte[] lenBuf = new byte[4];
underlyingTransport.readAll(lenBuf, 0, lenBuf.length);
return EncodingUtils.decodeBigEndian(lenBuf);
}
/**
* Write the given integer as 4 bytes to the underlying transport.
*
* @param length
* The length prefix of the next SASL message to write.
* @throws TTransportException
* Thrown if writing to the underlying transport fails.
*/
protected void writeLength(int length) throws TTransportException {
byte[] lenBuf = new byte[4];
TFramedTransport.encodeFrameSize(length, lenBuf);
underlyingTransport.write(lenBuf);
}
// Below is the SASL implementation of the TTransport interface.
/**
* Closes the underlying transport and disposes of the SASL implementation
* underlying this transport.
*/
@Override
public void close() {
underlyingTransport.close();
try {
sasl.dispose();
} catch (SaslException e) {
// Not much we can do here.
}
}
/**
* True if the underlying transport is open and the SASL handshake is
* complete.
*/
@Override
public boolean isOpen() {
return underlyingTransport.isOpen() && sasl != null && sasl.isComplete();
}
/**
* Read from the underlying transport. Unwraps the contents if a QOP was
* negotiated during the SASL handshake.
*/
@Override
public int read(byte[] buf, int off, int len) throws TTransportException {
if (!isOpen())
throw new TTransportException("SASL authentication not complete");
int got = readBuffer.read(buf, off, len);
if (got > 0) {
return got;
}
// Read another frame of data
try {
readFrame();
} catch (SaslException e) {
throw new TTransportException(e);
}
return readBuffer.read(buf, off, len);
}
/**
* Read a single frame of data from the underlying transport, unwrapping if
* necessary.
*
* @throws TTransportException
* Thrown if there's an error reading from the underlying transport.
* @throws SaslException
* Thrown if there's an error unwrapping the data.
*/
private void readFrame() throws TTransportException, SaslException {
int dataLength = readLength();
if (dataLength < 0)
throw new TTransportException("Read a negative frame size (" + dataLength + ")!");
byte[] buff = new byte[dataLength];
LOGGER.debug("{}: reading data length: {}", getRole(), dataLength);
underlyingTransport.readAll(buff, 0, dataLength);
if (shouldWrap) {
buff = sasl.unwrap(buff, 0, buff.length);
LOGGER.debug("data length after unwrap: {}", buff.length);
}
readBuffer.reset(buff);
}
/**
* Write to the underlying transport.
*/
@Override
public void write(byte[] buf, int off, int len) throws TTransportException {
if (!isOpen())
throw new TTransportException("SASL authentication not complete");
writeBuffer.write(buf, off, len);
}
/**
* Flushes to the underlying transport. Wraps the contents if a QOP was
* negotiated during the SASL handshake.
*/
@Override
public void flush() throws TTransportException {
byte[] buf = writeBuffer.get();
int dataLength = writeBuffer.len();
writeBuffer.reset();
if (shouldWrap) {
LOGGER.debug("data length before wrap: {}", dataLength);
try {
buf = sasl.wrap(buf, 0, dataLength);
} catch (SaslException e) {
throw new TTransportException(e);
}
dataLength = buf.length;
}
LOGGER.debug("writing data length: {}", dataLength);
writeLength(dataLength);
underlyingTransport.write(buf, 0, dataLength);
underlyingTransport.flush();
}
/**
* Used exclusively by readSaslMessage to return both a status and data.
*/
protected static class SaslResponse {
public NegotiationStatus status;
public byte[] payload;
public SaslResponse(NegotiationStatus status, byte[] payload) {
this.status = status;
this.payload = payload;
}
}
/**
* Used to abstract over the SaslServer
and
* SaslClient
classes, which share a lot of their interface, but
* unfortunately don't share a common superclass.
*/
private static class SaslParticipant {
// One of these will always be null.
public SaslServer saslServer;
public SaslClient saslClient;
public SaslParticipant(SaslServer saslServer) {
this.saslServer = saslServer;
}
public SaslParticipant(SaslClient saslClient) {
this.saslClient = saslClient;
}
public byte[] evaluateChallengeOrResponse(byte[] challengeOrResponse) throws SaslException {
if (saslClient != null) {
return saslClient.evaluateChallenge(challengeOrResponse);
} else {
return saslServer.evaluateResponse(challengeOrResponse);
}
}
public boolean isComplete() {
if (saslClient != null)
return saslClient.isComplete();
else
return saslServer.isComplete();
}
public void dispose() throws SaslException {
if (saslClient != null)
saslClient.dispose();
else
saslServer.dispose();
}
public byte[] unwrap(byte[] buf, int off, int len) throws SaslException {
if (saslClient != null)
return saslClient.unwrap(buf, off, len);
else
return saslServer.unwrap(buf, off, len);
}
public byte[] wrap(byte[] buf, int off, int len) throws SaslException {
if (saslClient != null)
return saslClient.wrap(buf, off, len);
else
return saslServer.wrap(buf, off, len);
}
public Object getNegotiatedProperty(String propName) {
if (saslClient != null)
return saslClient.getNegotiatedProperty(propName);
else
return saslServer.getNegotiatedProperty(propName);
}
}
}