com.microsoft.ml.spark.lightgbm.TrainUtils.scala Maven / Gradle / Ivy
The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.
package com.microsoft.ml.spark.lightgbm
import java.io._
import java.net._
import com.microsoft.ml.lightgbm._
import com.microsoft.ml.spark.core.env.StreamUtilities.using
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.{DenseVector, SparseVector}
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
import org.slf4j.Logger
import org.apache.spark.BarrierTaskContext
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
case class NetworkParams(defaultListenPort: Int, addr: String, port: Int, barrierExecutionMode: Boolean)
private object TrainUtils extends Serializable {
def generateDataset(rows: Array[Row], labelColumn: String, featuresColumn: String,
weightColumn: Option[String], initScoreColumn: Option[String], groupColumn: Option[String],
referenceDataset: Option[LightGBMDataset], schema: StructType,
log: Logger, trainParams: TrainParams): Option[LightGBMDataset] = {
val numRows = rows.length
val labels = rows.map(row => row.getDouble(schema.fieldIndex(labelColumn)))
if (trainParams.objective == LightGBMConstants.MulticlassObjective ||
trainParams.objective == LightGBMConstants.BinaryObjective) {
val distinctLabels = labels.distinct.map(_.toInt).sorted
// TODO: Temporary hack to append missing labels for debugging, off by default
// try to figure out a better fix in lightgbm
if (trainParams.asInstanceOf[ClassifierTrainParams].generateMissingLabels) {
val (count, missingLabels) =
distinctLabels.foldLeft((-1, List[Int]())) {
case ((baseCount, baseLabels), newLabel) => {
if (newLabel == baseCount + 1) (newLabel, baseLabels)
else (baseCount + 1, baseCount + 1 :: baseLabels)
}
}
if (!missingLabels.isEmpty) {
// Append missing labels to rows
val newRows = rows.take(missingLabels.size).zip(missingLabels).map { case (row, label) =>
val rowAsArray = row.toSeq.toArray
rowAsArray.update(schema.fieldIndex(labelColumn), label.toDouble)
new GenericRowWithSchema(rowAsArray, row.schema) }
return generateDataset(rows ++ newRows, labelColumn, featuresColumn, weightColumn, initScoreColumn,
groupColumn, referenceDataset, schema, log, trainParams)
}
} else {
val errMsg = "For classification, label values must start from 0 and increase " +
"by 1 to n for each partition."
distinctLabels.foldLeft(-1)((base, newLabel) => if (newLabel == base + 1) newLabel else
throw new Exception(s"$errMsg Missing label ${base + 1}, unique labels ${distinctLabels.mkString(",")}"))
}
}
val hrow = rows.head
var datasetPtr: Option[LightGBMDataset] = None
datasetPtr =
if (hrow.get(schema.fieldIndex(featuresColumn)).isInstanceOf[DenseVector]) {
val rowsAsDoubleArray = rows.map(row => row.get(schema.fieldIndex(featuresColumn)) match {
case dense: DenseVector => dense.toArray
case sparse: SparseVector => sparse.toDense.toArray
})
val numCols = rowsAsDoubleArray.head.length
val slotNames = getSlotNames(schema, featuresColumn, numCols)
log.info(s"LightGBM worker generating dense dataset with $numRows rows and $numCols columns")
Some(LightGBMUtils.generateDenseDataset(numRows, rowsAsDoubleArray, referenceDataset, slotNames))
} else {
val rowsAsSparse = rows.map(row => row.get(schema.fieldIndex(featuresColumn)) match {
case dense: DenseVector => dense.toSparse
case sparse: SparseVector => sparse
})
val numCols = rowsAsSparse(0).size
val slotNames = getSlotNames(schema, featuresColumn, numCols)
log.info(s"LightGBM worker generating sparse dataset with $numRows rows and $numCols columns")
Some(LightGBMUtils.generateSparseDataset(rowsAsSparse, referenceDataset, slotNames))
}
// Validate generated dataset has the correct number of rows and cols
datasetPtr.get.validateDataset()
datasetPtr.get.addFloatField(labels, "label", numRows)
weightColumn.foreach { col =>
val weights = rows.map(row => row.getDouble(schema.fieldIndex(col)))
datasetPtr.get.addFloatField(weights, "weight", numRows)
}
addGroupColumn(rows, groupColumn, datasetPtr, numRows, schema)
initScoreColumn.foreach { col =>
val initScores = rows.map(row => row.getDouble(schema.fieldIndex(col)))
datasetPtr.get.addDoubleField(initScores, "init_score", numRows)
}
datasetPtr
}
def addGroupColumn(rows: Array[Row], groupColumn: Option[String],
datasetPtr: Option[LightGBMDataset], numRows: Int, schema: StructType): Unit = {
groupColumn.foreach { col =>
val datatype = schema.fields(schema.fieldIndex(col)).dataType
val group =
if (datatype == org.apache.spark.sql.types.IntegerType) {
rows.map(row => row.getInt(schema.fieldIndex(col)))
} else {
rows.map(row => row.getLong(schema.fieldIndex(col)).toInt)
}
// Convert to distinct count (note ranker should have sorted within partition by group id)
// We use a triplet of a list of cardinalities, last unqiue value and unique value count
val cardinalityTriplet =
group.foldLeft((List.empty[Int], -1, 0)) { (listValue, currentValue) =>
if (listValue._2 < 0) {
// Base case, keep list as empty and set cardinality to 1
(listValue._1, currentValue, 1)
}
else if (listValue._2 == currentValue) {
// Encountered same value
(listValue._1, currentValue, listValue._3 + 1)
}
else {
// New value, need to reset counter and add new cardinality to list
(listValue._3 :: listValue._1, currentValue, 1)
}
}
val groupCardinality = (cardinalityTriplet._3 :: cardinalityTriplet._1).reverse.toArray
datasetPtr.get.addIntField(groupCardinality, "group", groupCardinality.length)
}
}
def createBooster(trainParams: TrainParams, trainDatasetPtr: Option[LightGBMDataset],
validDatasetPtr: Option[LightGBMDataset]): Option[SWIGTYPE_p_void] = {
// Create the booster
val boosterOutPtr = lightgbmlib.voidpp_handle()
val parameters = trainParams.toString()
LightGBMUtils.validate(lightgbmlib.LGBM_BoosterCreate(trainDatasetPtr.map(_.dataset).get,
parameters, boosterOutPtr), "Booster")
val boosterPtr = Some(lightgbmlib.voidpp_value(boosterOutPtr))
trainParams.modelString.foreach { modelStr =>
val booster = LightGBMUtils.getBoosterPtrFromModelString(modelStr)
LightGBMUtils.validate(lightgbmlib.LGBM_BoosterMerge(boosterPtr.get, booster), "Booster Merge")
}
validDatasetPtr.foreach { lgbmdataset =>
LightGBMUtils.validate(lightgbmlib.LGBM_BoosterAddValidData(boosterPtr.get,
lgbmdataset.dataset), "Add Validation Dataset")
}
boosterPtr
}
def saveBoosterToString(boosterPtr: Option[SWIGTYPE_p_void], log: Logger): String = {
val bufferLength = LightGBMConstants.DefaultBufferLength
val bufferLengthPtr = lightgbmlib.new_longp()
lightgbmlib.longp_assign(bufferLengthPtr, bufferLength)
val bufferLengthPtrInt64 = lightgbmlib.long_to_int64_t_ptr(bufferLengthPtr)
val bufferOutLengthPtr = lightgbmlib.new_int64_tp()
lightgbmlib.LGBM_BoosterSaveModelToStringSWIG(boosterPtr.get, 0, -1, bufferLengthPtrInt64, bufferOutLengthPtr)
}
def getEvalNames(boosterPtr: Option[SWIGTYPE_p_void]): Array[String] = {
// Need to keep track of best scores for each metric, see callback.py in lightgbm for reference
val evalCountsPtr = lightgbmlib.new_intp()
val resultCounts = lightgbmlib.LGBM_BoosterGetEvalCounts(boosterPtr.get, evalCountsPtr)
LightGBMUtils.validate(resultCounts, "Booster Get Eval Counts")
val evalCounts = lightgbmlib.intp_value(evalCountsPtr)
// For debugging, can get metric names:
val evalNamesPtr = lightgbmlib.LGBM_BoosterGetEvalNamesSWIG(boosterPtr.get, evalCounts)
(0 until evalCounts).map(lightgbmlib.stringArray_getitem(evalNamesPtr, _)).toArray
}
def trainCore(trainParams: TrainParams, boosterPtr: Option[SWIGTYPE_p_void],
log: Logger, hasValid: Boolean): Unit = {
val isFinishedPtr = lightgbmlib.new_intp()
var isFinished = false
var iters = 0
val evalNames = getEvalNames(boosterPtr)
val evalCounts = evalNames.length
val bestScore = new Array[Double](evalCounts)
val bestScores = new Array[Array[Double]](evalCounts)
val bestIter = new Array[Int](evalCounts)
while (!isFinished && iters < trainParams.numIterations) {
try {
log.info("LightGBM worker calling LGBM_BoosterUpdateOneIter")
val result = lightgbmlib.LGBM_BoosterUpdateOneIter(boosterPtr.get, isFinishedPtr)
LightGBMUtils.validate(result, "Booster Update One Iter")
isFinished = lightgbmlib.intp_value(isFinishedPtr) == 1
log.info("LightGBM running iteration: " + iters + " with result: " +
result + " and is finished: " + isFinished)
} catch {
case _: java.lang.Exception =>
isFinished = true
log.warn("LightGBM reached early termination on one worker," +
" stopping training on worker. This message should rarely occur")
}
if (trainParams.isProvideTrainingMetric && !isFinished) {
val trainResults = lightgbmlib.new_doubleArray(evalNames.length)
val dummyEvalCountsPtr = lightgbmlib.new_intp()
val resultEval = lightgbmlib.LGBM_BoosterGetEval(boosterPtr.get, 0, dummyEvalCountsPtr, trainResults)
lightgbmlib.delete_intp(dummyEvalCountsPtr)
LightGBMUtils.validate(resultEval, "Booster Get Train Eval")
evalNames.zipWithIndex.foreach { case (evalName, index) =>
val score = lightgbmlib.doubleArray_getitem(trainResults, index)
log.info(s"Train $evalName=$score")
}
}
if (hasValid && !isFinished) {
val evalResults = lightgbmlib.new_doubleArray(evalNames.length)
val dummyEvalCountsPtr = lightgbmlib.new_intp()
val resultEval = lightgbmlib.LGBM_BoosterGetEval(boosterPtr.get, 1, dummyEvalCountsPtr, evalResults)
lightgbmlib.delete_intp(dummyEvalCountsPtr)
LightGBMUtils.validate(resultEval, "Booster Get Valid Eval")
evalNames.zipWithIndex.foreach { case (evalName, index) =>
val score = lightgbmlib.doubleArray_getitem(evalResults, index)
log.info(s"Valid $evalName=$score")
val cmp =
if (evalName.startsWith("auc") || evalName.startsWith("ndcg@") || evalName.startsWith("map@"))
(x: Double, y: Double) => x > y
else
(x: Double, y: Double) => x < y
if (bestScores(index) == null || cmp(score, bestScore(index))) {
bestScore(index) = score
bestIter(index) = iters
bestScores(index) = evalNames.indices
.map(j => lightgbmlib.doubleArray_getitem(evalResults, j)).toArray
} else if (iters - bestIter(index) >= trainParams.earlyStoppingRound) {
isFinished = true
log.info("Early stopping, best iteration is " + bestIter(index))
}
}
lightgbmlib.delete_doubleArray(evalResults)
}
iters = iters + 1
}
}
def getSlotNames(schema: StructType, featuresColumn: String, numCols: Int): Option[Array[String]] = {
val featuresSchema = schema.fields(schema.fieldIndex(featuresColumn))
val metadata = AttributeGroup.fromStructField(featuresSchema)
if (metadata.attributes.isEmpty) None
else if (metadata.attributes.get.isEmpty) None
else {
val colnames = (0 until numCols).map(_.toString).toArray
metadata.attributes.get.foreach {
case attr =>
attr.index.foreach(index => colnames(index) = attr.name.getOrElse(index.toString))
}
Some(colnames)
}
}
def translate(labelColumn: String, featuresColumn: String, weightColumn: Option[String],
initScoreColumn: Option[String], groupColumn: Option[String],
validationData: Option[Broadcast[Array[Row]]],
log: Logger, trainParams: TrainParams, schema: StructType,
inputRows: Iterator[Row]): Iterator[LightGBMBooster] = {
val rows = inputRows.toArray
var trainDatasetPtr: Option[LightGBMDataset] = None
var validDatasetPtr: Option[LightGBMDataset] = None
try {
trainDatasetPtr = generateDataset(rows, labelColumn, featuresColumn,
weightColumn, initScoreColumn, groupColumn, None, schema, log, trainParams)
if (validationData.isDefined) {
validDatasetPtr = generateDataset(validationData.get.value, labelColumn,
featuresColumn, weightColumn, initScoreColumn, groupColumn, trainDatasetPtr,
schema, log, trainParams)
}
var boosterPtr: Option[SWIGTYPE_p_void] = None
try {
boosterPtr = createBooster(trainParams, trainDatasetPtr, validDatasetPtr)
trainCore(trainParams, boosterPtr, log, validDatasetPtr.isDefined)
val model = saveBoosterToString(boosterPtr, log)
List[LightGBMBooster](new LightGBMBooster(model)).toIterator
} finally {
// Free booster
boosterPtr.foreach { booster =>
LightGBMUtils.validate(lightgbmlib.LGBM_BoosterFree(booster), "Finalize Booster")
}
}
} finally {
// Free datasets
trainDatasetPtr.foreach(_.close())
validDatasetPtr.foreach(_.close())
}
}
private def findOpenPort(defaultListenPort: Int, numCoresPerExec: Int, log: Logger): Socket = {
val basePort = defaultListenPort + (LightGBMUtils.getId() * numCoresPerExec)
var localListenPort = basePort
var foundPort = false
var workerServerSocket: Socket = null
while (!foundPort) {
try {
workerServerSocket = new Socket()
workerServerSocket.bind(new InetSocketAddress(localListenPort))
foundPort = true
} catch {
case ex: IOException =>
log.warn(s"Could not bind to port $localListenPort...")
localListenPort += 1
if (localListenPort - basePort > 1000) {
throw new Exception("Error: Could not find open port after 1k tries")
}
}
}
log.info(s"Successfully bound to port $localListenPort")
workerServerSocket
}
def setFinishedStatus(networkParams: NetworkParams,
localListenPort: Int, log: Logger): Unit = {
using(new Socket(networkParams.addr, networkParams.port)) {
driverSocket =>
using(new BufferedWriter(new OutputStreamWriter(driverSocket.getOutputStream))) {
driverOutput =>
log.info("sending finished status to driver")
// If barrier execution mode enabled, create a barrier across tasks
driverOutput.write(s"${LightGBMConstants.FinishedStatus}\n")
driverOutput.flush()
}.get
}.get
}
def getNetworkInitNodes(networkParams: NetworkParams,
localListenPort: Int, log: Logger,
emptyPartition: Boolean): String = {
using(new Socket(networkParams.addr, networkParams.port)) {
driverSocket =>
using(Seq(new BufferedReader(new InputStreamReader(driverSocket.getInputStream)),
new BufferedWriter(new OutputStreamWriter(driverSocket.getOutputStream)))) {
io =>
val driverInput = io(0).asInstanceOf[BufferedReader]
val driverOutput = io(1).asInstanceOf[BufferedWriter]
val workerStatus =
if (emptyPartition) {
log.info("send empty status to driver")
LightGBMConstants.IgnoreStatus
} else {
val workerHost = driverSocket.getLocalAddress.getHostAddress
val workerInfo = s"$workerHost:$localListenPort"
log.info(s"send current worker info to driver: $workerInfo ")
workerInfo
}
// Send the current host:port to the driver
driverOutput.write(s"$workerStatus\n")
driverOutput.flush()
// If barrier execution mode enabled, create a barrier across tasks
if (networkParams.barrierExecutionMode) {
val context = BarrierTaskContext.get()
context.barrier()
if (context.partitionId() == 0) {
setFinishedStatus(networkParams, localListenPort, log)
}
}
if (workerStatus != LightGBMConstants.IgnoreStatus) {
// Wait to get the list of nodes from the driver
val nodes = driverInput.readLine()
log.info(s"LightGBM worker got nodes for network init: $nodes")
nodes
} else {
workerStatus
}
}.get
}.get
}
def networkInit(nodes: String, localListenPort: Int, log: Logger, retry: Int, delay: Long): Unit = {
try {
LightGBMUtils.validate(lightgbmlib.LGBM_NetworkInit(nodes, localListenPort,
LightGBMConstants.DefaultListenTimeout, nodes.split(",").length), "Network init")
} catch {
case ex @ (_: Exception | _: Throwable) => {
log.info(s"NetworkInit failed with exception on local port $localListenPort with exception: $ex")
Thread.sleep(delay)
if (retry > 0) {
log.info(s"Retrying NetworkInit with local port $localListenPort")
networkInit(nodes, localListenPort, log, retry - 1, delay * 2)
} else {
log.info(s"NetworkInit reached maximum exceptions on retry: $ex")
throw ex
}
}
}
}
def trainLightGBM(networkParams: NetworkParams, labelColumn: String, featuresColumn: String,
weightColumn: Option[String], initScoreColumn: Option[String], groupColumn: Option[String],
validationData: Option[Broadcast[Array[Row]]], log: Logger,
trainParams: TrainParams, numCoresPerExec: Int, schema: StructType)
(inputRows: Iterator[Row]): Iterator[LightGBMBooster] = {
val emptyPartition = !inputRows.hasNext
// Ideally we would start the socket connections in the C layer, this opens us up for
// race conditions in case other applications open sockets on cluster, but usually this
// should not be a problem
val (nodes, localListenPort) = using(findOpenPort(networkParams.defaultListenPort, numCoresPerExec, log)) {
openPort =>
val localListenPort = openPort.getLocalPort
// Initialize the native library
LightGBMUtils.initializeNativeLibrary()
log.info(s"LightGBM worker connecting to host: ${networkParams.addr} and port: ${networkParams.port}")
(getNetworkInitNodes(networkParams, localListenPort, log, emptyPartition), localListenPort)
}.get
if (emptyPartition) {
log.warn("LightGBM worker encountered empty partition, for best performance ensure no partitions empty")
List[LightGBMBooster]().toIterator
} else {
// Initialize the network communication
log.info(s"LightGBM worker listening on: $localListenPort")
try {
val retries = 3
val initialDelay = 1000L
networkInit(nodes, localListenPort, log, retries, initialDelay)
translate(labelColumn, featuresColumn, weightColumn, initScoreColumn, groupColumn, validationData,
log, trainParams, schema, inputRows)
} finally {
// Finalize network when done
LightGBMUtils.validate(lightgbmlib.LGBM_NetworkFree(), "Finalize network")
}
}
}
}