All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.ml.bundle.MultiInOutFormatSparkOp.scala Maven / Gradle / Ivy

The newest version!
package org.apache.spark.ml.bundle

import ml.combust.bundle.dsl.{Model, Value}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.ParamValidators
import org.apache.spark.ml.param.shared.{HasInputCol, HasInputCols, HasOutputCol, HasOutputCols}

trait MultiInOutFormatSparkOp[
  N <: Transformer with HasInputCol with HasInputCols with HasOutputCol with HasOutputCols
]{

  protected def saveMultiInOutFormat(model: Model, obj: N): Model = {
    ParamValidators.checkSingleVsMultiColumnParams(obj, Seq(obj.inputCol), Seq(obj.inputCols))
    ParamValidators.checkSingleVsMultiColumnParams(obj, Seq(obj.outputCol), Seq(obj.outputCols))
    val result = if(obj.isSet(obj.inputCols)) {
      model.withValue("input_cols", Value.stringList(obj.getInputCols))
    } else {
      model.withValue("input_col", Value.string(obj.getInputCol))
    }
    if (obj.isSet(obj.outputCols)) {
      result.withValue("output_cols", Value.stringList(obj.getOutputCols))
    } else {
      result.withValue("output_col", Value.string(obj.getOutputCol))
    }
  }

  protected def loadMultiInOutFormat(model: Model, obj: N): N = {
    val inputCol = model.getValue("input_col").map(_.getString)
    val inputCols = model.getValue("input_cols").map(_.getStringList)
    val outputCol = model.getValue("output_col").map(_.getString)
    val outputCols = model.getValue("output_cols").map(_.getStringList)
    val result: N = (inputCol, inputCols) match {
      case (None, None) => obj
      case (Some(col), None) => obj.set(obj.inputCol, col)
      case (None, Some(cols)) => obj.set(obj.inputCols, cols.toArray)
      case (_, _) => throw new UnsupportedOperationException("Cannot use both inputCol and inputCols")
    }
    (outputCol, outputCols) match {
      case (None, None) => obj
      case (Some(col), None) => result.set(result.outputCol, col)
      case (None, Some(cols)) => result.set(result.outputCols, cols.toArray)
      case (_, _) => throw new UnsupportedOperationException("Cannot use both outputCol and outputCols")
    }
  }

  def sparkInputs(obj: N): Seq[ParamSpec] = {
    if (obj.isSet(obj.inputCols)) {
      Seq(ParamSpec("input", obj.inputCols))
    } else{
      Seq(ParamSpec("input", obj.inputCol))
    }
  }

  def sparkOutputs(obj: N): Seq[ParamSpec] = {
    if (obj.isSet(obj.outputCols)) {
      Seq(ParamSpec("output", obj.outputCols))
    } else{
      Seq(ParamSpec("output", obj.outputCol))
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy