All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.sona.tree.gbdt.helper.LRScheduler.scala Maven / Gradle / Ivy

/*
 * Tencent is pleased to support the open source community by making Angel available.
 *
 * Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 * https://opensource.org/licenses/Apache-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations under
 * the License.
 *
 */
package com.tencent.angel.sona.tree.gbdt.helper

import com.tencent.angel.sona.tree.gbdt.helper.LRScheduler._
import com.tencent.angel.sona.tree.objective.ObjectiveFactory
import com.tencent.angel.sona.tree.objective.metric.EvalMetric

object LRScheduler {

  private def isBetter(kind: EvalMetric.Kind, former: Double, latter: Double,
                       threshold: Double): Boolean = {
    kind match {
      case EvalMetric.Kind.AUC | EvalMetric.Kind.PRECISION =>
        latter > former && Math.abs(latter - former) / former > threshold
      case EvalMetric.Kind.CROSS_ENTROPY | EvalMetric.Kind.ERROR
           | EvalMetric.Kind.LOG_LOSS | EvalMetric.Kind.RMSE =>
        latter < former && Math.abs(latter - former) / former > threshold
      case _ => throw new RuntimeException(s"Unrecognizable eval metric: $kind")
    }
  }

  /** -------------------- Default hyper-parameters -------------------- */
  val DEFAULT_PATIENT: Int = 5
  val DEFAULT_THRESHOLD: Double = 1e-4  // 0.01%
  val DEFAULT_DECAY_FACTOR: Float = 0.1f
  val DEFAULT_EARLY_STOP: Int = 3
}

class LRScheduler(maxSteps: Int, evalMetricsStr: Seq[String], dataset: String = "valid",
                  patient: Int = DEFAULT_PATIENT, threshold: Double = DEFAULT_THRESHOLD,
                  decayFactor: Float = DEFAULT_DECAY_FACTOR,
                  earlyStop: Int = DEFAULT_EARLY_STOP) extends Serializable {
  require(dataset == "train" || dataset == "valid", s"dataset should be 'train' or 'valid', but got $dataset")

  private var currentStep: Int = 0
  private val evalMetrics: Seq[EvalMetric.Kind] = evalMetricsStr.map(ObjectiveFactory.getEvalMetricKind)
  private val chiefMetric: EvalMetric.Kind = evalMetrics.head
  private val history: Map[String, Map[EvalMetric.Kind, Array[Double]]] = Map(
    "train" -> evalMetrics.map(kind => kind -> Array.ofDim[Double](maxSteps)).toMap,
    "valid" -> evalMetrics.map(kind => kind -> Array.ofDim[Double](maxSteps)).toMap)
  private var bestStep: Int = -1
  private var bestMetric: Double = Double.NaN

  def step(metrics: Seq[(EvalMetric.Kind, Double, Double)], currentLR: Float): Float = {
    metrics.foreach {
      case (kind, train, valid) =>
        history("train")(kind)(currentStep) = train
        history("valid")(kind)(currentStep) = valid
    }
    currentStep += 1

    val currentMetric = history(dataset)(chiefMetric)(currentStep - 1)
    if (bestStep == -1 || isBetter(chiefMetric, bestMetric, currentMetric, threshold)) {
      bestStep = currentStep
      bestMetric = currentMetric
    }

    if (shouldEarlyStop()) {
      println(s"$chiefMetric does not improve for ${patient * earlyStop} " +
        s"rounds, early stopping")
      0f
    } else if (shouldDecay()) {
      val decayedLR = currentLR * decayFactor
      println(s"$chiefMetric does not improve for $patient " +
        s"rounds, decay learning rate from $currentLR to $decayedLR")
      decayedLR
    } else {
      currentLR
    }
  }

  def shouldDecay(): Boolean = patient > 0 && currentStep - bestStep == patient

  def shouldEarlyStop(): Boolean = earlyStop > 0 && patient > 0 &&
    currentStep - bestStep >= patient * earlyStop

  def getHistory: Map[String, Map[EvalMetric.Kind, Array[Double]]] = history

  def getBest: (EvalMetric.Kind, Int, Double) = (chiefMetric, bestStep, bestMetric)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy