All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.h2o.sparkling.ml.algos.H2OAlgoCommonUtils.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package ai.h2o.sparkling.ml.algos

import ai.h2o.sparkling.backend.utils.H2OFrameLifecycle
import ai.h2o.sparkling.ml.models.H2OBinaryModel
import ai.h2o.sparkling.ml.utils.EstimatorCommonUtils
import ai.h2o.sparkling.{H2OContext, H2OFrame}
import org.apache.spark.sql.{Column, DataFrame, Dataset}
import org.apache.spark.sql.functions.col

trait H2OAlgoCommonUtils extends EstimatorCommonUtils with H2OFrameLifecycle {

  protected var binaryModel: Option[H2OBinaryModel] = None

  def getBinaryModel(): H2OBinaryModel = {
    if (binaryModel.isEmpty) {
      throw new IllegalArgumentException(
        "Algorithm needs to be fit first with the `keepBinaryModels` parameter " +
          "set to true in order to access binary model.")
    }
    binaryModel.get
  }

  private[sparkling] def getExcludedCols(): Seq[String] = Seq.empty

  /** The list of additional columns that needs to be send to H2O-3 backend for model training. */
  private[sparkling] def getAdditionalCols(): Seq[String] = Seq.empty

  /** The list of additional columns that needs to be send to H2O-3 backend for model validation. */
  private[sparkling] def getAdditionalValidationCols(): Seq[String] = Seq.empty

  private[sparkling] def getInputCols(): Array[String]

  private[sparkling] def getColumnsToCategorical(): Array[String]

  /** List of columns to convert to string before modelling. */
  private[sparkling] def getColumnsToString(): Array[String] = Array.empty

  private[sparkling] def getSplitRatio(): Double

  private[sparkling] def setInputCols(value: Array[String]): this.type

  private[sparkling] def getValidationDataFrame(): DataFrame

  private[sparkling] def prepareDatasetForFitting(dataset: Dataset[_]): (H2OFrame, Option[H2OFrame]) = {
    val excludedCols = getExcludedCols()

    if (getInputCols().isEmpty) {
      val inputs = dataset.columns.filter(c => excludedCols.forall(e => c.compareToIgnoreCase(e) != 0))
      setInputCols(inputs)
    } else {
      val missingColumns = getInputCols()
        .filterNot(col => dataset.columns.contains(col))

      if (missingColumns.nonEmpty) {
        throw new IllegalArgumentException(
          "The following feature columns are not available on" +
            s" the training dataset: '${missingColumns.mkString(", ")}'")
      }
    }

    val featureColumns = getInputCols().map(sanitize).map(col)

    val excludedColumns = excludedCols.map(sanitize).map(col)
    val additionalColumns = getAdditionalCols().map(sanitize).map(col)
    val columns = (featureColumns ++ excludedColumns ++ additionalColumns).distinct
    val h2oContext = H2OContext.ensure(
      "H2OContext needs to be created in order to train the model. Please create one as H2OContext.getOrCreate().")
    val trainFrame = h2oContext.asH2OFrame(dataset.select(columns: _*).toDF(), getInputCols())

    trainFrame.convertColumnsToStrings(getColumnsToString())

    // Our MOJO wrapper needs the full column name before the array/vector expansion in order to do predictions
    trainFrame.convertColumnsToCategorical(getColumnsToCategorical())

    val validationDataFrame = getValidationDataFrame()
    val (resultTrainFrame, resultTestFrame) = if (validationDataFrame != null) {
      val additionalValidationColumns = getAdditionalValidationCols().map(sanitize).map(col)
      val validationColumns = (columns ++ additionalValidationColumns).distinct
      val validationFrame = h2oContext.asH2OFrame(validationDataFrame.select(validationColumns: _*))
      (trainFrame, Some(validationFrame))
    } else if (getSplitRatio() < 1.0) {
      val frames = trainFrame.split(getSplitRatio())
      if (frames.length > 1) {
        (frames(0), Some(frames(1)))
      } else {
        (frames(0), None)
      }
    } else {
      (trainFrame, None)
    }

    registerH2OFrameForDeletion(resultTrainFrame)
    registerH2OFrameForDeletion(resultTestFrame)

    (resultTrainFrame, resultTestFrame)
  }

  def sanitize(colName: String) = '`' + colName + '`'
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy