All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.stages.Timer.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.stages

import com.microsoft.ml.spark.core.contracts.Wrappable
import com.microsoft.ml.spark.core.serialize.{ConstructorReadable, ConstructorWritable}
import org.apache.spark.ml._
import org.apache.spark.ml.param.{BooleanParam, ParamMap, PipelineStageParam}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType

import scala.reflect.runtime.universe.{TypeTag, typeTag}

object Timer extends ComplexParamsReadable[Timer]

trait TimerParams extends Wrappable {
  val stage = new PipelineStageParam(this, "stage", "The stage to time")

  def getStage: PipelineStage = $(stage)

  val logToScala = new BooleanParam(this, "logToScala", "Whether to output the time to the scala console")
  setDefault(logToScala -> true)

  def getLogToScala: Boolean = $(logToScala)

  def setLogToScala(v: Boolean): this.type = set(logToScala, v)

  val disableMaterialization = new BooleanParam(this, "disableMaterialization",
                 "Whether to disable timing (so that one can turn it off for evaluation)")

  setDefault(disableMaterialization->true)

  def getDisableMaterialization: Boolean = $(disableMaterialization)

  def setDisable(v: Boolean): this.type = set(disableMaterialization, v)

  protected def formatTime(t: Long, isTransform: Boolean, count: Option[Long], stage: PipelineStage): String = {
    val time = {
      t / 1000000000.0
    }
    val action = if (isTransform) "transform" else "predict"
    val amount = count match {
      case Some(i) => s"$i rows"
      case _ => ""
    }
    s"${stage} took ${time}s to $action $amount"
  }

  protected def log(str: String): Unit = {
    if (str != "") println(str)
  }

}

class Timer(val uid: String) extends Estimator[TimerModel]
  with TimerParams with ComplexParamsWritable {

  def this() = this(Identifiable.randomUID("Timer"))

  def setStage(s: PipelineStage): this.type = set(stage, s)

  override def copy(extra: ParamMap): this.type = defaultCopy(extra)

  override def transformSchema(schema: StructType): StructType = getStage.transformSchema(schema)

  def fitWithTime(dataset: Dataset[_]): (TimerModel, String) = {
    val cachedDatasetBefore = if (getDisableMaterialization) dataset else dataset.cache()
    val countBeforeVal = if (getDisableMaterialization) None else Some(dataset.count())
    getStage match {
      case t: Transformer => (new TimerModel(uid, t).setParent(this), "")
      case e: Estimator[_] =>
        val t0 = System.nanoTime()
        val fitE = e.fit(dataset)
        val t1 = System.nanoTime()
        (new TimerModel(uid, fitE).setParent(this),
         formatTime(t1 - t0, false, countBeforeVal, e))
    }
  }

  def fit(dataset: Dataset[_]): TimerModel = {
    val (model, message) = fitWithTime(dataset)
    log(message)
    model
  }

}

object TimerModel extends ConstructorReadable[TimerModel]

class TimerModel(val uid: String, t: Transformer)
  extends Model[TimerModel] with TimerParams with ConstructorWritable[TimerModel]{

  override val objectsToSave = List(uid, t)
  override val ttag: TypeTag[TimerModel] = typeTag[TimerModel]

  override def copy(extra: ParamMap): TimerModel = {
    new TimerModel(uid, t.copy(extra)).setParent(parent)
  }

  def transformWithTime(dataset: Dataset[_]): (DataFrame, String) = {
    val cachedDatasetBefore = if (getDisableMaterialization) dataset else dataset.cache()
    val countBeforeVal = if (getDisableMaterialization) None else Some(dataset.count())

    val t0 = System.nanoTime()

    val res = t.transform(cachedDatasetBefore)

    val cachedDatasetAfter = if (getDisableMaterialization) res else res.cache()
    val countAfterVal = if (getDisableMaterialization) None else Some(cachedDatasetAfter.count())

    val t1 = System.nanoTime()

    (cachedDatasetAfter, formatTime(t1 - t0, true, countAfterVal, t))
  }

  def transformSchema(schema: StructType): StructType = t.transformSchema(schema)

  def transform(dataset: Dataset[_]): DataFrame = {
    val (model, message) = transformWithTime(dataset)
    log(message)
    model
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy