All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.ml.events.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.ml

import com.fasterxml.jackson.annotation.JsonIgnore

import org.apache.spark.SparkContext
import org.apache.spark.annotation.Evolving
import org.apache.spark.internal.Logging
import org.apache.spark.ml.util.{MLReader, MLWriter}
import org.apache.spark.scheduler.SparkListenerEvent
import org.apache.spark.sql.{DataFrame, Dataset}

/**
 * Event emitted by ML operations. Events are either fired before and/or
 * after each operation (the event should document this).
 *
 * @note This is supported via [[Pipeline]] and [[PipelineModel]].
 */
@Evolving
sealed trait MLEvent extends SparkListenerEvent {
  // Do not log ML events in event log. It should be revisited to see
  // how it works with history server.
  protected[spark] override def logEvent: Boolean = false
}

/**
 * Event fired before `Transformer.transform`.
 */
@Evolving
case class TransformStart() extends MLEvent {
  @JsonIgnore var transformer: Transformer = _
  @JsonIgnore var input: Dataset[_] = _
}

/**
 * Event fired after `Transformer.transform`.
 */
@Evolving
case class TransformEnd() extends MLEvent {
  @JsonIgnore var transformer: Transformer = _
  @JsonIgnore var output: Dataset[_] = _
}

/**
 * Event fired before `Estimator.fit`.
 */
@Evolving
case class FitStart[M <: Model[M]]() extends MLEvent {
  @JsonIgnore var estimator: Estimator[M] = _
  @JsonIgnore var dataset: Dataset[_] = _
}

/**
 * Event fired after `Estimator.fit`.
 */
@Evolving
case class FitEnd[M <: Model[M]]() extends MLEvent {
  @JsonIgnore var estimator: Estimator[M] = _
  @JsonIgnore var model: M = _
}

/**
 * Event fired before `MLReader.load`.
 */
@Evolving
case class LoadInstanceStart[T](path: String) extends MLEvent {
  @JsonIgnore var reader: MLReader[T] = _
}

/**
 * Event fired after `MLReader.load`.
 */
@Evolving
case class LoadInstanceEnd[T]() extends MLEvent {
  @JsonIgnore var reader: MLReader[T] = _
  @JsonIgnore var instance: T = _
}

/**
 * Event fired before `MLWriter.save`.
 */
@Evolving
case class SaveInstanceStart(path: String) extends MLEvent {
  @JsonIgnore var writer: MLWriter = _
}

/**
 * Event fired after `MLWriter.save`.
 */
@Evolving
case class SaveInstanceEnd(path: String) extends MLEvent {
  @JsonIgnore var writer: MLWriter = _
}

/**
 * A small trait that defines some methods to send [[org.apache.spark.ml.MLEvent]].
 */
private[ml] trait MLEvents extends Logging {

  private def listenerBus = SparkContext.getOrCreate().listenerBus

  /**
   * Log [[MLEvent]] to send. By default, it emits a debug-level log.
   */
  def logEvent(event: MLEvent): Unit = logDebug(s"Sending an MLEvent: $event")

  def withFitEvent[M <: Model[M]](
      estimator: Estimator[M], dataset: Dataset[_])(func: => M): M = {
    val startEvent = FitStart[M]()
    startEvent.estimator = estimator
    startEvent.dataset = dataset
    logEvent(startEvent)
    listenerBus.post(startEvent)
    val model: M = func
    val endEvent = FitEnd[M]()
    endEvent.estimator = estimator
    endEvent.model = model
    logEvent(endEvent)
    listenerBus.post(endEvent)
    model
  }

  def withTransformEvent(
      transformer: Transformer, input: Dataset[_])(func: => DataFrame): DataFrame = {
    val startEvent = TransformStart()
    startEvent.transformer = transformer
    startEvent.input = input
    logEvent(startEvent)
    listenerBus.post(startEvent)
    val output: DataFrame = func
    val endEvent = TransformEnd()
    endEvent.transformer = transformer
    endEvent.output = output
    logEvent(endEvent)
    listenerBus.post(endEvent)
    output
  }

  def withLoadInstanceEvent[T](reader: MLReader[T], path: String)(func: => T): T = {
    val startEvent = LoadInstanceStart[T](path)
    startEvent.reader = reader
    logEvent(startEvent)
    listenerBus.post(startEvent)
    val instance: T = func
    val endEvent = LoadInstanceEnd[T]()
    endEvent.reader = reader
    endEvent.instance = instance
    logEvent(endEvent)
    listenerBus.post(endEvent)
    instance
  }

  def withSaveInstanceEvent(writer: MLWriter, path: String)(func: => Unit): Unit = {
    val startEvent = SaveInstanceStart(path)
    startEvent.writer = writer
    logEvent(startEvent)
    listenerBus.post(startEvent)
    func
    val endEvent = SaveInstanceEnd(path)
    endEvent.writer = writer
    logEvent(endEvent)
    listenerBus.post(endEvent)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy