All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.ppml.crypto.dataframe.EncryptedDataFrameWriter.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.bigdl.ppml.crypto.dataframe

import com.intel.analytics.bigdl.dllib.utils.Log4Error
import com.intel.analytics.bigdl.ppml.crypto.dataframe.EncryptedDataFrameWriter.writeCsv
import com.intel.analytics.bigdl.ppml.crypto.{AES_CBC_PKCS5PADDING, AES_GCM_CTR_V1, AES_GCM_V1, Crypto, CryptoMode, ENCRYPT, PLAIN_TEXT}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.rdd.RDD
import org.apache.spark.{SerializableWritable, SparkContext, TaskContext}
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}

import java.nio.charset.StandardCharsets
import java.util.{Base64, Locale}

class EncryptedDataFrameWriter(
      sparkSession: SparkSession,
      df: DataFrame,
      encryptMode: CryptoMode,
      dataKeyPlainText: String) {
  protected val extraOptions = new scala.collection.mutable.HashMap[String, String]

  def option(key: String, value: String): this.type = {
    this.extraOptions += (key -> value)
    this
  }

  def option(key: String, value: Boolean): this.type = {
    this.extraOptions += (key -> value.toString)
    this
  }

  private var mode: SaveMode = SaveMode.ErrorIfExists

  def mode(saveMode: String): this.type = {
    this.mode = saveMode.toLowerCase(Locale.ROOT) match {
      case "overwrite" => SaveMode.Overwrite
      case "append" => SaveMode.Append
      case "ignore" => SaveMode.Ignore
      case "error" | "errorifexists" | "default" => SaveMode.ErrorIfExists
      case _ =>
        Log4Error.invalidOperationError(false,
          s"Unknown save mode: $saveMode. " +
        "Accepted save modes are 'overwrite', 'append', 'ignore', 'error', 'errorifexists'.")
        null
    }
    this
  }

  def csv(path: String): Unit = {
    encryptMode match {
      case PLAIN_TEXT =>
        df.write.options(extraOptions).mode(mode).csv(path)
      case AES_CBC_PKCS5PADDING =>
        sparkSession.sparkContext.hadoopConfiguration
          .set("hadoop.io.compression.codecs",
          "com.intel.analytics.bigdl.ppml.crypto.CryptoCodec")
        df.write
          .option("compression", "com.intel.analytics.bigdl.ppml.crypto.CryptoCodec")
          .options(extraOptions).mode(mode).csv(path)
      case _ =>
        Log4Error.invalidOperationError(false, "unknown EncryptMode " + CryptoMode.toString)
    }
  }

  def json(path: String): Unit = {
    encryptMode match {
      case PLAIN_TEXT =>
        df.write.options(extraOptions).mode(mode).json(path)
      case AES_CBC_PKCS5PADDING =>
        sparkSession.sparkContext.hadoopConfiguration
          .set("hadoop.io.compression.codecs",
            "com.intel.analytics.bigdl.ppml.crypto.CryptoCodec")
        df.write
          .option("compression", "com.intel.analytics.bigdl.ppml.crypto.CryptoCodec")
          .options(extraOptions).mode(mode).json(path)
      case _ =>
        Log4Error.invalidOperationError(false, "unknown EncryptMode " + CryptoMode.toString)
    }
  }

  def parquet(path: String): Unit = {
    lazy val header = df.schema.fieldNames.mkString(",")
    encryptMode match {
      case PLAIN_TEXT =>
        df.write.options(extraOptions).mode(mode).parquet(path)
      case AES_GCM_CTR_V1 | AES_GCM_V1 =>
        EncryptedDataFrameWriter.setParquetKey(sparkSession, dataKeyPlainText)
        df.write
          .option("parquet.encryption.column.keys", "key1: " + header)
          .option("parquet.encryption.footer.key", "footerKey")
          .option("parquet.encryption.algorithm", encryptMode.encryptionAlgorithm)
          .options(extraOptions).mode(mode).parquet(path)
      case _ =>
        throw new IllegalArgumentException("unknown EncryptMode " + CryptoMode.toString)
    }
  }

}

object EncryptedDataFrameWriter {
  protected def writeCsv(rdd: RDD[Row],
                         sc: SparkContext,
                         path: String,
                         encryptMode: CryptoMode,
                         dataKeyPlainText: String,
                         schema: String): Unit = {
    val confBroadcast = sc.broadcast(
      new SerializableWritable(sc.hadoopConfiguration)
    )
    rdd.mapPartitions{ rows => {
      if (rows.nonEmpty) {
        val hadoopConf = confBroadcast.value.value
        val fs = FileSystem.get(hadoopConf)
        val partId = TaskContext.getPartitionId()
        // TODO
        val output = fs.create(new Path(path + "/part-" + partId), true)
        val crypto = Crypto(cryptoMode = encryptMode)
        crypto.init(encryptMode, ENCRYPT, dataKeyPlainText)
        // write crypto header
        output.write(crypto.genHeader())
        // write csv header
        if (schema != null && schema.nonEmpty) {
          output.write(crypto.update((schema + "\n").getBytes))
        }
        var row = rows.next()
        while (rows.hasNext) {
          val line = row.toSeq.mkString(",") + "\n"
          output.write(crypto.update(line.getBytes))
//          print(rows.hasNext)
          row = rows.next()
        }
        val lastLine = row.toSeq.mkString(",")
        val (lBytes, hmac) = crypto.doFinal(lastLine.getBytes)
        output.write(lBytes)
        output.write(hmac)
        output.flush()
        output.close()
      }
      Iterator.single(1)
    }}.count()

  }

  private[bigdl] def setParquetKey(sparkSession: SparkSession, dataKeyPlainText: String): Unit = {
    val encoder = Base64.getEncoder
    val footKey = encoder.encodeToString(dataKeyPlainText.slice(0, 16)
      .getBytes(StandardCharsets.UTF_8))
    val dataKey = encoder.encodeToString(dataKeyPlainText.slice(16, 32)
      .getBytes(StandardCharsets.UTF_8))
    sparkSession.sparkContext.hadoopConfiguration.set("parquet.crypto.factory.class",
      "org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory")
    sparkSession.sparkContext.hadoopConfiguration.set("parquet.encryption.kms.client.class",
      "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS")
    sparkSession.sparkContext.hadoopConfiguration.set("parquet.encryption.key.list",
      s"footerKey: ${footKey}, key1: ${dataKey}")
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy