All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.kyuubi.service.TBinaryFrontendService.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.kyuubi.service

import java.net.ServerSocket
import java.security.KeyStore
import java.util.Locale
import java.util.concurrent.{SynchronousQueue, ThreadPoolExecutor, TimeUnit}
import javax.net.ssl.{KeyManagerFactory, SSLServerSocket}

import org.apache.hive.service.rpc.thrift._
import org.apache.thrift.protocol.TBinaryProtocol
import org.apache.thrift.server.{TServer, TThreadPoolServer}
import org.apache.thrift.transport.{TServerSocket, TSSLTransportFactory}

import org.apache.kyuubi.{KyuubiException, Logging}
import org.apache.kyuubi.config.KyuubiConf
import org.apache.kyuubi.util.NamedThreadFactory

/**
 * Apache Thrift based hive service rpc
 *  1. the server side implementation serves client-server rpc calls
 *  2. the engine side implementations serve server-engine rpc calls
 */
abstract class TBinaryFrontendService(name: String)
  extends TFrontendService(name) with TCLIService.Iface with Runnable with Logging {

  import KyuubiConf._

  /**
   * @note this is final because we don't want new implementations for engine to override this.
   *       and we shall simply set it to zero for randomly picking an available port
   */
  final override protected lazy val serverHost: Option[String] =
    conf.get(FRONTEND_THRIFT_BINARY_BIND_HOST)
  final override protected lazy val portNum: Int = conf.get(FRONTEND_THRIFT_BINARY_BIND_PORT)

  protected var server: Option[TServer] = None
  private var _actualPort: Int = _
  override protected lazy val actualPort: Int = _actualPort

  // Removed OOM hook since Kyuubi #1800 to respect the hive server2 #2383

  override def initialize(conf: KyuubiConf): Unit = synchronized {
    this.conf = conf
    try {
      val minThreads = conf.get(FRONTEND_THRIFT_MIN_WORKER_THREADS)
      val maxThreads = conf.get(FRONTEND_THRIFT_MAX_WORKER_THREADS)
      val keepAliveTime = conf.get(FRONTEND_THRIFT_WORKER_KEEPALIVE_TIME)
      val executor = new ThreadPoolExecutor(
        minThreads,
        maxThreads,
        keepAliveTime,
        TimeUnit.MILLISECONDS,
        new SynchronousQueue[Runnable](),
        new NamedThreadFactory(name + "Handler-Pool", false))
      val transFactory = authFactory.getTTransportFactory
      val tProcFactory = authFactory.getTProcessorFactory(this)
      val tServerSocket =
        // only enable ssl for server side
        if (isServer() && conf.get(FRONTEND_THRIFT_BINARY_SSL_ENABLED)) {
          val keyStorePath = conf.get(FRONTEND_SSL_KEYSTORE_PATH)
          val keyStorePassword = conf.get(FRONTEND_SSL_KEYSTORE_PASSWORD)
          val keyStoreType = conf.get(FRONTEND_SSL_KEYSTORE_TYPE)
          val keyStoreAlgorithm = conf.get(FRONTEND_SSL_KEYSTORE_ALGORITHM)
          val disallowedSslProtocols = conf.get(FRONTEND_THRIFT_BINARY_SSL_DISALLOWED_PROTOCOLS)
          val includeCipherSuites = conf.get(FRONTEND_THRIFT_BINARY_SSL_INCLUDE_CIPHER_SUITES)

          if (keyStorePath.isEmpty) {
            throw new IllegalArgumentException(
              s"${FRONTEND_SSL_KEYSTORE_PATH.key} not configured for SSL connection")
          }
          if (keyStorePassword.isEmpty) {
            throw new IllegalArgumentException(
              s"${FRONTEND_SSL_KEYSTORE_PASSWORD.key} not configured for SSL connection")
          }

          getServerSSLSocket(
            keyStorePath.get,
            keyStorePassword.get,
            keyStoreType,
            keyStoreAlgorithm,
            disallowedSslProtocols,
            includeCipherSuites)
        } else {
          new TServerSocket(new ServerSocket(portNum, -1, serverAddr))
        }
      _actualPort = tServerSocket.getServerSocket.getLocalPort
      val maxMessageSize = conf.get(FRONTEND_THRIFT_MAX_MESSAGE_SIZE)
      val requestTimeout = conf.get(FRONTEND_THRIFT_LOGIN_TIMEOUT).toInt
      val beBackoffSlotLength = conf.get(FRONTEND_THRIFT_LOGIN_BACKOFF_SLOT_LENGTH).toInt
      val args = new TThreadPoolServer.Args(tServerSocket)
        .processorFactory(tProcFactory)
        .transportFactory(transFactory)
        .protocolFactory(new TBinaryProtocol.Factory)
        .inputProtocolFactory(
          new TBinaryProtocol.Factory(true, true, maxMessageSize, maxMessageSize))
        .requestTimeout(requestTimeout).requestTimeoutUnit(TimeUnit.MILLISECONDS)
        .beBackoffSlotLength(beBackoffSlotLength)
        .beBackoffSlotLengthUnit(TimeUnit.MILLISECONDS)
        .executorService(executor)
      // TCP Server
      server = Some(new TThreadPoolServer(args))
      server.foreach(_.setServerEventHandler(new FeTServerEventHandler))
      info(s"Initializing $name on ${serverAddr.getHostName}:${_actualPort} with" +
        s" [$minThreads, $maxThreads] worker threads")
    } catch {
      case e: Throwable =>
        error(e)
        throw new KyuubiException(
          s"Failed to initialize frontend service on $serverAddr:$portNum.",
          e)
    }
    super.initialize(conf)
  }

  private def getServerSSLSocket(
      keyStorePath: String,
      keyStorePassword: String,
      keyStoreType: Option[String],
      keyStoreAlgorithm: Option[String],
      disallowedSslProtocols: Seq[String],
      includeCipherSuites: Seq[String]): TServerSocket = {
    val params =
      if (includeCipherSuites.nonEmpty) {
        new TSSLTransportFactory.TSSLTransportParameters("TLS", includeCipherSuites.toArray)
      } else {
        new TSSLTransportFactory.TSSLTransportParameters()
      }
    params.setKeyStore(
      keyStorePath,
      keyStorePassword,
      keyStoreAlgorithm.getOrElse(KeyManagerFactory.getDefaultAlgorithm),
      keyStoreType.getOrElse(KeyStore.getDefaultType))

    val tServerSocket =
      TSSLTransportFactory.getServerSocket(portNum, 0, serverAddr, params)

    tServerSocket.getServerSocket match {
      case sslServerSocket: SSLServerSocket =>
        val lowerDisallowedSslProtocols = disallowedSslProtocols.map(_.toLowerCase(Locale.ROOT))
        val enabledProtocols = sslServerSocket.getEnabledProtocols.flatMap { protocol =>
          if (lowerDisallowedSslProtocols.contains(protocol.toLowerCase(Locale.ROOT))) {
            debug(s"Disabling SSL Protocol: $protocol")
            None
          } else {
            Some(protocol)
          }
        }
        sslServerSocket.setEnabledProtocols(enabledProtocols)
        info(s"SSL Server Socket enabled protocols: ${enabledProtocols.mkString(",")}")

      case _ =>
    }
    tServerSocket
  }

  override def run(): Unit =
    try {
      if (isServer()) {
        info(s"Starting and exposing JDBC connection at: jdbc:hive2://$connectionUrl/")
      }
      server.foreach(_.serve())
    } catch {
      case _: InterruptedException => error(s"$getName is interrupted")
      case t: Throwable =>
        error(s"Error starting $getName", t)
        System.exit(-1)
    }

  override protected def stopServer(): Unit = {
    server.foreach(_.stop())
    server = None
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy