All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.lightgbm.LightGBMDataset.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.lightgbm

import com.microsoft.ml.lightgbm._

/** Represents a LightGBM dataset.
  * Wraps the native implementation.
  * @param dataset The native representation of the dataset.
  */
class LightGBMDataset(val dataset: SWIGTYPE_p_void) extends AutoCloseable {
  var featureNames: Option[SWIGTYPE_p_p_char] = None
  var featureNamesOpt2: Option[Array[String]] = None

  def validateDataset(): Unit = {
    // Validate num rows
    val numDataPtr = lightgbmlib.new_intp()
    LightGBMUtils.validate(lightgbmlib.LGBM_DatasetGetNumData(dataset, numDataPtr), "DatasetGetNumData")
    val numData = lightgbmlib.intp_value(numDataPtr)
    if (numData <= 0) {
      throw new Exception("Unexpected num data: " + numData)
    }

    // Validate num cols
    val numFeaturePtr = lightgbmlib.new_intp()
    LightGBMUtils.validate(
      lightgbmlib.LGBM_DatasetGetNumFeature(dataset, numFeaturePtr),
      "DatasetGetNumFeature")
    val numFeature = lightgbmlib.intp_value(numFeaturePtr)
    if (numFeature <= 0) {
      throw new Exception("Unexpected num feature: " + numFeature)
    }
  }

  def addFloatField(field: Array[Double], fieldName: String, numRows: Int): Unit = {
    // Generate the column and add to dataset
    var colArray: Option[SWIGTYPE_p_float] = None
    try {
      colArray = Some(lightgbmlib.new_floatArray(numRows))
      field.zipWithIndex.foreach(ri =>
        lightgbmlib.floatArray_setitem(colArray.get, ri._2, ri._1.toFloat))
      val colAsVoidPtr = lightgbmlib.float_to_voidp_ptr(colArray.get)
      val data32bitType = lightgbmlibConstants.C_API_DTYPE_FLOAT32
      LightGBMUtils.validate(
        lightgbmlib.LGBM_DatasetSetField(dataset, fieldName, colAsVoidPtr, numRows, data32bitType),
        "DatasetSetField")
    } finally {
      // Free column
      colArray.foreach(lightgbmlib.delete_floatArray(_))
    }
  }

  def addDoubleField(field: Array[Double], fieldName: String, numRows: Int): Unit = {
    // Generate the column and add to dataset
    var colArray: Option[SWIGTYPE_p_double] = None
    try {
      colArray = Some(lightgbmlib.new_doubleArray(numRows))
      field.zipWithIndex.foreach(ri =>
        lightgbmlib.doubleArray_setitem(colArray.get, ri._2, ri._1))
      val colAsVoidPtr = lightgbmlib.double_to_voidp_ptr(colArray.get)
      val data64bitType = lightgbmlibConstants.C_API_DTYPE_FLOAT64
      LightGBMUtils.validate(
        lightgbmlib.LGBM_DatasetSetField(dataset, fieldName, colAsVoidPtr, numRows, data64bitType),
        "DatasetSetField")
    } finally {
      // Free column
      colArray.foreach(lightgbmlib.delete_doubleArray(_))
    }
  }

  def addIntField(field: Array[Int], fieldName: String, numRows: Int): Unit = {
    // Generate the column and add to dataset
    var colArray: Option[SWIGTYPE_p_int] = None
    try {
      colArray = Some(lightgbmlib.new_intArray(numRows))
      field.zipWithIndex.foreach(ri =>
        lightgbmlib.intArray_setitem(colArray.get, ri._2, ri._1))
      val colAsVoidPtr = lightgbmlib.int_to_voidp_ptr(colArray.get)
      val data32bitType = lightgbmlibConstants.C_API_DTYPE_INT32
      LightGBMUtils.validate(
        lightgbmlib.LGBM_DatasetSetField(dataset, fieldName, colAsVoidPtr, numRows, data32bitType),
        "DatasetSetField")
    } finally {
      // Free column
      colArray.foreach(lightgbmlib.delete_intArray(_))
    }
  }

  def setFeatureNames(featureNamesOpt: Option[Array[String]], numCols: Int): Unit = {
    // Add in slot names if they exist
    featureNamesOpt.foreach { featureNamesVal =>
      if (featureNamesVal.nonEmpty) {
        LightGBMUtils.validate(lightgbmlib.LGBM_DatasetSetFeatureNames(dataset, featureNamesVal, numCols),
          "Dataset set feature names")
      }
    }
  }

  override def close(): Unit = {
    // Free dataset
    LightGBMUtils.validate(lightgbmlib.LGBM_DatasetFree(dataset), "Finalize Dataset")
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy