All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.databricks.labs.automl.ensemble.tuner.EnsembleTunerSplits.scala Maven / Gradle / Ivy

package com.databricks.labs.automl.ensemble.tuner

import com.databricks.labs.automl.ensemble.setting.StackingEnsembleSettings
import com.databricks.labs.automl.executor.config.{ConfigurationGenerator, InstanceConfigValidation}
import com.databricks.labs.automl.model.tools.structures.{TrainSplitReferences, TrainTestData, TrainTestPaths}

class EnsembleTunerSplits {

  private var metaLearnerSplits: Option[Array[TrainSplitReferences]] = None
  private var weakLearnerSplits: Option[Array[TrainSplitReferences]] = None

  def getWeakLearnersSplits(stackingEnsembleSettings: StackingEnsembleSettings): Array[TrainSplitReferences] = {
    if(weakLearnerSplits.isEmpty) {
      val weakConfig = ConfigurationGenerator.generateMainConfig(stackingEnsembleSettings.weakLearnersConfigs(0))
      val metaSplits = getMetaLearnersSplits(Some(stackingEnsembleSettings))
      if(weakConfig.geneticConfig.kFold > 1) {
        val weakSplits = metaSplits.map(item => {
          val n = TunerUtils.buildSplitTrainTestData(weakConfig, metaSplits(0).data.train, Some(1))(0)
          TrainSplitReferences(
            item.kIndex,
            TrainTestData(
              n.data.train,
              item.data.test),
            TrainTestPaths(n.paths.train, item.paths.test)
          )
        })
        weakLearnerSplits = Some(weakSplits)
      } else {
        val weakSplits = TunerUtils
          .buildSplitTrainTestData(weakConfig, metaSplits(0).data.train)
          .map(item =>
            TrainSplitReferences(
              item.kIndex,
              TrainTestData(item.data.train, metaSplits(0).data.test),
              item.paths
            )
          )
        weakLearnerSplits = Some(weakSplits)
      }
    }
    weakLearnerSplits.get
  }

  def getMetaLearnersSplits(stackingEnsembleSettings: Option[StackingEnsembleSettings] = None): Array[TrainSplitReferences] = {
    if(metaLearnerSplits.isEmpty) {
      val metaConfig = ConfigurationGenerator.generateMainConfig(stackingEnsembleSettings.get.metaLearnerConfig.get)
      val metaSplits =  TunerUtils.buildSplitTrainTestData(metaConfig, stackingEnsembleSettings.get.inputData)
      metaLearnerSplits = Some(metaSplits)
    }
    metaLearnerSplits.get
  }

}
object EnsembleTunerSplits {
  def apply(): EnsembleTunerSplits = new EnsembleTunerSplits()
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy