All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.sona.ml.util.Instrumentation.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.tencent.angel.sona.ml.util

import java.util.UUID

import scala.util.{Failure, Success, Try}
import scala.util.control.NonFatal
import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.internal.Logging
import com.tencent.angel.sona.ml.PipelineStage
import com.tencent.angel.sona.ml.param.{Param, Params}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
import org.apache.spark.util.Utils

/**
  * A small wrapper that defines a training session for an estimator, and some methods to log
  * useful information during this session.
  */
private[sona] class Instrumentation private() extends Logging {

  private val id = UUID.randomUUID()
  private val shortId = id.toString.take(8)
  private[util] val prefix = s"[$shortId] "

  /**
    * Log some info about the pipeline stage being fit.
    */
  def logPipelineStage(stage: PipelineStage): Unit = {
    // estimator.getClass.getSimpleName can cause Malformed class name error,
    // call safer `Utils.getSimpleName` instead
    //    val className = Utils.getSimpleName(stage.getClass)
    //    logInfo(s"Stage class: $className")
    logInfo(s"Stage uid: ${stage.uid}")
  }

  /**
    * Log some data about the dataset being fit.
    */
  def logDataset(dataset: Dataset[_]): Unit = logDataset(dataset.rdd)

  /**
    * Log some data about the dataset being fit.
    */
  def logDataset(dataset: RDD[_]): Unit = {
    logInfo(s"training: numPartitions=${dataset.partitions.length}" +
      s" storageLevel=${dataset.getStorageLevel}")
  }

  /**
    * Logs a debug message with a prefix that uniquely identifies the training session.
    */
  override def logDebug(msg: => String): Unit = {
    super.logDebug(prefix + msg)
  }

  /**
    * Logs a warning message with a prefix that uniquely identifies the training session.
    */
  override def logWarning(msg: => String): Unit = {
    super.logWarning(prefix + msg)
  }

  /**
    * Logs a error message with a prefix that uniquely identifies the training session.
    */
  override def logError(msg: => String): Unit = {
    super.logError(prefix + msg)
  }

  /**
    * Logs an info message with a prefix that uniquely identifies the training session.
    */
  override def logInfo(msg: => String): Unit = {
    super.logInfo(prefix + msg)
  }

  /**
    * Logs the value of the given parameters for the estimator being used in this session.
    */
  def logParams(hasParams: Params, params: Param[_]*): Unit = {
    val pairs: Seq[(String, JValue)] = for {
      p <- params
      value <- hasParams.get(p)
    } yield {
      val cast = p.asInstanceOf[Param[Any]]
      p.name -> parse(cast.jsonEncode(value))
    }
    logInfo(compact(render(map2jvalue(pairs.toMap))))
  }

  def logNumFeatures(num: Long): Unit = {
    logNamedValue(Instrumentation.loggerTags.numFeatures, num)
  }

  def logNumClasses(num: Long): Unit = {
    logNamedValue(Instrumentation.loggerTags.numClasses, num)
  }

  def logNumExamples(num: Long): Unit = {
    logNamedValue(Instrumentation.loggerTags.numExamples, num)
  }

  /**
    * Logs the value with customized name field.
    */
  def logNamedValue(name: String, value: String): Unit = {
    logInfo(compact(render(name -> value)))
  }

  def logNamedValue(name: String, value: Long): Unit = {
    logInfo(compact(render(name -> value)))
  }

  def logNamedValue(name: String, value: Double): Unit = {
    logInfo(compact(render(name -> value)))
  }

  def logNamedValue(name: String, value: Array[String]): Unit = {
    logInfo(compact(render(name -> compact(render(value.toSeq)))))
  }

  def logNamedValue(name: String, value: Array[Long]): Unit = {
    logInfo(compact(render(name -> compact(render(value.toSeq)))))
  }

  def logNamedValue(name: String, value: Array[Double]): Unit = {
    logInfo(compact(render(name -> compact(render(value.toSeq)))))
  }


  /**
    * Logs the successful completion of the training session.
    */
  def logSuccess(): Unit = {
    logInfo("training finished")
  }

  /**
    * Logs an exception raised during a training session.
    */
  def logFailure(e: Throwable): Unit = {
    val msg = e.getStackTrace.mkString("\n")
    super.logError(msg)
  }
}

/**
  * Some common methods for logging information about a training session.
  */
private[sona] object Instrumentation {

  object loggerTags {
    val numFeatures = "numFeatures"
    val numClasses = "numClasses"
    val numExamples = "numExamples"
    val meanOfLabels = "meanOfLabels"
    val varianceOfLabels = "varianceOfLabels"
  }

  def instrumented[T](body: (Instrumentation => T)): T = {
    val instr = new Instrumentation()
    Try(body(instr)) match {
      case Failure(NonFatal(e)) =>
        instr.logFailure(e)
        throw e
      case Success(result) =>
        instr.logSuccess()
        result
    }
  }

  def create(params: Params, rdd: RDD[_]): Instrumentation = {
    new Instrumentation()
  }
}

/**
  * A small wrapper that contains an optional `Instrumentation` object.
  * Provide some log methods, if the containing `Instrumentation` object is defined,
  * will log via it, otherwise will log via common logger.
  */
private[sona] class OptionalInstrumentation private(
                                                     val instrumentation: Option[Instrumentation],
                                                     val className: String) extends Logging {

  protected override def logName: String = className

  override def logInfo(msg: => String) {
    instrumentation match {
      case Some(instr) => instr.logInfo(msg)
      case None => super.logInfo(msg)
    }
  }

  override def logWarning(msg: => String) {
    instrumentation match {
      case Some(instr) => instr.logWarning(msg)
      case None => super.logWarning(msg)
    }
  }

  override def logError(msg: => String) {
    instrumentation match {
      case Some(instr) => instr.logError(msg)
      case None => super.logError(msg)
    }
  }
}

private[sona] object OptionalInstrumentation {

  /**
    * Creates an `OptionalInstrumentation` object from an existing `Instrumentation` object.
    */
  def create(instr: Instrumentation): OptionalInstrumentation = {
    new OptionalInstrumentation(Some(instr), instr.prefix)
  }

  /**
    * Creates an `OptionalInstrumentation` object from a `Class` object.
    * The created `OptionalInstrumentation` object will log messages via common logger and use the
    * specified class name as logger name.
    */
  def create(clazz: Class[_]): OptionalInstrumentation = {
    new OptionalInstrumentation(None, clazz.getName.stripSuffix("$"))
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy