All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.github.metarank.ltrlib.booster.LightGBMBooster.scala Maven / Gradle / Ivy

There is a newer version: 0.2.6
Show newest version
package io.github.metarank.ltrlib.booster

import io.github.metarank.lightgbm4j.{LGBMBooster, LGBMDataset}
import Booster.{BoosterFactory, BoosterOptions, DatasetOptions}
import com.microsoft.ml.lightgbm.PredictionType
import io.github.metarank.lightgbm4j.LGBMBooster.FeatureImportanceType

import java.nio.charset.StandardCharsets
import scala.collection.mutable

case class LightGBMBooster(model: LGBMBooster, datasets: mutable.Map[LGBMDataset, Int] = mutable.Map.empty)
    extends Booster[LGBMDataset] {
  override def trainOneIteration(dataset: LGBMDataset): Unit = model.updateOneIter()
  override def evalMetric(dataset: LGBMDataset): Double = {
    datasets.get(dataset) match {
      case Some(index) => model.getEval(index)(0)
      case None =>
        val maxIndex = datasets.values.reduceLeftOption(math.max).getOrElse(0)
        datasets.put(dataset, maxIndex + 1)
        model.addValidData(dataset)
        model.getEval(maxIndex + 1)(0)
    }
  }

  override def predictMat(values: Array[Double], rows: Int, cols: Int): Array[Double] = {
    model.predictForMat(values, rows, cols, true, PredictionType.C_API_PREDICT_NORMAL)
  }

  override def save(): Array[Byte] =
    model.saveModelToString(0, 0, FeatureImportanceType.SPLIT).getBytes(StandardCharsets.UTF_8)

  override def weights(): Array[Double] =
    // numIteration=0 means "use all of them"
    // we use split there to match xgboost, which can only do split
    model.featureImportance(0, FeatureImportanceType.SPLIT)
}

object LightGBMBooster extends BoosterFactory[LGBMDataset, LightGBMBooster, LightGBMOptions] {
  override def formatData(d: BoosterDataset, parent: Option[LGBMDataset]): LGBMDataset = {
    val ds = LGBMDataset.createFromMat(d.data, d.rows, d.cols, true, "", parent.orNull)
    ds.setField("label", d.labels.map(_.toFloat))
    ds.setField("group", d.groups)
    ds.setFeatureNames(d.featureNames)
    ds
  }
  def apply(string: Array[Byte]): LightGBMBooster = {
    LightGBMBooster(LGBMBooster.loadModelFromString(new String(string, StandardCharsets.UTF_8)))
  }
  override def apply(ds: LGBMDataset, options: LightGBMOptions, dso: DatasetOptions) = {
    val paramsMap = Map(
      "objective"                   -> "lambdarank",
      "metric"                      -> "ndcg",
      "lambdarank_truncation_level" -> options.ndcgCutoff.toString,
      "max_depth"                   -> options.maxDepth.toString,
      "learning_rate"               -> options.learningRate.toString,
      "num_leaves"                  -> options.numLeaves.toString,
      "seed"                        -> options.randomSeed.toString,
      "categorical_feature"         -> dso.categoryFeatures.mkString(","),
      "feature_fraction"            -> options.featureFraction.toString
    )
    val params = paramsMap.map(kv => s"${kv._1}=${kv._2}").mkString(" ")
    new LightGBMBooster(
      model = LGBMBooster.create(ds, params)
    )
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy