org.apache.spark.ml.events.scala Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml
import com.fasterxml.jackson.annotation.JsonIgnore
import org.apache.spark.SparkContext
import org.apache.spark.annotation.Evolving
import org.apache.spark.internal.Logging
import org.apache.spark.ml.util.{MLReader, MLWriter}
import org.apache.spark.scheduler.SparkListenerEvent
import org.apache.spark.sql.{DataFrame, Dataset}
/**
* Event emitted by ML operations. Events are either fired before and/or
* after each operation (the event should document this).
*
* @note This is supported via [[Pipeline]] and [[PipelineModel]].
*/
@Evolving
sealed trait MLEvent extends SparkListenerEvent {
// Do not log ML events in event log. It should be revisited to see
// how it works with history server.
protected[spark] override def logEvent: Boolean = false
}
/**
* Event fired before `Transformer.transform`.
*/
@Evolving
case class TransformStart() extends MLEvent {
@JsonIgnore var transformer: Transformer = _
@JsonIgnore var input: Dataset[_] = _
}
/**
* Event fired after `Transformer.transform`.
*/
@Evolving
case class TransformEnd() extends MLEvent {
@JsonIgnore var transformer: Transformer = _
@JsonIgnore var output: Dataset[_] = _
}
/**
* Event fired before `Estimator.fit`.
*/
@Evolving
case class FitStart[M <: Model[M]]() extends MLEvent {
@JsonIgnore var estimator: Estimator[M] = _
@JsonIgnore var dataset: Dataset[_] = _
}
/**
* Event fired after `Estimator.fit`.
*/
@Evolving
case class FitEnd[M <: Model[M]]() extends MLEvent {
@JsonIgnore var estimator: Estimator[M] = _
@JsonIgnore var model: M = _
}
/**
* Event fired before `MLReader.load`.
*/
@Evolving
case class LoadInstanceStart[T](path: String) extends MLEvent {
@JsonIgnore var reader: MLReader[T] = _
}
/**
* Event fired after `MLReader.load`.
*/
@Evolving
case class LoadInstanceEnd[T]() extends MLEvent {
@JsonIgnore var reader: MLReader[T] = _
@JsonIgnore var instance: T = _
}
/**
* Event fired before `MLWriter.save`.
*/
@Evolving
case class SaveInstanceStart(path: String) extends MLEvent {
@JsonIgnore var writer: MLWriter = _
}
/**
* Event fired after `MLWriter.save`.
*/
@Evolving
case class SaveInstanceEnd(path: String) extends MLEvent {
@JsonIgnore var writer: MLWriter = _
}
/**
* A small trait that defines some methods to send [[org.apache.spark.ml.MLEvent]].
*/
private[ml] trait MLEvents extends Logging {
private def listenerBus = SparkContext.getOrCreate().listenerBus
/**
* Log [[MLEvent]] to send. By default, it emits a debug-level log.
*/
def logEvent(event: MLEvent): Unit = logDebug(s"Sending an MLEvent: $event")
def withFitEvent[M <: Model[M]](
estimator: Estimator[M], dataset: Dataset[_])(func: => M): M = {
val startEvent = FitStart[M]()
startEvent.estimator = estimator
startEvent.dataset = dataset
logEvent(startEvent)
listenerBus.post(startEvent)
val model: M = func
val endEvent = FitEnd[M]()
endEvent.estimator = estimator
endEvent.model = model
logEvent(endEvent)
listenerBus.post(endEvent)
model
}
def withTransformEvent(
transformer: Transformer, input: Dataset[_])(func: => DataFrame): DataFrame = {
val startEvent = TransformStart()
startEvent.transformer = transformer
startEvent.input = input
logEvent(startEvent)
listenerBus.post(startEvent)
val output: DataFrame = func
val endEvent = TransformEnd()
endEvent.transformer = transformer
endEvent.output = output
logEvent(endEvent)
listenerBus.post(endEvent)
output
}
def withLoadInstanceEvent[T](reader: MLReader[T], path: String)(func: => T): T = {
val startEvent = LoadInstanceStart[T](path)
startEvent.reader = reader
logEvent(startEvent)
listenerBus.post(startEvent)
val instance: T = func
val endEvent = LoadInstanceEnd[T]()
endEvent.reader = reader
endEvent.instance = instance
logEvent(endEvent)
listenerBus.post(endEvent)
instance
}
def withSaveInstanceEvent(writer: MLWriter, path: String)(func: => Unit): Unit = {
val startEvent = SaveInstanceStart(path)
startEvent.writer = writer
logEvent(startEvent)
listenerBus.post(startEvent)
func
val endEvent = SaveInstanceEnd(path)
endEvent.writer = writer
logEvent(endEvent)
listenerBus.post(endEvent)
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy