All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.databricks.jdbc.api.impl.arrow.InlineChunkProvider Maven / Gradle / Ivy

package com.databricks.jdbc.api.impl.arrow;

import static com.databricks.jdbc.common.util.DatabricksTypeUtil.*;
import static com.databricks.jdbc.common.util.DecompressionUtil.decompress;

import com.databricks.jdbc.common.CompressionType;
import com.databricks.jdbc.exception.DatabricksParsingException;
import com.databricks.jdbc.exception.DatabricksSQLException;
import com.databricks.jdbc.log.JdbcLogger;
import com.databricks.jdbc.log.JdbcLoggerFactory;
import com.databricks.jdbc.model.client.thrift.generated.*;
import com.google.common.annotations.VisibleForTesting;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.SchemaUtility;

/** Class to manage inline Arrow chunks */
public class InlineChunkProvider implements ChunkProvider {

  private static final JdbcLogger LOGGER = JdbcLoggerFactory.getLogger(InlineChunkProvider.class);
  private long totalRows;
  private long currentChunkIndex;

  ArrowResultChunk arrowResultChunk; // There is only one packet of data in case of inline arrow

  InlineChunkProvider(
      List arrowBatches, TGetResultSetMetadataResp metadata, String statementId)
      throws DatabricksParsingException {
    this.currentChunkIndex = -1;
    this.totalRows = 0;
    ByteArrayInputStream byteStream = initializeByteStream(arrowBatches, metadata, statementId);
    arrowResultChunk = ArrowResultChunk.builder().withInputStream(byteStream, totalRows).build();
  }

  /** {@inheritDoc} */
  @Override
  public boolean hasNextChunk() {
    return this.currentChunkIndex == -1;
  }

  /** {@inheritDoc} */
  @Override
  public boolean next() {
    if (!hasNextChunk()) {
      return false;
    }
    this.currentChunkIndex++;
    return true;
  }

  /** {@inheritDoc} */
  @Override
  public ArrowResultChunk getChunk() {
    return arrowResultChunk;
  }

  /** {@inheritDoc} */
  @Override
  public void close() {
    arrowResultChunk.releaseChunk();
  }

  private ByteArrayInputStream initializeByteStream(
      List arrowBatches, TGetResultSetMetadataResp metadata, String statementId)
      throws DatabricksParsingException {
    ByteArrayOutputStream baos = new ByteArrayOutputStream();
    CompressionType compressionType = CompressionType.getCompressionMapping(metadata);
    try {
      byte[] serializedSchema = getSerializedSchema(metadata);
      if (serializedSchema != null) {
        baos.write(serializedSchema);
      }
      for (TSparkArrowBatch arrowBatch : arrowBatches) {
        byte[] decompressedBytes =
            decompress(
                arrowBatch.getBatch(),
                compressionType,
                String.format(
                    "Data fetch for inline arrow batch [%d] and statement [%s] with decompression algorithm : [%s]",
                    arrowBatch.getRowCount(), statementId, compressionType));
        totalRows += arrowBatch.getRowCount();
        baos.write(decompressedBytes);
      }
      return new ByteArrayInputStream(baos.toByteArray());
    } catch (DatabricksSQLException | IOException e) {
      handleError(e);
    }
    return null;
  }

  private byte[] getSerializedSchema(TGetResultSetMetadataResp metadata)
      throws DatabricksSQLException {
    if (metadata.getArrowSchema() != null) {
      return metadata.getArrowSchema();
    }
    Schema arrowSchema = hiveSchemaToArrowSchema(metadata.getSchema());
    try {
      return SchemaUtility.serialize(arrowSchema);
    } catch (IOException e) {
      handleError(e);
    }
    // should never reach here;
    return null;
  }

  private static Schema hiveSchemaToArrowSchema(TTableSchema hiveSchema)
      throws DatabricksParsingException {
    List fields = new ArrayList<>();
    if (hiveSchema == null) {
      return new Schema(fields);
    }
    try {
      hiveSchema
          .getColumns()
          .forEach(
              columnDesc -> {
                try {
                  fields.add(getArrowField(columnDesc));
                } catch (SQLException e) {
                  throw new RuntimeException(e);
                }
              });
    } catch (RuntimeException e) {
      handleError(e);
    }
    return new Schema(fields);
  }

  private static Field getArrowField(TColumnDesc columnDesc) throws SQLException {
    TTypeId thriftType = getThriftTypeFromTypeDesc(columnDesc.getTypeDesc());
    ArrowType arrowType = mapThriftToArrowType(thriftType);
    FieldType fieldType = new FieldType(true, arrowType, null);
    return new Field(columnDesc.getColumnName(), fieldType, null);
  }

  @VisibleForTesting
  static void handleError(Exception e) throws DatabricksParsingException {
    String errorMessage = "Cannot process inline arrow format. Error: " + e.getMessage();
    LOGGER.error(errorMessage);
    throw new DatabricksParsingException(errorMessage, e);
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy