All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.io.powerbi.PowerBIWriter.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.io.powerbi

import com.microsoft.ml.spark.io.http._
import com.microsoft.ml.spark.stages._
import org.apache.http.client.HttpResponseException
import org.apache.log4j.{LogManager, Logger}
import org.apache.spark.ml.NamespaceInjections
import org.apache.spark.sql.functions.{col, struct}
import org.apache.spark.sql.streaming.DataStreamWriter
import org.apache.spark.sql.{DataFrame, ForeachWriter, Row}

import scala.collection.JavaConverters._

private[ml] class StreamMaterializer extends ForeachWriter[Row] {

  override def open(partitionId: Long, version: Long): Boolean = true

  override def process(value: Row): Unit = ()

  override def close(errorOrNull: Throwable): Unit = ()

}

object PowerBIWriter {

  val Logger: Logger = LogManager.getRootLogger

  private def prepareDF(df: DataFrame, url: String, options: Map[String, String] = Map()): DataFrame = {
    val applicableOptions = Set(
      "consolidate", "concurrency", "concurrentTimeout", "minibatcher",
      "maxBatchSize", "batchSize", "buffered", "maxBufferSize", "millisToWait"
    )

    options.keys.foreach(k =>
      assert(applicableOptions(k), s"$k not an applicable option ${applicableOptions.toList}"))

    val consolidate = options.get("consolidate").map(_.toBoolean).getOrElse(false)

    val concurrency = options.get("concurrency").map(_.toInt).getOrElse(1)
    val concurrentTimeout = options.get("concurrentTimeout").map(_.toDouble).getOrElse(30.0)

    val minibatcher = options.getOrElse("minibatcher", "fixed")
    val maxBatchSize = options.get("maxBatchSize").map(_.toInt).getOrElse(Integer.MAX_VALUE)
    val batchSize = options.get("batchSize").map(_.toInt).getOrElse(10)
    val isBuffered = options.get("buffered").map(_.toBoolean).getOrElse(false)
    val maxBufferSize = options.get("maxBufferSize").map(_.toInt).getOrElse(5)
    val millisToWait = options.get("millisToWait").map(_.toInt).getOrElse(1000)

    val mb = minibatcher match {
      case "dynamic" =>
        new DynamicMiniBatchTransformer()
           .setMaxBatchSize(maxBatchSize)
      case "fixed" =>
        new FixedMiniBatchTransformer()
          .setBuffered(isBuffered)
          .setBatchSize(batchSize)
          .setMaxBufferSize(maxBufferSize)
      case "timed" =>
        new TimeIntervalMiniBatchTransformer()
          .setMillisToWait(millisToWait)
          .setMaxBatchSize(maxBatchSize)
    }

    val df2 = if (consolidate){
      new PartitionConsolidator().transform(df)
    }else{
      df
    }

    new SimpleHTTPTransformer()
      .setUrl(url)
      .setMiniBatcher(mb)
      .setFlattenOutputBatches(false)
      .setOutputParser(new CustomOutputParser().setUDF({response: HTTPResponseData =>
        val status = response.statusLine
        val code = status.statusCode
        if (code != 200){
          val content = new String(response.entity.get.content)
          throw new HttpResponseException(code, s"Request failed with \n " +
            s"code: $code, \n" +
            s"reason:${status.reasonPhrase}, \n" +
            s"content: $content")
        }
        response
      }))
      .setConcurrency(concurrency)
      .setConcurrentTimeout(concurrentTimeout)
      .setInputCol("input")
      .setOutputCol("output")
      .transform(df2.select(struct(df2.columns.map(col): _*).alias("input")))
  }

  def stream(df: DataFrame, url: String, options: Map[String, String] = Map()): DataStreamWriter[Row] = {
    prepareDF(df, url, options).writeStream.foreach(new StreamMaterializer)
  }

  def write(df: DataFrame, url: String, options: Map[String, String] = Map()): Unit = {
    prepareDF(df, url, options).foreachPartition(it => it.foreach(_ => ()))
  }

  def stream(df: DataFrame, url: String,
             options: java.util.HashMap[String, String]): DataStreamWriter[Row] = {
    stream(df, url, options.asScala.toMap)
  }

  def write(df: DataFrame, url: String,
            options: java.util.HashMap[String, String]): Unit = {
    write(df, url, options.asScala.toMap)
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy