All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.cdap.plugin.gcp.bigquery.sqlengine.BigQuerySparkDatasetProducer Maven / Gradle / Ivy

/*
 * Copyright © 2021 Cask Data, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */

package io.cdap.plugin.gcp.bigquery.sqlengine;

import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.cdap.etl.api.engine.sql.dataset.RecordCollection;
import io.cdap.cdap.etl.api.engine.sql.dataset.SQLDataset;
import io.cdap.cdap.etl.api.engine.sql.dataset.SQLDatasetDescription;
import io.cdap.cdap.etl.api.engine.sql.dataset.SQLDatasetProducer;
import io.cdap.cdap.etl.api.sql.engine.dataset.SparkRecordCollectionImpl;
import org.apache.spark.SparkContext;
import org.apache.spark.sql.DataFrameReader;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.Serializable;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import javax.annotation.Nullable;

/**
 * Dataset Producer implementation which uses the Spark-BigQuery connector to extract records.
 */
public class BigQuerySparkDatasetProducer
  implements SQLDatasetProducer, Serializable {

  private static final Logger LOG = LoggerFactory.getLogger(BigQuerySparkDatasetProducer.class);

  private static final String FORMAT = "bigquery";
  private static final String CONFIG_CREDENTIALS_FILE = "credentialsFile";
  private static final String CONFIG_CREDENTIALS = "credentials";

  private BigQuerySQLEngineConfig config;
  private String project;
  private String bqDataset;
  private String bqTable;
  private Schema schema;


  public BigQuerySparkDatasetProducer(BigQuerySQLEngineConfig config,
                                      String project,
                                      String bqDataset,
                                      String bqTable,
                                      Schema schema) {
    this.config = config;
    this.project = project;
    this.bqDataset = bqDataset;
    this.bqTable = bqTable;
    this.schema = schema;
  }

  @Override
  public SQLDatasetDescription getDescription() {
    return null;
  }

  @Override
  @Nullable
  public RecordCollection produce(SQLDataset sqlDataset) {
    // Define which table to load.
    String path = String.format("%s.%s.%s", project, bqDataset, bqTable);

    // Create Spark context to use for this operation.
    SparkContext sc = SparkContext.getOrCreate();
    SparkSession spark = SparkSession.builder()
      .appName("spark-bq-connector-reader")
      .sparkContext(sc)
      .getOrCreate();

    DataFrameReader bqReader = spark.read().format(FORMAT);

    // Set credential file path or base64-encoded credential from json.
    if (Boolean.TRUE.equals(config.isServiceAccountFilePath()) && config.getServiceAccountFilePath() != null) {
      bqReader.option(CONFIG_CREDENTIALS_FILE, config.getServiceAccountFilePath());
    } else if (Boolean.TRUE.equals(config.isServiceAccountJson()) && config.getServiceAccountJson() != null) {
      bqReader.option(CONFIG_CREDENTIALS, encodeBase64(config.getServiceAccountJson()));
    }

    // Load path into dataset.
    Dataset ds = bqReader.load(path);
    ds = convertFieldTypes(ds);

    return new SparkRecordCollectionImpl(ds);
  }


  private String encodeBase64(String serviceAccountJson) {
    return Base64.getEncoder().encodeToString(serviceAccountJson.getBytes(StandardCharsets.UTF_8));
  }

  /**
   * Adjust CDAP types for int and float fields.
   *
   * @param ds input dataframe
   * @return dataframe with updated schema.
   */
  private Dataset convertFieldTypes(Dataset ds) {
    for (Schema.Field field : schema.getFields()) {
      String fieldName = field.getName();
      Schema fieldSchema = field.getSchema();

      // For nullable types, check the underlying type.
      if (fieldSchema.isNullable()) {
        fieldSchema = fieldSchema.getNonNullable();
      }

      // Handle Int types
      if (fieldSchema.getType() == Schema.Type.INT) {
        LOG.trace("Converting field {} to Integer", fieldName);
        ds = ds.withColumn(fieldName, ds.col(fieldName).cast(DataTypes.IntegerType));
      }

      // Handle float types
      if (fieldSchema.getType() == Schema.Type.FLOAT) {
        LOG.trace("Converting field {} to Float", fieldName);
        ds = ds.withColumn(fieldName, ds.col(fieldName).cast(DataTypes.FloatType));
      }
    }

    return ds;
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy