All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.h2o.sparkling.examples.ChicagoCrimeApp.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package ai.h2o.sparkling.examples

import java.io.File
import ai.h2o.sparkling.H2OContext
import ai.h2o.sparkling.ml.algos.{H2ODeepLearning, H2OGBM}
import ai.h2o.sparkling.ml.models.H2OMOJOModel
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.{DataFrame, SparkSession}

object ChicagoCrimeApp {

  private val seasonUdf = udf(monthToSeason _)
  private val weekendUdf = udf((isWeekend _).andThen(boolToInt))
  private val dayOfWeekUdf = udf(dayOfWeek _)

  def main(args: Array[String]) {
    implicit val spark = SparkSession
      .builder()
      .appName("Chicago Crime App")
      .getOrCreate()
    import spark.implicits._

    val weatherTable = loadCsv("./examples/smalldata/chicago/chicagoAllWeather.csv").drop("date")
    val chicagoWeatherTableName = "chicagoWeather"
    weatherTable.createOrReplaceTempView(chicagoWeatherTableName)

    val censusTable = loadCsv("./examples/smalldata/chicago/chicagoCensus.csv")
    val chicagoCensusTableName = "chicagoCensus"
    censusTable.createOrReplaceTempView(chicagoCensusTableName)

    val crimesTable = addAdditionalDateColumns(loadCsv("./examples/smalldata/chicago/chicagoCrimes10k.csv"))
    val chicagoCrimeTableName = "chicagoCrime"
    crimesTable.createOrReplaceTempView(chicagoCrimeTableName)

    val crimeDataColumnsForTraining = Seq(
      $"cr.Year",
      $"cr.Month",
      $"cr.Day",
      $"WeekNum",
      $"HourOfDay",
      $"Weekend",
      $"Season",
      $"WeekDay",
      $"IUCR",
      $"Primary_Type",
      $"Location_Description",
      $"Community_Area",
      $"District",
      $"Arrest",
      $"Domestic",
      $"Beat",
      $"Ward",
      $"FBI_Code")

    val censusDataColumnsForTraining = Seq(
      $"PERCENT_AGED_UNDER_18_OR_OVER_64",
      $"PER_CAPITA_INCOME",
      $"HARDSHIP_INDEX",
      $"PERCENT_OF_HOUSING_CROWDED",
      $"PERCENT_HOUSEHOLDS_BELOW_POVERTY",
      $"PERCENT_AGED_16_UNEMPLOYED",
      $"PERCENT_AGED_25_WITHOUT_HIGH_SCHOOL_DIPLOMA")

    val weatherDataColumnsForTraining = Seq($"minTemp", $"maxTemp", $"meanTemp")

    val joinedDataForTraining = spark
      .table(chicagoCrimeTableName)
      .as("cr")
      .join(
        spark.table(chicagoWeatherTableName).as("we"),
        $"cr.Year" === $"we.year" and $"cr.Month" === $"we.month" and $"cr.Day" === $"we.day")
      .join(spark.table(chicagoCensusTableName).as("ce"), $"cr.Community_Area" === $"ce.Community_Area_Number")
      .select(crimeDataColumnsForTraining ++ censusDataColumnsForTraining ++ weatherDataColumnsForTraining: _*)

    H2OContext.getOrCreate()
    val gbmModel = trainGBM(joinedDataForTraining)
    val dlModel = trainDeepLearning(joinedDataForTraining)

    val crimesToScore = Seq(
      CrimeWithCensusData(
        date = "02/08/2015 11:43:58 PM",
        IUCR = 1811,
        Primary_Type = "NARCOTICS",
        Location_Description = "STREET",
        Domestic = false,
        Beat = 422,
        District = 4,
        Ward = 7,
        Community_Area = 46,
        FBI_Code = 18,
        PERCENT_AGED_UNDER_18_OR_OVER_64 = 41.1,
        PER_CAPITA_INCOME = 16579,
        HARDSHIP_INDEX = 75,
        PERCENT_OF_HOUSING_CROWDED = 4.7,
        PERCENT_HOUSEHOLDS_BELOW_POVERTY = 29.8,
        PERCENT_AGED_16_UNEMPLOYED = 19.7,
        PERCENT_AGED_25_WITHOUT_HIGH_SCHOOL_DIPLOMA = 26.6),
      CrimeWithCensusData(
        date = "02/08/2015 11:00:39 PM",
        IUCR = 1150,
        Primary_Type = "DECEPTIVE PRACTICE",
        Location_Description = "RESIDENCE",
        Domestic = false,
        Beat = 923,
        District = 9,
        Ward = 14,
        Community_Area = 63,
        FBI_Code = 11,
        PERCENT_AGED_UNDER_18_OR_OVER_64 = 38.8,
        PER_CAPITA_INCOME = 12171,
        HARDSHIP_INDEX = 93,
        PERCENT_OF_HOUSING_CROWDED = 15.8,
        PERCENT_HOUSEHOLDS_BELOW_POVERTY = 23.4,
        PERCENT_AGED_16_UNEMPLOYED = 18.2,
        PERCENT_AGED_25_WITHOUT_HIGH_SCHOOL_DIPLOMA = 51.5)).toDF

    score(addAdditionalDateColumns(crimesToScore), gbmModel, dlModel)
  }

  private def trainGBM(train: DataFrame): H2OMOJOModel = {
    val gbm = new H2OGBM()
      .setSplitRatio(0.8)
      .setLabelCol("Arrest")
      .setColumnsToCategorical("Arrest")
      .setNtrees(10)
      .setMaxDepth(6)
      .setDistribution("bernoulli")
    gbm.fit(train)
  }

  private def trainDeepLearning(train: DataFrame): H2OMOJOModel = {
    val dl = new H2ODeepLearning()
      .setSplitRatio(0.8)
      .setLabelCol("Arrest")
      .setColumnsToCategorical("Arrest")
      .setEpochs(10)
      .setL1(0.0001)
      .setL2(0.0001)
      .setActivation("RectifierWithDropout")
      .setHidden(Array(200, 200))
    dl.fit(train)
  }

  private def score(crimes: DataFrame, gbmModel: H2OMOJOModel, dlModel: H2OMOJOModel)(
      implicit spark: SparkSession): Unit = {
    import spark.implicits._
    val arrestGBM = gbmModel.transform(crimes)
    val arrestDL = dlModel.transform(crimes)
    val willBeArrestedPrediction = $"prediction" === "1"
    println(s"""
           |  Will be arrested based on DeepLearning: ${arrestDL.where(willBeArrestedPrediction).count()}
           |  Will be arrested based on GBM: ${arrestGBM.where(willBeArrestedPrediction).count()}
           |
        """.stripMargin)
  }

  private def loadCsv(dataPath: String)(implicit spark: SparkSession): DataFrame = {
    val datafile = s"file://${new File(dataPath).getAbsolutePath}"
    val df = spark.read.option("header", "true").option("inferSchema", "true").csv(datafile)
    val renamedColumns = df.columns.map { col =>
      val name = col.trim
        .replace(' ', '_')
        .replace('+', '_')
        .replace("__", "_")
      df(col).as(name)
    }
    df.select(renamedColumns: _*)
  }

  private def addAdditionalDateColumns(df: DataFrame)(implicit spark: SparkSession): DataFrame = {
    import org.apache.spark.sql.functions._
    import spark.implicits._
    df.withColumn("DateTmp", from_unixtime(unix_timestamp('Date, "MM/dd/yyyy hh:mm:ss a")))
      .withColumn("Year", year('DateTmp))
      .withColumn("Month", month('DateTmp))
      .withColumn("Day", dayofmonth('DateTmp))
      .withColumn("WeekNum", weekofyear('DateTmp))
      .withColumn("HourOfDay", hour('DateTmp))
      .withColumn("Season", seasonUdf('Month))
      .withColumn("WeekDay", dayOfWeekUdf(date_format('DateTmp, "E")))
      .withColumn("Weekend", weekendUdf('WeekDay))
      .drop('DateTmp)
  }

  private def monthToSeason(month: Int): String = {
    if (month >= 3 && month <= 5) "Spring"
    else if (month >= 6 && month <= 8) "Summer"
    else if (month >= 9 && month <= 10) "Autumn"
    else "Winter"
  }

  private def isWeekend(dayOfWeek: Int): Boolean = dayOfWeek == 7 || dayOfWeek == 6

  private def boolToInt(bool: Boolean): Int = if (bool) 1 else 0

  private def dayOfWeek(day: String): Int = {
    day match {
      case "Mon" => 1
      case "Tue" => 2
      case "Wed" => 3
      case "Thu" => 4
      case "Fri" => 5
      case "Sat" => 6
      case "Sun" => 7
      case _ => throw new RuntimeException("Invalid day!")
    }
  }

  case class CrimeWithCensusData(
      date: String,
      IUCR: Short,
      Primary_Type: String,
      Location_Description: String,
      Domestic: Boolean,
      Beat: Short,
      District: Byte,
      Ward: Byte,
      Community_Area: Byte,
      FBI_Code: Byte,
      PERCENT_AGED_UNDER_18_OR_OVER_64: Double,
      PER_CAPITA_INCOME: Int,
      HARDSHIP_INDEX: Short,
      PERCENT_OF_HOUSING_CROWDED: Double,
      PERCENT_HOUSEHOLDS_BELOW_POVERTY: Double,
      PERCENT_AGED_16_UNEMPLOYED: Double,
      PERCENT_AGED_25_WITHOUT_HIGH_SCHOOL_DIPLOMA: Double)

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy