All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.stages.EnsembleByKey.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.stages

import com.microsoft.ml.spark.core.contracts.Wrappable
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.linalg.SQLDataTypes._
import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors}
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import spray.json.DefaultJsonProtocol._

import scala.collection.mutable

object EnsembleByKey extends DefaultParamsReadable[EnsembleByKey]

class EnsembleByKey(val uid: String) extends Transformer with Wrappable with DefaultParamsWritable {
  def this() = this(Identifiable.randomUID("EnsembleByKey"))

  val keys = new StringArrayParam(this, "keys", "Keys to group by")

  def getKeys: Array[String] = $(keys)

  def setKeys(arr: Array[String]): this.type = set(keys, arr)

  def setKeys(arr: String*): this.type = set(keys, arr.toArray)

  def setKey(value: String): this.type = set(keys, Array(value))

  val cols = new StringArrayParam(this, "cols", "Cols to ensemble")

  def getCols: Array[String] = $(cols)

  def setCols(arr: Array[String]): this.type = set(cols, arr)

  def setCols(arr: String*): this.type = set(cols, arr.toArray)

  def setCol(value: String): this.type = set(cols, Array(value))

  val colNames = new StringArrayParam(this, "colNames", "Names of the result of each col")

  def getColNames: Array[String] = $(colNames)

  def setColNames(arr: Array[String]): this.type = set(colNames, arr)

  def setColNames(arr: String*): this.type = set(colNames, arr.toArray)

  def setColName(value: String): this.type = set(colNames, Array(value))

  val allowedStrategies = Set("mean")
  val strategy = new Param[String](this, "strategy", "How to ensemble the scores, ex: mean",
                { x: String => allowedStrategies(x) })

  def getStrategy: String = $(strategy)

  def setStrategy(value: String): this.type = set(strategy, value)

  setDefault(strategy -> "mean")

  val collapseGroup = new BooleanParam(
    this, "collapseGroup", "Whether to collapse all items in group to one entry")

  def getCollapseGroup: Boolean = $(collapseGroup)

  def setCollapseGroup(value: Boolean): this.type = set(collapseGroup, value)

  val vectorDims =new MapParam[String, Int](this, "vectorDims",
    "the dimensions of any vector columns, used to avoid materialization")

  def getVectorDims: Map[String, Int] = get(vectorDims).getOrElse(Map())

  def setVectorDims(value: Map[String, Int]): this.type = set(vectorDims, value)

  setDefault(collapseGroup -> true)

  override def transform(dataset: Dataset[_]): DataFrame = {

    if (get(colNames).isEmpty) {
      setDefault(colNames -> getCols.map(name => s"$getStrategy($name)"))
    }

    transformSchema(dataset.schema)

    val strategyToFloatFunction = Map(
      "mean" -> { (x: String, y: String) => mean(x).alias(y) }
    )

    val strategyToVectorFunction = Map(
      "mean" -> { (x: String, y: String) =>
        val dim = getVectorDims.getOrElse(x,
          dataset.select(x).take(1)(0).getAs[DenseVector](0).size)
        new VectorAvg(dim)(dataset(x)).asInstanceOf[Column].alias(y)
      }
    )

    val newCols = getCols.zip(getColNames).map { case (inColName, outColName) =>
      dataset.schema(inColName).dataType match {
        case _: DoubleType =>
          strategyToFloatFunction(getStrategy)(inColName, outColName)
        case _: FloatType =>
          strategyToFloatFunction(getStrategy)(inColName, outColName)
        case v if v == VectorType =>
          strategyToVectorFunction(getStrategy)(inColName, outColName)
        case t =>
          throw new IllegalArgumentException(s"Cannot operate on type $t with strategy $getStrategy")
      }
    }

    val aggregated = dataset.toDF()
      .groupBy(getKeys.head, getKeys.tail: _*)
      .agg(newCols.head, newCols.tail: _*)

    if (getCollapseGroup) {
      aggregated
    } else {
      val needToDrop = getColNames.toSet & dataset.columns.toSet
      dataset.drop(needToDrop.toList: _*).toDF().join(aggregated, getKeys)
    }

  }

  def transformSchema(schema: StructType): StructType = {
    val colSet = getCols.toSet
    val colToNewName = getCols.zip(getColNames).toMap

    val newFields = schema.fields.flatMap { f =>
      if (!colSet(f.name)) None
      else {
        val newField = StructField(colToNewName(f.name), f.dataType)
        f.dataType match {
          case _: DoubleType => Some(newField)
          case _: FloatType => Some(newField)
          case fdt if fdt == VectorType => Some(newField)
          case t => throw new IllegalArgumentException(s"Cannot operate on type $t with strategy $getStrategy")
        }
      }
    }

    val keyFields = schema.fields.filter(f => colSet(f.name))
    val fields =
      (if (getCollapseGroup) schema.fields else keyFields).++(newFields)

    new StructType(fields)
  }

  def copy(extra: ParamMap): this.type = defaultCopy(extra)

}

private class VectorAvg(n: Int) extends UserDefinedAggregateFunction {

  def inputSchema: StructType = new StructType().add("v", VectorType)

  def bufferSchema: StructType =
    new StructType().add("buff", ArrayType(DoubleType)).add("count", LongType)

  def dataType: DataType = VectorType

  def deterministic: Boolean = true

  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0, Array.fill(n)(0.0))
    buffer.update(1, 0L)
  }

  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    if (!input.isNullAt(0)) {
      val buff = buffer.getAs[mutable.WrappedArray[Double]](0)
      val count = buffer.getLong(1)

      val v = input.getAs[Vector](0).toSparse
      for (i <- v.indices) {
        buff(i) += v(i)
      }
      buffer.update(0, buff)
      buffer.update(1, count + 1L)
    }
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val buff1 = buffer1.getAs[mutable.WrappedArray[Double]](0)
    val buff2 = buffer2.getAs[mutable.WrappedArray[Double]](0)
    val c1 = buffer1.getLong(1)
    val c2 = buffer2.getLong(1)

    for ((x, i) <- buff2.zipWithIndex) {
      buff1(i) += x
    }
    buffer1.update(0, buff1)
    buffer1.update(1, c1 + c2)
  }

  def evaluate(buffer: Row): Vector = {
    val c = buffer.getLong(1)
    Vectors.dense(buffer.getAs[Seq[Double]](0).map(_ / c).toArray)
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy