All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.databricks.jdbc.dbclient.impl.thrift.DatabricksThriftAccessor Maven / Gradle / Ivy

package com.databricks.jdbc.dbclient.impl.thrift;

import static com.databricks.jdbc.common.DatabricksJdbcConstants.IS_FAKE_SERVICE_TEST_PROP;
import static com.databricks.jdbc.common.EnvironmentVariables.*;
import static com.databricks.jdbc.common.util.DatabricksThriftUtil.*;
import static com.databricks.jdbc.model.client.thrift.generated.TStatusCode.*;

import com.databricks.jdbc.api.IDatabricksConnectionContext;
import com.databricks.jdbc.api.IDatabricksSession;
import com.databricks.jdbc.api.impl.*;
import com.databricks.jdbc.api.internal.IDatabricksStatementInternal;
import com.databricks.jdbc.common.StatementType;
import com.databricks.jdbc.dbclient.impl.common.ClientConfigurator;
import com.databricks.jdbc.dbclient.impl.common.StatementId;
import com.databricks.jdbc.dbclient.impl.http.DatabricksHttpClientFactory;
import com.databricks.jdbc.exception.DatabricksHttpException;
import com.databricks.jdbc.exception.DatabricksParsingException;
import com.databricks.jdbc.exception.DatabricksSQLException;
import com.databricks.jdbc.exception.DatabricksSQLFeatureNotSupportedException;
import com.databricks.jdbc.log.JdbcLogger;
import com.databricks.jdbc.log.JdbcLoggerFactory;
import com.databricks.jdbc.model.client.thrift.generated.*;
import com.databricks.sdk.core.DatabricksConfig;
import com.google.common.annotations.VisibleForTesting;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.Map;
import org.apache.http.HttpException;
import org.apache.thrift.TBase;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TBinaryProtocol;

final class DatabricksThriftAccessor {

  private static final JdbcLogger LOGGER =
      JdbcLoggerFactory.getLogger(DatabricksThriftAccessor.class);
  private final DatabricksConfig databricksConfig;
  private final ThreadLocal thriftClient;
  private final Boolean enableDirectResults;
  private static final TSparkGetDirectResults DEFAULT_DIRECT_RESULTS =
      new TSparkGetDirectResults().setMaxRows(DEFAULT_ROW_LIMIT).setMaxBytes(DEFAULT_BYTE_LIMIT);

  DatabricksThriftAccessor(IDatabricksConnectionContext connectionContext)
      throws DatabricksParsingException {
    enableDirectResults = connectionContext.getDirectResultMode();
    this.databricksConfig = new ClientConfigurator(connectionContext).getDatabricksConfig();
    Map authHeaders = databricksConfig.authenticate();
    String endPointUrl = connectionContext.getEndpointURL();

    final boolean isFakeServiceTest =
        Boolean.parseBoolean(System.getProperty(IS_FAKE_SERVICE_TEST_PROP));
    if (!isFakeServiceTest) {
      // Create a new thrift client for each thread as client state is not thread safe. Note that
      // the underlying protocol uses the same http client which is thread safe
      this.thriftClient =
          ThreadLocal.withInitial(
              () -> createThriftClient(endPointUrl, authHeaders, connectionContext));
    } else {
      TCLIService.Client client = createThriftClient(endPointUrl, authHeaders, connectionContext);
      this.thriftClient = ThreadLocal.withInitial(() -> client);
    }
  }

  @VisibleForTesting
  DatabricksThriftAccessor(
      TCLIService.Client client,
      DatabricksConfig config,
      IDatabricksConnectionContext connectionContext) {
    this.databricksConfig = config;
    this.thriftClient = ThreadLocal.withInitial(() -> client);
    this.enableDirectResults = connectionContext.getDirectResultMode();
  }

  TBase getThriftResponse(TBase request) throws DatabricksSQLException {
    refreshHeadersIfRequired();
    LOGGER.debug(String.format("Fetching thrift response for request {%s}", request.toString()));
    try {
      if (request instanceof TOpenSessionReq) {
        return getThriftClient().OpenSession((TOpenSessionReq) request);
      } else if (request instanceof TCloseSessionReq) {
        return getThriftClient().CloseSession((TCloseSessionReq) request);
      } else if (request instanceof TGetPrimaryKeysReq) {
        return listPrimaryKeys((TGetPrimaryKeysReq) request);
      } else if (request instanceof TGetFunctionsReq) {
        return listFunctions((TGetFunctionsReq) request);
      } else if (request instanceof TGetSchemasReq) {
        return listSchemas((TGetSchemasReq) request);
      } else if (request instanceof TGetColumnsReq) {
        return listColumns((TGetColumnsReq) request);
      } else if (request instanceof TGetCatalogsReq) {
        return getCatalogs((TGetCatalogsReq) request);
      } else if (request instanceof TGetTablesReq) {
        return getTables((TGetTablesReq) request);
      } else if (request instanceof TGetTableTypesReq) {
        return getTableTypes((TGetTableTypesReq) request);
      } else if (request instanceof TGetTypeInfoReq) {
        return getTypeInfo((TGetTypeInfoReq) request);
      } else {
        String errorMessage =
            String.format(
                "No implementation for fetching thrift response for Request {%s}", request);
        LOGGER.error(errorMessage);
        throw new DatabricksSQLFeatureNotSupportedException(errorMessage);
      }
    } catch (TException | SQLException e) {
      Throwable cause = e;
      while (cause != null) {
        if (cause instanceof HttpException) {
          throw new DatabricksHttpException(cause.getMessage(), cause);
        }
        cause = cause.getCause();
      }
      String errorMessage =
          String.format(
              "Error while receiving response from Thrift server. Request {%s}, Error {%s}",
              request, e.getMessage());
      LOGGER.error(e, errorMessage);
      throw new DatabricksSQLException(errorMessage, e);
    }
  }

  TFetchResultsResp getResultSetResp(TOperationHandle operationHandle, String context)
      throws DatabricksHttpException {
    refreshHeadersIfRequired();
    return getResultSetResp(SUCCESS_STATUS, operationHandle, context, DEFAULT_ROW_LIMIT, false);
  }

  TCancelOperationResp cancelOperation(TCancelOperationReq req) throws DatabricksHttpException {
    refreshHeadersIfRequired();
    try {
      return getThriftClient().CancelOperation(req);
    } catch (TException e) {
      String errorMessage =
          String.format(
              "Error while canceling operation from Thrift server. Request {%s}, Error {%s}",
              req.toString(), e.getMessage());
      LOGGER.error(e, errorMessage);
      throw new DatabricksHttpException(errorMessage, e);
    }
  }

  TCloseOperationResp closeOperation(TCloseOperationReq req) throws DatabricksHttpException {
    refreshHeadersIfRequired();
    try {
      return getThriftClient().CloseOperation(req);
    } catch (TException e) {
      String errorMessage =
          String.format(
              "Error while closing operation from Thrift server. Request {%s}, Error {%s}",
              req.toString(), e.getMessage());
      LOGGER.error(e, errorMessage);
      throw new DatabricksHttpException(errorMessage, e);
    }
  }

  private TFetchResultsResp getResultSetResp(
      TStatusCode responseCode,
      TOperationHandle operationHandle,
      String context,
      int maxRows,
      boolean fetchMetadata)
      throws DatabricksHttpException {
    verifySuccessStatus(responseCode, context);
    TFetchResultsReq request =
        new TFetchResultsReq()
            .setOperationHandle(operationHandle)
            .setFetchType((short) 0) // 0 represents Query output. 1 represents Log
            .setMaxRows(maxRows)
            .setMaxBytes(DEFAULT_BYTE_LIMIT);
    if (fetchMetadata) {
      request.setIncludeResultSetMetadata(true);
    }
    TFetchResultsResp response;
    try {
      response = getThriftClient().FetchResults(request);
    } catch (TException e) {
      String errorMessage =
          String.format(
              "Error while fetching results from Thrift server. Request {%s}, Error {%s}",
              request.toString(), e.getMessage());
      LOGGER.error(e, errorMessage);
      throw new DatabricksHttpException(errorMessage, e);
    }
    verifySuccessStatus(
        response.getStatus().getStatusCode(),
        String.format(
            "Error while fetching results Request {%s}. TFetchResultsResp {%s}. ",
            request, response));
    return response;
  }

  private void longPolling(TOperationHandle operationHandle)
      throws TException, InterruptedException, DatabricksHttpException {
    TGetOperationStatusReq request =
        new TGetOperationStatusReq()
            .setOperationHandle(operationHandle)
            .setGetProgressUpdate(false);
    TGetOperationStatusResp response;
    TStatusCode statusCode;
    do {
      response = getThriftClient().GetOperationStatus(request);
      statusCode = response.getStatus().getStatusCode();
      if (statusCode == TStatusCode.STILL_EXECUTING_STATUS) {
        Thread.sleep(DEFAULT_SLEEP_DELAY);
      }
    } while (statusCode == TStatusCode.STILL_EXECUTING_STATUS);
    verifySuccessStatus(
        statusCode, String.format("Request {%s}, Response {%s}", request, response));
  }

  DatabricksResultSet execute(
      TExecuteStatementReq request,
      IDatabricksStatementInternal parentStatement,
      IDatabricksSession session,
      StatementType statementType)
      throws SQLException {
    refreshHeadersIfRequired();
    int maxRows = (parentStatement == null) ? DEFAULT_ROW_LIMIT : parentStatement.getMaxRows();
    if (enableDirectResults) {
      TSparkGetDirectResults directResults =
          new TSparkGetDirectResults().setMaxBytes(DEFAULT_BYTE_LIMIT).setMaxRows(maxRows);
      request.setGetDirectResults(directResults);
    }
    TExecuteStatementResp response;
    TFetchResultsResp resultSet = null;
    try {
      response = getThriftClient().ExecuteStatement(request);
      if (Arrays.asList(ERROR_STATUS, INVALID_HANDLE_STATUS).contains(response.status.statusCode)) {
        throw new DatabricksSQLException(response.status.errorMessage);
      }
      if (response.isSetDirectResults()) {
        if (enableDirectResults) {
          if (response.getDirectResults().isSetOperationStatus()
              && response.getDirectResults().operationStatus.operationState
                  == TOperationState.ERROR_STATE) {
            throw new DatabricksSQLException(
                response.getDirectResults().getOperationStatus().errorMessage);
          }
        }
        if (((response.status.statusCode == SUCCESS_STATUS)
            || (response.status.statusCode == SUCCESS_WITH_INFO_STATUS))) {
          checkDirectResultsForErrorStatus(response.getDirectResults(), response.toString());
          resultSet = response.getDirectResults().getResultSet();
          resultSet.setResultSetMetadata(response.getDirectResults().getResultSetMetadata());
        }
      } else {
        longPolling(response.getOperationHandle());
        resultSet =
            getResultSetResp(
                response.getStatus().getStatusCode(),
                response.getOperationHandle(),
                response.toString(),
                maxRows,
                true);
      }
    } catch (TException | InterruptedException e) {
      String errorMessage =
          String.format(
              "Error while receiving response from Thrift server. Request {%s}, Error {%s}",
              request.toString(), e.getMessage());
      LOGGER.error(e, errorMessage);
      throw new DatabricksHttpException(errorMessage, e);
    }
    StatementId statementId = new StatementId(response.getOperationHandle().operationId);
    if (parentStatement != null) {
      parentStatement.setStatementId(statementId);
    }
    return new DatabricksResultSet(
        response.getStatus(),
        statementId,
        resultSet.getResults(),
        resultSet.getResultSetMetadata(),
        statementType,
        parentStatement,
        session);
  }

  DatabricksResultSet executeAsync(
      TExecuteStatementReq request,
      IDatabricksStatementInternal parentStatement,
      IDatabricksSession session,
      StatementType statementType)
      throws SQLException {
    refreshHeadersIfRequired();
    TExecuteStatementResp response;
    try {
      response = getThriftClient().ExecuteStatement(request);
      if (Arrays.asList(ERROR_STATUS, INVALID_HANDLE_STATUS).contains(response.status.statusCode)) {
        LOGGER.error(
            "Received error response {%s} from Thrift Server for request {%s}",
            response, request.toString());
        throw new DatabricksSQLException(response.status.errorMessage);
      }
    } catch (TException e) {
      String errorMessage =
          String.format(
              "Error while receiving response from Thrift server. Request {%s}, Error {%s}",
              request.toString(), e.getMessage());
      LOGGER.error(e, errorMessage);
      throw new DatabricksHttpException(errorMessage, e);
    }
    StatementId statementId = new StatementId(response.getOperationHandle().operationId);
    if (parentStatement != null) {
      parentStatement.setStatementId(statementId);
    }
    return new DatabricksResultSet(
        response.getStatus(), statementId, null, null, statementType, parentStatement, session);
  }

  DatabricksResultSet getStatementResult(
      TOperationHandle operationHandle,
      IDatabricksStatementInternal parentStatement,
      IDatabricksSession session)
      throws SQLException {
    LOGGER.debug("Operation handle {%s}", operationHandle);
    TGetOperationStatusReq request =
        new TGetOperationStatusReq()
            .setOperationHandle(operationHandle)
            .setGetProgressUpdate(false);
    TGetOperationStatusResp response;
    TStatusCode statusCode;
    TFetchResultsResp resultSet = null;
    try {
      response = getThriftClient().GetOperationStatus(request);
      statusCode = response.getStatus().getStatusCode();
      if (statusCode == SUCCESS_STATUS || statusCode == SUCCESS_WITH_INFO_STATUS) {
        resultSet =
            getResultSetResp(
                response.getStatus().getStatusCode(),
                operationHandle,
                response.toString(),
                -1,
                true);
      }
    } catch (TException e) {
      String errorMessage =
          String.format(
              "Error while receiving response from Thrift server. Request {%s}, Error {%s}",
              request.toString(), e.getMessage());
      LOGGER.error(e, errorMessage);
      throw new DatabricksHttpException(errorMessage, e);
    }
    StatementId statementId = new StatementId(operationHandle.getOperationId());
    return new DatabricksResultSet(
        response.getStatus(),
        statementId,
        resultSet == null ? null : resultSet.getResults(),
        resultSet == null ? null : resultSet.getResultSetMetadata(),
        StatementType.SQL,
        parentStatement,
        session);
  }

  void resetAccessToken(String newAccessToken) {
    this.databricksConfig.setToken(newAccessToken);
  }

  private TFetchResultsResp listFunctions(TGetFunctionsReq request)
      throws DatabricksHttpException, TException {
    if (enableDirectResults) request.setGetDirectResults(DEFAULT_DIRECT_RESULTS);
    TGetFunctionsResp response = getThriftClient().GetFunctions(request);
    if (response.isSetDirectResults()) {
      checkDirectResultsForErrorStatus(response.getDirectResults(), response.toString());
      return response.getDirectResults().getResultSet();
    }
    return getResultSetResp(
        response.getStatus().getStatusCode(),
        response.getOperationHandle(),
        response.toString(),
        DEFAULT_ROW_LIMIT,
        false);
  }

  private TFetchResultsResp listPrimaryKeys(TGetPrimaryKeysReq request)
      throws DatabricksHttpException, TException {
    if (enableDirectResults) request.setGetDirectResults(DEFAULT_DIRECT_RESULTS);
    TGetPrimaryKeysResp response = getThriftClient().GetPrimaryKeys(request);
    if (response.isSetDirectResults()) {
      checkDirectResultsForErrorStatus(response.getDirectResults(), response.toString());
      return response.getDirectResults().getResultSet();
    }
    return getResultSetResp(
        response.getStatus().getStatusCode(),
        response.getOperationHandle(),
        response.toString(),
        DEFAULT_ROW_LIMIT,
        false);
  }

  private TFetchResultsResp getTables(TGetTablesReq request)
      throws TException, DatabricksHttpException {
    if (enableDirectResults) request.setGetDirectResults(DEFAULT_DIRECT_RESULTS);
    TGetTablesResp response = getThriftClient().GetTables(request);
    if (response.isSetDirectResults()) {
      checkDirectResultsForErrorStatus(response.getDirectResults(), response.toString());
      return response.getDirectResults().getResultSet();
    }
    return getResultSetResp(
        response.getStatus().getStatusCode(),
        response.getOperationHandle(),
        response.toString(),
        DEFAULT_ROW_LIMIT,
        false);
  }

  private TFetchResultsResp getTableTypes(TGetTableTypesReq request)
      throws TException, DatabricksHttpException {
    if (enableDirectResults) request.setGetDirectResults(DEFAULT_DIRECT_RESULTS);
    TGetTableTypesResp response = getThriftClient().GetTableTypes(request);
    if (response.isSetDirectResults()) {
      checkDirectResultsForErrorStatus(response.getDirectResults(), response.toString());
      return response.getDirectResults().getResultSet();
    }
    return getResultSetResp(
        response.getStatus().getStatusCode(),
        response.getOperationHandle(),
        response.toString(),
        DEFAULT_ROW_LIMIT,
        false);
  }

  private TFetchResultsResp getCatalogs(TGetCatalogsReq request)
      throws TException, DatabricksHttpException {
    if (enableDirectResults) request.setGetDirectResults(DEFAULT_DIRECT_RESULTS);
    TGetCatalogsResp response = getThriftClient().GetCatalogs(request);
    if (response.isSetDirectResults()) {
      checkDirectResultsForErrorStatus(response.getDirectResults(), response.toString());
      return response.getDirectResults().getResultSet();
    }
    return getResultSetResp(
        response.getStatus().getStatusCode(),
        response.getOperationHandle(),
        response.toString(),
        DEFAULT_ROW_LIMIT,
        false);
  }

  private TFetchResultsResp listSchemas(TGetSchemasReq request)
      throws TException, DatabricksHttpException {
    if (enableDirectResults) request.setGetDirectResults(DEFAULT_DIRECT_RESULTS);
    TGetSchemasResp response = getThriftClient().GetSchemas(request);
    if (response.isSetDirectResults()) {
      checkDirectResultsForErrorStatus(response.getDirectResults(), response.toString());
      return response.getDirectResults().getResultSet();
    }
    return getResultSetResp(
        response.getStatus().getStatusCode(),
        response.getOperationHandle(),
        response.toString(),
        DEFAULT_ROW_LIMIT,
        false);
  }

  private TFetchResultsResp getTypeInfo(TGetTypeInfoReq request)
      throws TException, DatabricksHttpException {
    if (enableDirectResults) request.setGetDirectResults(DEFAULT_DIRECT_RESULTS);
    TGetTypeInfoResp response = getThriftClient().GetTypeInfo(request);
    if (response.isSetDirectResults()) {
      checkDirectResultsForErrorStatus(response.getDirectResults(), response.toString());
      return response.getDirectResults().getResultSet();
    }
    return getResultSetResp(
        response.getStatus().getStatusCode(),
        response.getOperationHandle(),
        response.toString(),
        DEFAULT_ROW_LIMIT,
        false);
  }

  private TFetchResultsResp listColumns(TGetColumnsReq request)
      throws TException, DatabricksHttpException {
    if (enableDirectResults) request.setGetDirectResults(DEFAULT_DIRECT_RESULTS);
    TGetColumnsResp response = getThriftClient().GetColumns(request);
    if (response.isSetDirectResults()) {
      checkDirectResultsForErrorStatus(response.getDirectResults(), response.toString());
      return response.getDirectResults().getResultSet();
    }
    return getResultSetResp(
        response.getStatus().getStatusCode(),
        response.getOperationHandle(),
        response.toString(),
        DEFAULT_ROW_LIMIT,
        false);
  }

  private void refreshHeadersIfRequired() {
    ((DatabricksHttpTTransport) getThriftClient().getInputProtocol().getTransport())
        .setCustomHeaders(databricksConfig.authenticate());
  }

  private TCLIService.Client getThriftClient() {
    return thriftClient.get();
  }

  /**
   * Creates a new thrift client for the given endpoint URL and authentication headers.
   *
   * @param endPointUrl endpoint URL
   * @param authHeaders authentication headers
   * @param connectionContext connection context
   */
  private TCLIService.Client createThriftClient(
      String endPointUrl,
      Map authHeaders,
      IDatabricksConnectionContext connectionContext) {
    DatabricksHttpTTransport transport =
        new DatabricksHttpTTransport(
            DatabricksHttpClientFactory.getInstance().getClient(connectionContext), endPointUrl);
    transport.setCustomHeaders(authHeaders);
    TBinaryProtocol protocol = new TBinaryProtocol(transport);

    return new TCLIService.Client(protocol);
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy