All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.thrift.transport.TSaslTransport Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.thrift.transport;

import java.nio.charset.StandardCharsets;
import java.util.Objects;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
import org.apache.thrift.EncodingUtils;
import org.apache.thrift.TByteArrayOutputStream;
import org.apache.thrift.TConfiguration;
import org.apache.thrift.transport.layered.TFramedTransport;
import org.apache.thrift.transport.sasl.NegotiationStatus;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * A superclass for SASL client/server thrift transports. A subclass need only implement the 
 * open method.
 */
abstract class TSaslTransport extends TEndpointTransport {

  private static final Logger LOGGER = LoggerFactory.getLogger(TSaslTransport.class);

  protected static final int DEFAULT_MAX_LENGTH = 0x7FFFFFFF;

  protected static final int MECHANISM_NAME_BYTES = 1;
  protected static final int STATUS_BYTES = 1;
  protected static final int PAYLOAD_LENGTH_BYTES = 4;

  protected static enum SaslRole {
    SERVER,
    CLIENT;
  }

  /** Transport underlying this one. */
  protected TTransport underlyingTransport;

  /** Either a SASL client or a SASL server. */
  private SaslParticipant sasl;

  /**
   * Whether or not we should wrap/unwrap reads/writes. Determined by whether or not a QOP is
   * negotiated during the SASL handshake.
   */
  private boolean shouldWrap = false;

  /** Buffer for input. */
  private TMemoryInputTransport readBuffer;

  /** Buffer for output. */
  private final TByteArrayOutputStream writeBuffer = new TByteArrayOutputStream(1024);

  /**
   * Create a TSaslTransport. It's assumed that setSaslServer will be called later to initialize the
   * SASL endpoint underlying this transport.
   *
   * @param underlyingTransport The thrift transport which this transport is wrapping.
   */
  protected TSaslTransport(TTransport underlyingTransport) throws TTransportException {
    super(
        Objects.isNull(underlyingTransport.getConfiguration())
            ? new TConfiguration()
            : underlyingTransport.getConfiguration());
    this.underlyingTransport = underlyingTransport;
    this.readBuffer = new TMemoryInputTransport(underlyingTransport.getConfiguration());
  }

  /**
   * Create a TSaslTransport which acts as a client.
   *
   * @param saslClient The SaslClient which this transport will use for SASL
   *     negotiation.
   * @param underlyingTransport The thrift transport which this transport is wrapping.
   */
  protected TSaslTransport(SaslClient saslClient, TTransport underlyingTransport)
      throws TTransportException {
    super(
        Objects.isNull(underlyingTransport.getConfiguration())
            ? new TConfiguration()
            : underlyingTransport.getConfiguration());
    sasl = new SaslParticipant(saslClient);
    this.underlyingTransport = underlyingTransport;
    this.readBuffer = new TMemoryInputTransport(underlyingTransport.getConfiguration());
  }

  protected void setSaslServer(SaslServer saslServer) {
    sasl = new SaslParticipant(saslServer);
  }

  // Used to read the status byte and payload length.
  private final byte[] messageHeader = new byte[STATUS_BYTES + PAYLOAD_LENGTH_BYTES];

  /**
   * Send a complete Thrift SASL message.
   *
   * @param status The status to send.
   * @param payload The data to send as the payload of this message.
   * @throws TTransportException
   */
  protected void sendSaslMessage(NegotiationStatus status, byte[] payload)
      throws TTransportException {
    if (payload == null) payload = new byte[0];

    messageHeader[0] = status.getValue();
    EncodingUtils.encodeBigEndian(payload.length, messageHeader, STATUS_BYTES);

    LOGGER.debug(
        "{}: Writing message with status {} and payload length {}",
        getRole(),
        status,
        payload.length);

    underlyingTransport.write(messageHeader);
    underlyingTransport.write(payload);
    underlyingTransport.flush();
  }

  /**
   * Read a complete Thrift SASL message.
   *
   * @return The SASL status and payload from this message.
   * @throws TTransportException Thrown if there is a failure reading from the underlying transport,
   *     or if a status code of BAD or ERROR is encountered.
   */
  protected SaslResponse receiveSaslMessage() throws TTransportException {
    underlyingTransport.readAll(messageHeader, 0, messageHeader.length);

    byte statusByte = messageHeader[0];

    NegotiationStatus status = NegotiationStatus.byValue(statusByte);
    if (status == null) {
      throw sendAndThrowMessage(NegotiationStatus.ERROR, "Invalid status " + statusByte);
    }

    int payloadBytes = EncodingUtils.decodeBigEndian(messageHeader, STATUS_BYTES);
    if (payloadBytes < 0 || payloadBytes > getConfiguration().getMaxMessageSize() /* 100 MB */) {
      throw sendAndThrowMessage(
          NegotiationStatus.ERROR, "Invalid payload header length: " + payloadBytes);
    }

    byte[] payload = new byte[payloadBytes];
    underlyingTransport.readAll(payload, 0, payload.length);

    if (status == NegotiationStatus.BAD || status == NegotiationStatus.ERROR) {
      String remoteMessage = new String(payload, StandardCharsets.UTF_8);
      throw new TTransportException("Peer indicated failure: " + remoteMessage);
    }
    LOGGER.debug(
        "{}: Received message with status {} and payload length {}",
        getRole(),
        status,
        payload.length);
    return new SaslResponse(status, payload);
  }

  /**
   * Send a Thrift SASL message with the given status (usually BAD or ERROR) and string message, and
   * then throw a TTransportException with the given message.
   *
   * @param status The Thrift SASL status code to send. Usually BAD or ERROR.
   * @param message The optional message to send to the other side.
   * @throws TTransportException Always thrown with the message provided.
   * @return always throws TTransportException but declares return type to allow throw
   *     sendAndThrowMessage(...) to inform compiler control flow
   */
  protected TTransportException sendAndThrowMessage(NegotiationStatus status, String message)
      throws TTransportException {
    try {
      sendSaslMessage(status, message.getBytes(StandardCharsets.UTF_8));
    } catch (Exception e) {
      LOGGER.warn("Could not send failure response", e);
      message += "\nAlso, could not send response: " + e.toString();
    }
    throw new TTransportException(message);
  }

  /**
   * Implemented by subclasses to start the Thrift SASL handshake process. When this method
   * completes, the SaslParticipant in this class is assumed to be initialized.
   *
   * @throws TTransportException
   * @throws SaslException
   */
  protected abstract void handleSaslStartMessage() throws TTransportException, SaslException;

  protected abstract SaslRole getRole();

  /**
   * Opens the underlying transport if it's not already open and then performs SASL negotiation. If
   * a QOP is negotiated during this SASL handshake, it used for all communication on this transport
   * after this call is complete.
   */
  @Override
  public void open() throws TTransportException {
    /*
     * readSaslHeader is used to tag whether the SASL header has been read properly.
     * If there is a problem in reading the header, there might not be any
     * data in the stream, possibly a TCP health check from load balancer.
     */
    boolean readSaslHeader = false;

    LOGGER.debug("opening transport {}", this);
    if (sasl != null && sasl.isComplete())
      throw new TTransportException("SASL transport already open");

    if (!underlyingTransport.isOpen()) underlyingTransport.open();

    try {
      // Negotiate a SASL mechanism. The client also sends its
      // initial response, or an empty one.
      handleSaslStartMessage();
      readSaslHeader = true;
      LOGGER.debug("{}: Start message handled", getRole());

      SaslResponse message = null;
      while (!sasl.isComplete()) {
        message = receiveSaslMessage();
        if (message.status != NegotiationStatus.COMPLETE
            && message.status != NegotiationStatus.OK) {
          throw new TTransportException("Expected COMPLETE or OK, got " + message.status);
        }

        byte[] challenge = sasl.evaluateChallengeOrResponse(message.payload);

        // If we are the client, and the server indicates COMPLETE, we don't need to
        // send back any further response.
        if (message.status == NegotiationStatus.COMPLETE && getRole() == SaslRole.CLIENT) {
          LOGGER.debug("{}: All done!", getRole());
          continue;
        }

        sendSaslMessage(
            sasl.isComplete() ? NegotiationStatus.COMPLETE : NegotiationStatus.OK, challenge);
      }
      LOGGER.debug("{}: Main negotiation loop complete", getRole());

      // If we're the client, and we're complete, but the server isn't
      // complete yet, we need to wait for its response. This will occur
      // with ANONYMOUS auth, for example, where we send an initial response
      // and are immediately complete.
      if (getRole() == SaslRole.CLIENT
          && (message == null || message.status == NegotiationStatus.OK)) {
        LOGGER.debug("{}: SASL Client receiving last message", getRole());
        message = receiveSaslMessage();
        if (message.status != NegotiationStatus.COMPLETE) {
          throw new TTransportException("Expected SASL COMPLETE, but got " + message.status);
        }
      }
    } catch (SaslException e) {
      try {
        LOGGER.error("SASL negotiation failure", e);
        throw sendAndThrowMessage(NegotiationStatus.BAD, e.getMessage());
      } finally {
        underlyingTransport.close();
      }
    } catch (TTransportException e) {
      // If there is no-data or no-sasl header in the stream,
      // log the failure, and clean up the underlying transport.
      if (!readSaslHeader && e.getType() == TTransportException.END_OF_FILE) {
        underlyingTransport.close();
        LOGGER.debug("No data or no sasl data in the stream during negotiation");
      }
      throw e;
    }

    String qop = (String) sasl.getNegotiatedProperty(Sasl.QOP);
    if (qop != null && !qop.equalsIgnoreCase("auth")) shouldWrap = true;
  }

  /**
   * Get the underlying SaslClient.
   *
   * @return The SaslClient, or null if this transport is backed by a
   *     SaslServer.
   */
  public SaslClient getSaslClient() {
    return sasl.saslClient;
  }

  /**
   * Get the underlying transport that Sasl is using.
   *
   * @return The TTransport transport
   */
  public TTransport getUnderlyingTransport() {
    return underlyingTransport;
  }

  /**
   * Get the underlying SaslServer.
   *
   * @return The SaslServer, or null if this transport is backed by a
   *     SaslClient.
   */
  public SaslServer getSaslServer() {
    return sasl.saslServer;
  }

  /**
   * Read a 4-byte word from the underlying transport and interpret it as an integer.
   *
   * @return The length prefix of the next SASL message to read.
   * @throws TTransportException Thrown if reading from the underlying transport fails.
   */
  protected int readLength() throws TTransportException {
    byte[] lenBuf = new byte[4];
    underlyingTransport.readAll(lenBuf, 0, lenBuf.length);
    return EncodingUtils.decodeBigEndian(lenBuf);
  }

  /**
   * Write the given integer as 4 bytes to the underlying transport.
   *
   * @param length The length prefix of the next SASL message to write.
   * @throws TTransportException Thrown if writing to the underlying transport fails.
   */
  protected void writeLength(int length) throws TTransportException {
    byte[] lenBuf = new byte[4];
    TFramedTransport.encodeFrameSize(length, lenBuf);
    underlyingTransport.write(lenBuf);
  }

  // Below is the SASL implementation of the TTransport interface.

  /**
   * Closes the underlying transport and disposes of the SASL implementation underlying this
   * transport.
   */
  @Override
  public void close() {
    underlyingTransport.close();
    try {
      sasl.dispose();
    } catch (SaslException e) {
      LOGGER.warn("Failed to dispose sasl participant.", e);
    }
  }

  /** True if the underlying transport is open and the SASL handshake is complete. */
  @Override
  public boolean isOpen() {
    return underlyingTransport.isOpen() && sasl != null && sasl.isComplete();
  }

  /**
   * Read from the underlying transport. Unwraps the contents if a QOP was negotiated during the
   * SASL handshake.
   */
  @Override
  public int read(byte[] buf, int off, int len) throws TTransportException {
    if (!isOpen()) throw new TTransportException("SASL authentication not complete");

    int got = readBuffer.read(buf, off, len);
    if (got > 0) {
      return got;
    }

    // Read another frame of data
    try {
      readFrame();
    } catch (SaslException e) {
      throw new TTransportException(e);
    } catch (TTransportException transportException) {
      // If there is no-data or no-sasl header in the stream, log the failure, and rethrow.
      if (transportException.getType() == TTransportException.END_OF_FILE) {
        LOGGER.debug("No data or no sasl data in the stream during negotiation");
      }
      throw transportException;
    }

    return readBuffer.read(buf, off, len);
  }

  /**
   * Read a single frame of data from the underlying transport, unwrapping if necessary.
   *
   * @throws TTransportException Thrown if there's an error reading from the underlying transport.
   * @throws SaslException Thrown if there's an error unwrapping the data.
   */
  private void readFrame() throws TTransportException, SaslException {
    int dataLength = readLength();

    if (dataLength < 0)
      throw new TTransportException("Read a negative frame size (" + dataLength + ")!");

    byte[] buff = new byte[dataLength];
    LOGGER.debug("{}: reading data length: {}", getRole(), dataLength);
    underlyingTransport.readAll(buff, 0, dataLength);
    if (shouldWrap) {
      buff = sasl.unwrap(buff, 0, buff.length);
      LOGGER.debug("data length after unwrap: {}", buff.length);
    }
    readBuffer.reset(buff);
  }

  /** Write to the underlying transport. */
  @Override
  public void write(byte[] buf, int off, int len) throws TTransportException {
    if (!isOpen()) throw new TTransportException("SASL authentication not complete");

    writeBuffer.write(buf, off, len);
  }

  /**
   * Flushes to the underlying transport. Wraps the contents if a QOP was negotiated during the SASL
   * handshake.
   */
  @Override
  public void flush() throws TTransportException {
    byte[] buf = writeBuffer.get();
    int dataLength = writeBuffer.len();
    writeBuffer.reset();

    if (shouldWrap) {
      LOGGER.debug("data length before wrap: {}", dataLength);
      try {
        buf = sasl.wrap(buf, 0, dataLength);
      } catch (SaslException e) {
        throw new TTransportException(e);
      }
      dataLength = buf.length;
    }
    LOGGER.debug("writing data length: {}", dataLength);
    writeLength(dataLength);
    underlyingTransport.write(buf, 0, dataLength);
    underlyingTransport.flush();
  }

  /** Used exclusively by readSaslMessage to return both a status and data. */
  protected static class SaslResponse {
    public NegotiationStatus status;
    public byte[] payload;

    public SaslResponse(NegotiationStatus status, byte[] payload) {
      this.status = status;
      this.payload = payload;
    }
  }

  /**
   * Used to abstract over the SaslServer and SaslClient classes, which
   * share a lot of their interface, but unfortunately don't share a common superclass.
   */
  private static class SaslParticipant {
    // One of these will always be null.
    public SaslServer saslServer;
    public SaslClient saslClient;

    public SaslParticipant(SaslServer saslServer) {
      this.saslServer = saslServer;
    }

    public SaslParticipant(SaslClient saslClient) {
      this.saslClient = saslClient;
    }

    public byte[] evaluateChallengeOrResponse(byte[] challengeOrResponse) throws SaslException {
      if (saslClient != null) {
        return saslClient.evaluateChallenge(challengeOrResponse);
      } else {
        return saslServer.evaluateResponse(challengeOrResponse);
      }
    }

    public boolean isComplete() {
      if (saslClient != null) return saslClient.isComplete();
      else return saslServer.isComplete();
    }

    public void dispose() throws SaslException {
      if (saslClient != null) saslClient.dispose();
      else saslServer.dispose();
    }

    public byte[] unwrap(byte[] buf, int off, int len) throws SaslException {
      if (saslClient != null) return saslClient.unwrap(buf, off, len);
      else return saslServer.unwrap(buf, off, len);
    }

    public byte[] wrap(byte[] buf, int off, int len) throws SaslException {
      if (saslClient != null) return saslClient.wrap(buf, off, len);
      else return saslServer.wrap(buf, off, len);
    }

    public Object getNegotiatedProperty(String propName) {
      if (saslClient != null) return saslClient.getNegotiatedProperty(propName);
      else return saslServer.getNegotiatedProperty(propName);
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy