All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.thrift.transport.sasl.NonblockingSaslHandler Maven / Gradle / Ivy

Go to download

Thrift is a software framework for scalable cross-language services development.

There is a newer version: 0.21.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.thrift.transport.sasl;

import java.nio.channels.SelectionKey;
import java.nio.charset.StandardCharsets;

import javax.security.sasl.SaslServer;

import org.apache.thrift.TByteArrayOutputStream;
import org.apache.thrift.TProcessor;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.protocol.TProtocolFactory;
import org.apache.thrift.server.ServerContext;
import org.apache.thrift.server.TServerEventHandler;
import org.apache.thrift.transport.TMemoryTransport;
import org.apache.thrift.transport.TNonblockingTransport;
import org.apache.thrift.transport.TTransportException;
import org.apache.thrift.transport.sasl.TSaslNegotiationException.ErrorType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import static org.apache.thrift.transport.sasl.NegotiationStatus.COMPLETE;
import static org.apache.thrift.transport.sasl.NegotiationStatus.OK;

/**
 * State machine managing one sasl connection in a nonblocking way.
 */
public class NonblockingSaslHandler {
  private static final Logger LOGGER = LoggerFactory.getLogger(NonblockingSaslHandler.class);

  private static final int INTEREST_NONE = 0;
  private static final int INTEREST_READ = SelectionKey.OP_READ;
  private static final int INTEREST_WRITE = SelectionKey.OP_WRITE;

  // Tracking the current running phase
  private Phase currentPhase = Phase.INITIIALIIZING;
  // Tracking the next phase on the next invocation of the state machine.
  // It should be the same as current phase if current phase is not yet finished.
  // Otherwise, if it is different from current phase, the statemachine is in a transition state:
  // current phase is done, and next phase is not yet started.
  private Phase nextPhase = currentPhase;

  // Underlying nonblocking transport
  private SelectionKey selectionKey;
  private TNonblockingTransport underlyingTransport;

  // APIs for intercepting event / customizing behaviors:
  // Factories (decorating the base implementations) & EventHandler (intercepting)
  private TSaslServerFactory saslServerFactory;
  private TSaslProcessorFactory processorFactory;
  private TProtocolFactory inputProtocolFactory;
  private TProtocolFactory outputProtocolFactory;
  private TServerEventHandler eventHandler;
  private ServerContext serverContext;
  // It turns out the event handler implementation in hive sometimes creates a null ServerContext.
  // In order to know whether TServerEventHandler#createContext is called we use such a flag.
  private boolean serverContextCreated = false;

  // Wrapper around sasl server
  private ServerSaslPeer saslPeer;

  // Sasl negotiation io
  private SaslNegotiationFrameReader saslResponse;
  private SaslNegotiationFrameWriter saslChallenge;
  // IO for request from and response to the socket
  private DataFrameReader requestReader;
  private DataFrameWriter responseWriter;
  // If sasl is negotiated for integrity/confidentiality protection
  private boolean dataProtected;

  public NonblockingSaslHandler(SelectionKey selectionKey, TNonblockingTransport underlyingTransport,
                                TSaslServerFactory saslServerFactory, TSaslProcessorFactory processorFactory,
                                TProtocolFactory inputProtocolFactory, TProtocolFactory outputProtocolFactory,
                                TServerEventHandler eventHandler) {
    this.selectionKey = selectionKey;
    this.underlyingTransport = underlyingTransport;
    this.saslServerFactory = saslServerFactory;
    this.processorFactory = processorFactory;
    this.inputProtocolFactory = inputProtocolFactory;
    this.outputProtocolFactory = outputProtocolFactory;
    this.eventHandler = eventHandler;

    saslResponse = new SaslNegotiationFrameReader();
    saslChallenge = new SaslNegotiationFrameWriter();
    requestReader = new DataFrameReader();
    responseWriter = new DataFrameWriter();
  }

  /**
   * Get current phase of the state machine.
   *
   * @return current phase.
   */
  public Phase getCurrentPhase() {
    return currentPhase;
  }

  /**
   * Get next phase of the state machine.
   * It is different from current phase iff current phase is done (and next phase not yet started).
   *
   * @return next phase.
   */
  public Phase getNextPhase() {
    return nextPhase;
  }

  /**
   *
   * @return underlying nonblocking socket
   */
  public TNonblockingTransport getUnderlyingTransport() {
    return underlyingTransport;
  }

  /**
   *
   * @return SaslServer instance
   */
  public SaslServer getSaslServer() {
    return saslPeer.getSaslServer();
  }

  /**
   *
   * @return true if current phase is done.
   */
  public boolean isCurrentPhaseDone() {
    return currentPhase != nextPhase;
  }

  /**
   * Run state machine.
   *
   * @throws IllegalStateException if current state is already done.
   */
  public void runCurrentPhase() {
    currentPhase.runStateMachine(this);
  }

  /**
   * When current phase is intrested in read selection, calling this will run the current phase and
   * its following phases if the following ones are interested to read, until there is nothing
   * available in the underlying transport.
   *
   * @throws IllegalStateException if is called in an irrelevant phase.
   */
  public void handleRead() {
    handleOps(INTEREST_READ);
  }

  /**
   * Similiar to handleRead. But it is for write ops.
   *
   * @throws IllegalStateException if it is called in an irrelevant phase.
   */
  public void handleWrite() {
    handleOps(INTEREST_WRITE);
  }

  private void handleOps(int interestOps) {
    if (currentPhase.selectionInterest != interestOps) {
      throw new IllegalStateException("Current phase " + currentPhase + " but got interest " +
          interestOps);
    }
    runCurrentPhase();
    if (isCurrentPhaseDone() && nextPhase.selectionInterest == interestOps) {
      stepToNextPhase();
      handleOps(interestOps);
    }
  }

  /**
   * When current phase is finished, it's expected to call this method first before running the
   * state machine again.
   * By calling this, "next phase" is marked as started (and not done), thus is ready to run.
   *
   * @throws IllegalArgumentException if current phase is not yet done.
   */
  public void stepToNextPhase() {
    if (!isCurrentPhaseDone()) {
      throw new IllegalArgumentException("Not yet done with current phase: " + currentPhase);
    }
    LOGGER.debug("Switch phase {} to {}", currentPhase, nextPhase);
    switch (nextPhase) {
      case INITIIALIIZING:
        throw new IllegalStateException("INITIALIZING cannot be the next phase of " + currentPhase);
      default:
    }
    // If next phase's interest is not the same as current,  nor the same as the selection key,
    // we need to change interest on the selector.
    if (!(nextPhase.selectionInterest == currentPhase.selectionInterest ||
        nextPhase.selectionInterest == selectionKey.interestOps())) {
      changeSelectionInterest(nextPhase.selectionInterest);
    }
    currentPhase = nextPhase;
  }

  private void changeSelectionInterest(int selectionInterest) {
    selectionKey.interestOps(selectionInterest);
  }

  // sasl negotiaion failure handling
  private void failSaslNegotiation(TSaslNegotiationException e) {
    LOGGER.error("Sasl negotiation failed", e);
    String errorMsg = e.getDetails();
    saslChallenge.withHeaderAndPayload(new byte[]{e.getErrorType().code.getValue()},
        errorMsg.getBytes(StandardCharsets.UTF_8));
    nextPhase = Phase.WRITING_FAILURE_MESSAGE;
  }

  private void fail(Exception e) {
    LOGGER.error("Failed io in " + currentPhase, e);
    nextPhase = Phase.CLOSING;
  }

  private void failIO(TTransportException e) {
    StringBuilder errorMsg = new StringBuilder("IO failure ")
        .append(e.getType())
        .append(" in ")
        .append(currentPhase);
    if (e.getMessage() != null) {
      errorMsg.append(": ").append(e.getMessage());
    }
    LOGGER.error(errorMsg.toString(), e);
    nextPhase = Phase.CLOSING;
  }

  // Read handlings

  private void handleInitializing() {
    try {
      saslResponse.read(underlyingTransport);
      if (saslResponse.isComplete()) {
        SaslNegotiationHeaderReader startHeader = saslResponse.getHeader();
        if (startHeader.getStatus() != NegotiationStatus.START) {
          throw new TInvalidSaslFrameException("Expecting START status but got " + startHeader.getStatus());
        }
        String mechanism = new String(saslResponse.getPayload(), StandardCharsets.UTF_8);
        saslPeer = saslServerFactory.getSaslPeer(mechanism);
        saslResponse.clear();
        nextPhase = Phase.READING_SASL_RESPONSE;
      }
    } catch (TSaslNegotiationException e) {
      failSaslNegotiation(e);
    } catch (TTransportException e) {
      failIO(e);
    }
  }

  private void handleReadingSaslResponse() {
    try {
      saslResponse.read(underlyingTransport);
      if (saslResponse.isComplete()) {
        nextPhase = Phase.EVALUATING_SASL_RESPONSE;
      }
    } catch (TSaslNegotiationException e) {
      failSaslNegotiation(e);
    } catch (TTransportException e) {
      failIO(e);
    }
  }

  private void handleReadingRequest() {
    try {
      requestReader.read(underlyingTransport);
      if (requestReader.isComplete()) {
        nextPhase = Phase.PROCESSING;
      }
    } catch (TTransportException e) {
      failIO(e);
    }
  }

  // Computation executions

  private void executeEvaluatingSaslResponse() {
    if (!(saslResponse.getHeader().getStatus() == OK || saslResponse.getHeader().getStatus() == COMPLETE)) {
      String error = "Expect status OK or COMPLETE, but got " + saslResponse.getHeader().getStatus();
      failSaslNegotiation(new TSaslNegotiationException(ErrorType.PROTOCOL_ERROR, error));
      return;
    }
    try {
      byte[] response = saslResponse.getPayload();
      saslResponse.clear();
      byte[] newChallenge = saslPeer.evaluate(response);
      if (saslPeer.isAuthenticated()) {
        dataProtected = saslPeer.isDataProtected();
        saslChallenge.withHeaderAndPayload(new byte[]{COMPLETE.getValue()}, newChallenge);
        nextPhase = Phase.WRITING_SUCCESS_MESSAGE;
      } else {
        saslChallenge.withHeaderAndPayload(new byte[]{OK.getValue()}, newChallenge);
        nextPhase = Phase.WRITING_SASL_CHALLENGE;
      }
    } catch (TSaslNegotiationException e) {
      failSaslNegotiation(e);
    }
  }

  private void executeProcessing() {
    try {
      byte[] inputPayload = requestReader.getPayload();
      requestReader.clear();
      byte[] rawInput = dataProtected ? saslPeer.unwrap(inputPayload) : inputPayload;
      TMemoryTransport memoryTransport = new TMemoryTransport(rawInput);
      TProtocol requestProtocol = inputProtocolFactory.getProtocol(memoryTransport);
      TProtocol responseProtocol = outputProtocolFactory.getProtocol(memoryTransport);

      if (eventHandler != null) {
        if (!serverContextCreated) {
          serverContext = eventHandler.createContext(requestProtocol, responseProtocol);
          serverContextCreated = true;
        }
        eventHandler.processContext(serverContext, memoryTransport, memoryTransport);
      }

      TProcessor processor = processorFactory.getProcessor(this);
      processor.process(requestProtocol, responseProtocol);
      TByteArrayOutputStream rawOutput = memoryTransport.getOutput();
      if (rawOutput.len() == 0) {
        // This is a oneway request, no response to send back. Waiting for next incoming request.
        nextPhase = Phase.READING_REQUEST;
        return;
      }
      if (dataProtected) {
        byte[] outputPayload = saslPeer.wrap(rawOutput.get(), 0, rawOutput.len());
        responseWriter.withOnlyPayload(outputPayload);
      } else {
        responseWriter.withOnlyPayload(rawOutput.get(), 0 ,rawOutput.len());
      }
      nextPhase = Phase.WRITING_RESPONSE;
    } catch (TTransportException e) {
      failIO(e);
    } catch (Exception e) {
      fail(e);
    }
  }

  // Write handlings

  private void handleWritingSaslChallenge() {
    try {
      saslChallenge.write(underlyingTransport);
      if (saslChallenge.isComplete()) {
        saslChallenge.clear();
        nextPhase = Phase.READING_SASL_RESPONSE;
      }
    } catch (TTransportException e) {
      fail(e);
    }
  }

  private void handleWritingSuccessMessage() {
    try {
      saslChallenge.write(underlyingTransport);
      if (saslChallenge.isComplete()) {
        LOGGER.debug("Authentication is done.");
        saslChallenge = null;
        saslResponse = null;
        nextPhase = Phase.READING_REQUEST;
      }
    } catch (TTransportException e) {
      fail(e);
    }
  }

  private void handleWritingFailureMessage() {
    try {
      saslChallenge.write(underlyingTransport);
      if (saslChallenge.isComplete()) {
        nextPhase = Phase.CLOSING;
      }
    } catch (TTransportException e) {
      fail(e);
    }
  }

  private void handleWritingResponse() {
    try {
      responseWriter.write(underlyingTransport);
      if (responseWriter.isComplete()) {
        responseWriter.clear();
        nextPhase = Phase.READING_REQUEST;
      }
    } catch (TTransportException e) {
      fail(e);
    }
  }

  /**
   * Release all the resources managed by this state machine (connection, selection and sasl server).
   * To avoid being blocked, this should be invoked in the network thread that manages the selector.
   */
  public void close() {
    underlyingTransport.close();
    selectionKey.cancel();
    if (saslPeer != null) {
      saslPeer.dispose();
    }
    if (serverContextCreated) {
      eventHandler.deleteContext(serverContext,
          inputProtocolFactory.getProtocol(underlyingTransport),
          outputProtocolFactory.getProtocol(underlyingTransport));
    }
    nextPhase = Phase.CLOSED;
    currentPhase = Phase.CLOSED;
    LOGGER.trace("Connection closed: {}", underlyingTransport);
  }

  public enum Phase {
    INITIIALIIZING(INTEREST_READ) {
      @Override
      void unsafeRun(NonblockingSaslHandler statemachine) {
        statemachine.handleInitializing();
      }
    },
    READING_SASL_RESPONSE(INTEREST_READ) {
      @Override
      void unsafeRun(NonblockingSaslHandler statemachine) {
        statemachine.handleReadingSaslResponse();
      }
    },
    EVALUATING_SASL_RESPONSE(INTEREST_NONE) {
      @Override
      void unsafeRun(NonblockingSaslHandler statemachine) {
        statemachine.executeEvaluatingSaslResponse();
      }
    },
    WRITING_SASL_CHALLENGE(INTEREST_WRITE) {
      @Override
      void unsafeRun(NonblockingSaslHandler statemachine) {
        statemachine.handleWritingSaslChallenge();
      }
    },
    WRITING_SUCCESS_MESSAGE(INTEREST_WRITE) {
      @Override
      void unsafeRun(NonblockingSaslHandler statemachine) {
        statemachine.handleWritingSuccessMessage();
      }
    },
    WRITING_FAILURE_MESSAGE(INTEREST_WRITE) {
      @Override
      void unsafeRun(NonblockingSaslHandler statemachine) {
        statemachine.handleWritingFailureMessage();
      }
    },
    READING_REQUEST(INTEREST_READ) {
      @Override
      void unsafeRun(NonblockingSaslHandler statemachine) {
        statemachine.handleReadingRequest();
      }
    },
    PROCESSING(INTEREST_NONE) {
      @Override
      void unsafeRun(NonblockingSaslHandler statemachine) {
        statemachine.executeProcessing();
      }
    },
    WRITING_RESPONSE(INTEREST_WRITE) {
      @Override
      void unsafeRun(NonblockingSaslHandler statemachine) {
        statemachine.handleWritingResponse();
      }
    },
    CLOSING(INTEREST_NONE) {
      @Override
      void unsafeRun(NonblockingSaslHandler statemachine) {
        statemachine.close();
      }
    },
    CLOSED(INTEREST_NONE) {
      @Override
      void unsafeRun(NonblockingSaslHandler statemachine) {
        // Do nothing.
      }
    }
    ;

    // The interest on the selection key during the phase
    private int selectionInterest;

    Phase(int selectionInterest) {
      this.selectionInterest = selectionInterest;
    }

    /**
     * Provide the execution to run for the state machine in current phase. The execution should
     * return the next phase after running on the state machine.
     *
     * @param statemachine The state machine to run.
     * @throws IllegalArgumentException if the state machine's current phase is different.
     * @throws IllegalStateException if the state machine' current phase is already done.
     */
    void runStateMachine(NonblockingSaslHandler statemachine) {
      if (statemachine.currentPhase != this) {
        throw new IllegalArgumentException("State machine is " + statemachine.currentPhase +
            " but is expected to be " + this);
      }
      if (statemachine.isCurrentPhaseDone()) {
        throw new IllegalStateException("State machine should step into " + statemachine.nextPhase);
      }
      unsafeRun(statemachine);
    }

    // Run the state machine without checkiing its own phase
    // It should not be called direcly by users.
    abstract void unsafeRun(NonblockingSaslHandler statemachine);
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy