All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.stages.MiniBatchTransformer.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.stages

import com.microsoft.ml.spark.core.contracts.Wrappable
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset, Row}

trait MiniBatchBase extends Transformer with DefaultParamsWritable with Wrappable {
  def transpose(nestedSeq: Seq[Seq[Any]]): Seq[Seq[Any]] = {
    val innerLength = nestedSeq.head.length
    assert(nestedSeq.forall(_.lengthCompare(innerLength) == 0))
    (0 until innerLength).map(i => nestedSeq.map(innerSeq => innerSeq(i)))
  }

  override def copy(extra: ParamMap): this.type = defaultCopy(extra)

  override def transformSchema(schema: StructType): StructType = {
    StructType(schema.map(f => StructField(f.name, ArrayType(f.dataType))))
  }

  def getBatcher(it: Iterator[Row]): Iterator[List[Row]]

  def transform(dataset: Dataset[_]): DataFrame = {
    dataset.toDF().mapPartitions { it =>
      if (it.isEmpty) {
        it
      }else{
        getBatcher(it).map(listOfRows => Row.fromSeq(transpose(listOfRows.map(r => r.toSeq))))
      }
    }(RowEncoder(transformSchema(dataset.schema)))
  }

}

object DynamicMiniBatchTransformer extends DefaultParamsReadable[DynamicMiniBatchTransformer]

class DynamicMiniBatchTransformer(val uid: String)
    extends MiniBatchBase {

  val maxBatchSize: Param[Int] = new IntParam(
    this, "maxBatchSize", "The max size of the buffer")

  /** @group getParam */
  def getMaxBatchSize: Int = $(maxBatchSize)

  /** @group setParam */
  def setMaxBatchSize(value: Int): this.type = set(maxBatchSize, value)

  def this() = this(Identifiable.randomUID("DynamicMiniBatchTransformer"))

  setDefault(maxBatchSize -> Integer.MAX_VALUE)

  override def getBatcher(it: Iterator[Row]): DynamicBufferedBatcher[Row] =
    new DynamicBufferedBatcher(it, getMaxBatchSize)

}

object TimeIntervalMiniBatchTransformer extends DefaultParamsReadable[TimeIntervalMiniBatchTransformer]

class TimeIntervalMiniBatchTransformer(val uid: String)
  extends MiniBatchBase {

  val maxBatchSize: Param[Int] = new IntParam(
    this, "maxBatchSize", "The max size of the buffer")

  /** @group getParam */
  def getMaxBatchSize: Int = $(maxBatchSize)

  /** @group setParam */
  def setMaxBatchSize(value: Int): this.type = set(maxBatchSize, value)

  val millisToWait: Param[Int] = new IntParam(
    this, "millisToWait", "The time to wait before constructing a batch")

  /** @group getParam */
  def getMillisToWait: Int = $(millisToWait)

  /** @group setParam */
  def setMillisToWait(value: Int): this.type = set(millisToWait, value)

  def this() = this(Identifiable.randomUID("DynamicMiniBatchTransformer"))

  setDefault(maxBatchSize -> Integer.MAX_VALUE)

  override def getBatcher(it: Iterator[Row]): TimeIntervalBatcher[Row] =
    new TimeIntervalBatcher(it, getMillisToWait, getMaxBatchSize)

}

trait HasMiniBatcher extends Params {
  /** Size of minibatches. Must be greater than 0; default is 10
    * @group param
    */
  val miniBatcher: TransformerParam = new TransformerParam(this, "miniBatcher", "Minibatcher to use", {
    case _: MiniBatchBase => true
    case _ => false
  })

  /** @group setParam */
  def setMiniBatcher(value: MiniBatchBase): this.type = set(miniBatcher, value)

  /** @group getParam */
  def getMiniBatcher: MiniBatchBase = $(miniBatcher).asInstanceOf[MiniBatchBase]

  def setMiniBatchSize(n: Int): this.type = setMiniBatcher(getMiniBatcher match {
    case d: DynamicMiniBatchTransformer => d.setMaxBatchSize(n)
    case f: FixedMiniBatchTransformer => f.setBatchSize(n)
    case t: TimeIntervalMiniBatchTransformer => t.setMaxBatchSize(n)
  })

  def getMiniBatchSize: Int = getMiniBatcher match {
    case d: DynamicMiniBatchTransformer => d.getMaxBatchSize
    case f: FixedMiniBatchTransformer => f.getBatchSize
    case t: TimeIntervalMiniBatchTransformer => t.getMaxBatchSize
  }
}

object FixedMiniBatchTransformer extends DefaultParamsReadable[FixedMiniBatchTransformer]

trait HasBatchSize extends Params {

  val batchSize: Param[Int] = new IntParam(
    this, "batchSize", "The max size of the buffer")

  /** @group getParam */
  def getBatchSize: Int = $(batchSize)

  /** @group setParam */
  def setBatchSize(value: Int): this.type = set(batchSize, value)

}

class FixedMiniBatchTransformer(val uid: String)
  extends MiniBatchBase with HasBatchSize {

  val maxBufferSize: Param[Int] = new IntParam(
    this, "maxBufferSize", "The max size of the buffer")

  /** @group getParam */
  def getMaxBufferSize: Int = $(maxBufferSize)

  /** @group setParam */
  def setMaxBufferSize(value: Int): this.type = set(maxBufferSize, value)

  val buffered: Param[Boolean] = new BooleanParam(
    this, "buffered", "Whether or not to buffer batches in memory")

  /** @group getParam */
  def getBuffered: Boolean = $(buffered)

  /** @group setParam */
  def setBuffered(value: Boolean): this.type = set(buffered, value)

  setDefault(buffered->false, maxBufferSize->Integer.MAX_VALUE)

  def this() = this(Identifiable.randomUID("FixedMiniBatchTransformer"))

  override def getBatcher(it: Iterator[Row]): Iterator[List[Row]] = if (getBuffered){
    new FixedBufferedBatcher(it, getBatchSize, getMaxBufferSize)
  }else{
    new FixedBatcher(it, getBatchSize)
  }

}

object FlattenBatch extends DefaultParamsReadable[FlattenBatch]

class FlattenBatch(val uid: String)
    extends Transformer with Wrappable with DefaultParamsWritable {

  def this() = this(Identifiable.randomUID("FlattenBatch"))

  def transpose(nestedSeq: Seq[Seq[Any]]): Seq[Seq[Any]] = {
    val innerLength = nestedSeq.head.length
    assert(nestedSeq.forall(_.lengthCompare(innerLength) == 0))
    (0 until innerLength).map(i => nestedSeq.map(inneSeq => inneSeq(i)))
  }

  override def transform(dataset: Dataset[_]): DataFrame = {
    dataset.toDF().mapPartitions(it =>
      it.flatMap { rowOfLists =>
        val transposed = transpose((0 until rowOfLists.length).map(rowOfLists.getSeq))
        transposed.map(Row.fromSeq)
      }
    )(RowEncoder(transformSchema(dataset.schema)))
  }

  override def copy(extra: ParamMap): this.type = defaultCopy(extra)

  override def transformSchema(schema: StructType): StructType = {
    assert(schema.fields.forall(sf => sf.dataType match {
      case _: ArrayType => true
      case _ => false
    }))
    StructType(schema.map(f => StructField(f.name, f.dataType.asInstanceOf[ArrayType].elementType)))
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy