All Downloads are FREE. Search and download functionalities are using the official Maven repository.

icu.wuhufly.features.ml_result02.scala Maven / Gradle / Ivy

The newest version!
package icu.wuhufly.features

import icu.wuhufly.utils.{CreateUtils, WriteUtils}
import org.apache.spark.SparkContext
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{Normalizer, PolynomialExpansion, StandardScaler, VectorAssembler}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tuning.{CrossValidator, CrossValidatorModel, ParamGridBuilder}
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}

object ml_result02 {
  def main(args: Array[String]): Unit = {
    val spark: SparkSession = CreateUtils.getSpark()
    val sc: SparkContext = spark.sparkContext
    import spark.implicits._
    import org.apache.spark.sql.functions._

    val Array(train, test) = spark.read.table("dwd.fact_machine_learning_data")
      .randomSplit(Array(0.8,0.2), 42)

    test
      .select(test.columns.diff(Array("machine_record_state")).map(col):_*)
      .write.mode(SaveMode.Overwrite)
      .saveAsTable("dwd.fact_machine_learning_data_test")

    val rfclf: RandomForestClassifier = new RandomForestClassifier()
      .setSeed(42)
      .setLabelCol("machine_record_state")
      .setFeaturesCol("features")
      .setPredictionCol("prediction")
    val poly: PolynomialExpansion = new PolynomialExpansion()
      .setInputCol("poly_features")
      .setOutputCol("features")

    val pipeline: Pipeline = new Pipeline()
      .setStages(
        Array(
          new VectorAssembler()
            .setInputCols(Array("machine_id", "machine_record_state", "machine_record_mainshaft_speed", "machine_record_mainshaft_multiplerate", "machine_record_mainshaft_load", "machine_record_feed_speed", "machine_record_feed_multiplerate", "machine_record_pmc_code", "machine_record_circle_time", "machine_record_run_time", "machine_record_effective_shaft", "machine_record_amount_process", "machine_record_use_memory", "machine_record_free_memory", "machine_record_amount_use_code", "machine_record_amount_free_code"))
            .setOutputCol("norm_features"),
          new Normalizer()
            .setInputCol("norm_features")
            .setOutputCol("std_features"),
          new StandardScaler()
            .setInputCol("std_features")
            .setOutputCol("poly_features"),
          poly,
          rfclf
        )
      )

    val paramMaps: Array[ParamMap] = new ParamGridBuilder()
      .addGrid(rfclf.maxBins, Array(8, 16, 32, 48))
      .addGrid(rfclf.maxDepth, Array(5, 8, 12))
      .addGrid(poly.degree, Array(1, 2))
      .build()

    val cv: CrossValidator = new CrossValidator()
      .setEstimatorParamMaps(paramMaps)
      .setEvaluator(new BinaryClassificationEvaluator().setLabelCol("machine_record_state"))
      .setEstimator(pipeline)
      .setNumFolds(3)
      .setParallelism(10)


    val model: CrossValidatorModel = cv.fit(train)
    val resDF: DataFrame = model
      .transform(test)
      .selectExpr("machine_record_id", "prediction as machine_record_state")

    WriteUtils.writeToMysql(
      "ml_result", resDF
    )

    sc.stop()
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy