icu.wuhufly.features.ml_result02.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of shtd-bd Show documentation
Show all versions of shtd-bd Show documentation
bigdata source code for shtd
The newest version!
package icu.wuhufly.features
import icu.wuhufly.utils.{CreateUtils, WriteUtils}
import org.apache.spark.SparkContext
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{Normalizer, PolynomialExpansion, StandardScaler, VectorAssembler}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tuning.{CrossValidator, CrossValidatorModel, ParamGridBuilder}
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
object ml_result02 {
def main(args: Array[String]): Unit = {
val spark: SparkSession = CreateUtils.getSpark()
val sc: SparkContext = spark.sparkContext
import spark.implicits._
import org.apache.spark.sql.functions._
val Array(train, test) = spark.read.table("dwd.fact_machine_learning_data")
.randomSplit(Array(0.8,0.2), 42)
test
.select(test.columns.diff(Array("machine_record_state")).map(col):_*)
.write.mode(SaveMode.Overwrite)
.saveAsTable("dwd.fact_machine_learning_data_test")
val rfclf: RandomForestClassifier = new RandomForestClassifier()
.setSeed(42)
.setLabelCol("machine_record_state")
.setFeaturesCol("features")
.setPredictionCol("prediction")
val poly: PolynomialExpansion = new PolynomialExpansion()
.setInputCol("poly_features")
.setOutputCol("features")
val pipeline: Pipeline = new Pipeline()
.setStages(
Array(
new VectorAssembler()
.setInputCols(Array("machine_id", "machine_record_state", "machine_record_mainshaft_speed", "machine_record_mainshaft_multiplerate", "machine_record_mainshaft_load", "machine_record_feed_speed", "machine_record_feed_multiplerate", "machine_record_pmc_code", "machine_record_circle_time", "machine_record_run_time", "machine_record_effective_shaft", "machine_record_amount_process", "machine_record_use_memory", "machine_record_free_memory", "machine_record_amount_use_code", "machine_record_amount_free_code"))
.setOutputCol("norm_features"),
new Normalizer()
.setInputCol("norm_features")
.setOutputCol("std_features"),
new StandardScaler()
.setInputCol("std_features")
.setOutputCol("poly_features"),
poly,
rfclf
)
)
val paramMaps: Array[ParamMap] = new ParamGridBuilder()
.addGrid(rfclf.maxBins, Array(8, 16, 32, 48))
.addGrid(rfclf.maxDepth, Array(5, 8, 12))
.addGrid(poly.degree, Array(1, 2))
.build()
val cv: CrossValidator = new CrossValidator()
.setEstimatorParamMaps(paramMaps)
.setEvaluator(new BinaryClassificationEvaluator().setLabelCol("machine_record_state"))
.setEstimator(pipeline)
.setNumFolds(3)
.setParallelism(10)
val model: CrossValidatorModel = cv.fit(train)
val resDF: DataFrame = model
.transform(test)
.selectExpr("machine_record_id", "prediction as machine_record_state")
WriteUtils.writeToMysql(
"ml_result", resDF
)
sc.stop()
}
}