All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.kyuubi.service.authentication.HadoopThriftAuthBridgeServer.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.kyuubi.service.authentication

import java.io.IOException
import java.net.InetAddress
import java.security.PrivilegedAction
import java.util.Base64
import javax.security.auth.callback._
import javax.security.sasl.{AuthorizeCallback, RealmCallback}

import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.security.{SaslRpcServer, UserGroupInformation}
import org.apache.hadoop.security.SaslRpcServer.AuthMethod
import org.apache.hadoop.security.token.SecretManager.InvalidToken
import org.apache.thrift.{TException, TProcessor}
import org.apache.thrift.protocol.TProtocol
import org.apache.thrift.transport._

import org.apache.kyuubi.Logging

class HadoopThriftAuthBridgeServer(secretMgr: KyuubiDelegationTokenManager) {
  import HadoopThriftAuthBridgeServer._

  private val ugi = UserGroupInformation.getCurrentUser

  def createSaslServerTransportFactory(
      saslProps: java.util.Map[String, String]): TSaslServerTransport.Factory = {
    val principal = ugi.getUserName
    val names = SaslRpcServer.splitKerberosName(principal)
    if (names.length != 3) {
      throw new TTransportException(s"Kerberos principal should have 3 parts: $principal")
    }
    val factory = new TSaslServerTransport.Factory
    factory.addServerDefinition(
      AuthMethod.KERBEROS.getMechanismName,
      names(0),
      names(1),
      saslProps,
      new SaslRpcServer.SaslGssCallbackHandler)
    factory.addServerDefinition(
      AuthMethod.TOKEN.getMechanismName,
      null,
      SaslRpcServer.SASL_DEFAULT_REALM,
      saslProps,
      new SaslDigestCallbackHandler(secretMgr))
    factory
  }

  /**
   * Wrap a TTransportFactory in such a way that, before processing any RPC, it
   * assumes the UserGroupInformation of the user authenticated by
   * the SASL transport.
   */
  def wrapTransportFactory(transFactory: TTransportFactory): TTransportFactory = {
    new TUGIAssumingTransportFactory(ugi, transFactory)
  }

  /**
   * Wrap a TProcessor to capture the client information like connecting userid, ip etc
   */
  def wrapNonAssumingProcessor(processor: TProcessor): TProcessor = {
    new TUGIAssumingProcessor(processor, secretMgr)
  }

  def getRemoteAddress: InetAddress = REMOTE_ADDRESS.get

  def getRemoteUser: String = REMOTE_USER.get

  def getUserAuthMechanism: String = USER_AUTH_MECHANISM.get
}

object HadoopThriftAuthBridgeServer {

  final val REMOTE_ADDRESS = new ThreadLocal[InetAddress]() {
    override def initialValue(): InetAddress = null
  }

  final val REMOTE_USER = new ThreadLocal[String]() {
    override protected def initialValue: String = null
  }

  final val USER_AUTH_MECHANISM: ThreadLocal[String] = new ThreadLocal[String]() {
    override protected def initialValue: String = AuthMethod.KERBEROS.getMechanismName
  }

  /**
   * Form Apache Hive
   *
   * A TransportFactory that wraps another one, but assumes a specified UGI
   * before calling through.
   *
   * This is used on the server side to assume the server's Principal when accepting
   * clients.
   */
  class TUGIAssumingTransportFactory(
      ugi: UserGroupInformation,
      wrapped: TTransportFactory) extends TTransportFactory {

    override def getTransport(trans: TTransport): TTransport = {
      ugi.doAs(new PrivilegedAction[TTransport] {
        override def run(): TTransport = wrapped.getTransport(trans)
      })
    }
  }

  /**
   * Form Apache Hive
   *
   * Processor that pulls the SaslServer object out of the transport, and
   * assumes the remote user's UGI before calling through to the original
   * processor.
   *
   * This is used on the server side to set the UGI for each specific call.
   */
  class TUGIAssumingProcessor(
      wrapped: TProcessor,
      secretMgr: KyuubiDelegationTokenManager) extends TProcessor with Logging {
    override def process(in: TProtocol, out: TProtocol): Boolean = {
      val transport = in.getTransport
      transport match {
        case saslTrans: TSaslServerTransport =>
          val saslServer = saslTrans.getSaslServer
          val authId = saslServer.getAuthorizationID
          var endUser = authId
          debug(s"AUTH ID ======> $authId")
          val socket = saslTrans.getUnderlyingTransport.asInstanceOf[TSocket].getSocket
          REMOTE_ADDRESS.set(socket.getInetAddress)
          val mechanismName = saslServer.getMechanismName
          USER_AUTH_MECHANISM.set(mechanismName)
          try {
            if (AuthMethod.PLAIN.getMechanismName.equalsIgnoreCase(mechanismName)) {
              REMOTE_USER.set(endUser)
              wrapped.process(in, out)
            } else {
              if (AuthMethod.TOKEN.getMechanismName.equalsIgnoreCase(mechanismName)) {
                try {
                  val identifier = SaslRpcServer.getIdentifier(authId, secretMgr)
                  endUser = identifier.getUser.getUserName
                } catch {
                  case e: InvalidToken => throw new TException(e.getMessage)
                }
              }
              val clientUgi: UserGroupInformation = UserGroupInformation.createRemoteUser(endUser)
              try {
                REMOTE_USER.set(clientUgi.getShortUserName)
                debug(s"SET REMOTE USER: ${REMOTE_USER.get()} from endUser: $clientUgi")
                wrapped.process(in, out)
              } catch {
                case e: RuntimeException => e.getCause match {
                    case t: TException => throw t
                    case _ => throw e
                  }
                case e: InterruptedException => throw new RuntimeException(e)
                case e: IOException => throw new RuntimeException(e)
              } finally {
                try {
                  FileSystem.closeAllForUGI(clientUgi)
                } catch {
                  case e: IOException =>
                    error(s"Could not clean up file-system handles for UGI: $clientUgi", e)
                }
              }
            }
          } finally {
            REMOTE_USER.remove()
            REMOTE_ADDRESS.remove()
            USER_AUTH_MECHANISM.remove()
          }

        case _ => throw new TException(s"Unexpected non-SASL transport ${transport.getClass}")
      }
    }
  }

  /**
   * From Apache Hive
   */
  class SaslDigestCallbackHandler(secretMgr: KyuubiDelegationTokenManager)
    extends CallbackHandler with Logging {

    def getPasswd(identifier: KyuubiDelegationTokenIdentifier): Array[Char] = {
      val passwd = secretMgr.retrievePassword(identifier)
      Base64.getMimeEncoder.encodeToString(passwd).toCharArray
    }

    override def handle(callbacks: Array[Callback]): Unit = {
      var nc: NameCallback = null
      var pc: PasswordCallback = null
      callbacks.foreach {
        case ac: AuthorizeCallback =>
          val authenticationID = ac.getAuthenticationID
          val authorizationID = ac.getAuthorizationID
          ac.setAuthorized(authenticationID == authorizationID)
          if (ac.isAuthorized) {
            debug(s"SASL server DIGEST-MD5 callback: setting canonicalized client ID" +
              SaslRpcServer.getIdentifier(authorizationID, secretMgr).getUser.getUserName)
            ac.setAuthorizedID(authorizationID)
          }
        case c: NameCallback => nc = c
        case p: PasswordCallback => pc = p
        case _: RealmCallback => // do nothing
        case o => throw new UnsupportedCallbackException(o, "Unrecognized SASL DIGEST-MD5 Callback")
      }
      if (pc != null) {
        val tokenIdentifier = SaslRpcServer.getIdentifier(nc.getDefaultName, secretMgr)
        val password: Array[Char] = getPasswd(tokenIdentifier)
        debug(s"SASL server DIGEST-MD5 callback: setting password for client:" +
          s" ${tokenIdentifier.getUser}")
        pc.setPassword(password)
      }
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy