All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.examples.h2o.AirlinesWithWeatherDemo2.scala Maven / Gradle / Ivy

There is a newer version: 1.6.8
Show newest version
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*    http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.examples.h2o

import java.io.File

import hex.FrameSplitter
import hex.deeplearning.DeepLearning
import hex.deeplearning.DeepLearningModel.DeepLearningParameters
import hex.deeplearning.DeepLearningModel.DeepLearningParameters.Activation
import hex.tree.gbm.GBM
import hex.tree.gbm.GBMModel.GBMParameters
import org.apache.spark.h2o.{H2OContext, H2OFrame}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext, SparkFiles}
import water.Key
import water.fvec.Frame
import water.support.SparkContextSupport

/** Demo for meetup presented at 12/17/2014 */
object AirlinesWithWeatherDemo2 extends SparkContextSupport {

  def residualPlotRCode(prediction:Frame, predCol: String, actual:Frame, actCol:String, h2oContext: H2OContext = null):String = {
    val (ip, port) = if (h2oContext != null) {
      val s = h2oContext.h2oLocalClient.split(":")
      (s(0), s(1))
    } else {
      ("127.0.0.1", "54321")
    }

    s"""# R script for residual plot
        |library(h2o)
        |h = h2o.init(ip="${ip}", port=${port})
        |
        |pred = h2o.getFrame(h, "${prediction._key}")
        |act = h2o.getFrame (h, "${actual._key}")
        |
        |predDelay = pred$$${predCol}
        |actDelay = act$$${actCol}
        |
        |nrow(actDelay) == nrow(predDelay)
        |
        |residuals = predDelay - actDelay
        |
        |compare = cbind (as.data.frame(actDelay$$ArrDelay), as.data.frame(residuals$$predict))
        |nrow(compare)
        |plot( compare[,1:2] )
        |
      """.stripMargin
  }

  def main(args: Array[String]): Unit = {
    // Configure this application
    val conf: SparkConf = configure("Sparkling Water Meetup: Use Airlines and Weather Data for delay prediction")
    // Create SparkContext to execute application on Spark cluster
    val sc = new SparkContext(conf)
    val h2oContext = H2OContext.getOrCreate(sc)
    import h2oContext._
    import h2oContext.implicits._
    // Setup environment
    addFiles(sc,
      absPath("examples/smalldata/Chicago_Ohare_International_Airport.csv"),
      absPath("examples/smalldata/year2005.csv.gz"))

    //val weatherDataFile = "examples/smalldata/Chicago_Ohare_International_Airport.csv"
    val wrawdata = sc.textFile(enforceLocalSparkFile("Chicago_Ohare_International_Airport.csv"),3).cache()
    val weatherTable = wrawdata.map(_.split(",")).map(row => WeatherParse(row)).filter(!_.isWrongRow())

    //
    // Load H2O from CSV file (i.e., access directly H2O cloud)
    // Use super-fast advanced H2O CSV parser !!!
    val airlinesData = new H2OFrame(new File(SparkFiles.get("year2005.csv.gz")))

    val airlinesTable : RDD[Airlines] = asRDD[Airlines](airlinesData)
    // Select flights only to ORD
    val flightsToORD = airlinesTable.filter(f => f.Dest==Some("ORD"))

    flightsToORD.count
    println(s"\nFlights to ORD: ${flightsToORD.count}\n")

    implicit val sqlContext = new SQLContext(sc)
    import sqlContext.implicits._ // import implicit conversions
    flightsToORD.toDF.registerTempTable("FlightsToORD")
    weatherTable.toDF.registerTempTable("WeatherORD")

    //
    // -- Join both tables and select interesting columns
    //
    val joinedTable = sqlContext.sql(
      """SELECT
        |f.Year,f.Month,f.DayofMonth,
        |f.CRSDepTime,f.CRSArrTime,f.CRSElapsedTime,
        |f.UniqueCarrier,f.FlightNum,f.TailNum,
        |f.Origin,f.Distance,
        |w.TmaxF,w.TminF,w.TmeanF,w.PrcpIn,w.SnowIn,w.CDD,w.HDD,w.GDD,
        |f.ArrDelay
        |FROM FlightsToORD f
        |JOIN WeatherORD w
        |ON f.Year=w.Year AND f.Month=w.Month AND f.DayofMonth=w.Day""".stripMargin)
    println(s"\nResult of query: ${joinedTable.count}\n")

    //
    // Split data into 3 tables - train/validation/test
    //
    // Instead of using RDD API we will directly split H2O Frame
    val joinedH2OFrame:H2OFrame = joinedTable // Invoke implicit transformation
    // Transform date related columns to enums
    for( i <- 0 to 2) joinedH2OFrame.replace(i, joinedH2OFrame.vec(i).toCategoricalVec)

    //
    // Use low-level task to split the frame
    val sf = new FrameSplitter(joinedH2OFrame, Array(.7, .2), Array("train", "valid","test").map(Key.make[Frame](_)), null)
    water.H2O.submitTask(sf)
    val splits = sf.getResult

    val trainTable = splits(0)
    val validTable = splits(1)
    val testTable  = splits(2)

    //
    // -- Run DeepLearning
    //
    val dlParams = new DeepLearningParameters()
    dlParams._train = trainTable
    dlParams._response_column = 'ArrDelay
    dlParams._valid = validTable
    dlParams._epochs = 5
    dlParams._activation = Activation.RectifierWithDropout
    dlParams._hidden = Array[Int](100, 100)
    dlParams._reproducible = true
    dlParams._force_load_balance = false

    val dl = new DeepLearning(dlParams)
    val dlModel = dl.trainModel.get

    val dlPredictTable = dlModel.score(testTable)('predict)
    val predictionsFromDlModel = asDataFrame(dlPredictTable).collect
                                .map(row => if (row.isNullAt(0)) Double.NaN else row(0))

    println(predictionsFromDlModel.length)
    println(predictionsFromDlModel.mkString("\n===> Model predictions: ", ", ", ", ...\n"))

    printf( residualPlotRCode(dlPredictTable, 'predict, testTable, 'ArrDelay) )

    // GBM Model
    val gbmParams = new GBMParameters()
    gbmParams._train = trainTable
    gbmParams._response_column = 'ArrDelay
    gbmParams._valid = validTable
    gbmParams._ntrees = 100
    //gbmParams._learn_rate = 0.01f
    val gbm = new GBM(gbmParams)
    val gbmModel = gbm.trainModel.get

    // Print R code for residual plot
    val gbmPredictTable = gbmModel.score(testTable)('predict)
    printf( residualPlotRCode(gbmPredictTable, 'predict, testTable, 'ArrDelay) )

    // Shutdown Spark cluster and H2O
    h2oContext.stop(stopSparkContext = true)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy