All Downloads are FREE. Search and download functionalities are using the official Maven repository.

tech.mlsql.plugins.ets.TableRepartition.scala Maven / Gradle / Ivy

The newest version!
package tech.mlsql.plugins.ets

import org.apache.spark.ml.param.{IntParam, Param}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.mlsql.session.MLSQLException
import org.apache.spark.sql.{DataFrame, SparkSession, functions => F}
import streaming.dsl.auth.TableAuthResult
import streaming.dsl.mmlib._
import streaming.dsl.mmlib.algs.param.WowParams
import tech.mlsql.dsl.auth.ETAuth
import tech.mlsql.dsl.auth.dsl.mmlib.ETMethod.ETMethod
import tech.mlsql.version.VersionCompatibility


class TableRepartition(override val uid: String) extends SQLAlg with VersionCompatibility  with WowParams with ETAuth {
  def this() = this("tech.mlsql.plugins.ets.TableRepartition")

  // 
  override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {

    params.get(partitionNum.name).map { item =>
      set(partitionNum, item.toInt)
      item
    }.getOrElse {
      throw new MLSQLException(s"${partitionNum.name} is required")
    }

    params.get(partitionType.name).map { item =>
      set(partitionType, item)
      item
    }.getOrElse {
      set(partitionType, "hash")
    }

    params.get(partitionCols.name).map { item =>
      set(partitionCols, item)
      item
    }.getOrElse {
      set(partitionCols, "")
    }

    $(partitionType) match {
      case "range" =>

        require(params.contains(partitionCols.name), "At least one partition-by expression must be specified.")
        df.repartitionByRange($(partitionNum), $(partitionCols).split(",").map(name => F.col(name)): _*)

      case _ =>
        df.repartition($(partitionNum))
    }


  }

  override def auth(etMethod: ETMethod, path: String, params: Map[String, String]): List[TableAuthResult] = {
    List()
  }

  override def supportedVersions: Seq[String] = {
    Seq("1.5.0-SNAPSHOT", "1.5.0", "1.6.0-SNAPSHOT", "1.6.0")
  }


  override def doc: Doc = Doc(MarkDownDoc,
    s"""
       |
    """.stripMargin)


  override def codeExample: Code = Code(SQLCode,
    """
      |
    """.stripMargin)

  override def batchPredict(df: DataFrame, path: String, params: Map[String, String]): DataFrame = train(df, path, params)

  override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = ???

  override def predict(sparkSession: SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = ???

  final val partitionNum: IntParam = new IntParam(this, "partitionNum",
    "")
  final val partitionType: Param[String] = new Param[String](this, "partitionType",
    "")

  final val partitionCols: Param[String] = new Param[String](this, "partitionCols",
    "")

  override def explainParams(sparkSession: SparkSession): DataFrame = _explainParams(sparkSession)

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy