All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.examples.h2o.AmazonFineFood.scala Maven / Gradle / Ivy

There is a newer version: 1.6.8
Show newest version
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*    http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.examples.h2o

import org.apache.spark
import org.apache.spark.h2o._
import org.apache.spark.ml.feature.StopWords
import org.apache.spark.mllib.feature.{HashingTF, IDF}
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.{SparkConf, SparkContext}
import org.joda.time.{DateTimeZone, MutableDateTime}
import water.MRTask
import water.app.ModelMetricsSupport
import water.support.SparkContextSupport
import water.fvec._

/**
  * Created by michal on 6/5/16.
  */
object AmazonFineFood extends SparkContextSupport with ModelMetricsSupport {
  def main(args: Array[String]): Unit = {
    val conf: SparkConf = configure("Amazon Fine Food Review Sentiment Analysis")
    val sc = new SparkContext(conf)

    implicit val sqlContext = SQLContext.getOrCreate(sc)
    @transient val hc = H2OContext.getOrCreate(sc)

    val reviews = new H2OFrame(new java.io.File("/Users/michal/Tmp/amazon-fine-foods/Reviews.csv"))

    // We do not need redundant data
    reviews.remove("Id").remove
    reviews.remove("ProfileName").remove

    // Update in DKV
    reviews.update


    val refinedTime = new RefineTime().doAll(5, Vec.T_NUM, reviews).outputFrame(Array("Day", "Month", "Year", "DayOfWeek", "Hour"), null /* no domains */)
    reviews.add(refinedTime).remove("Time").remove()
    reviews.update
    // NOTE: hour is not useful

    val df = hc.asDataFrame(reviews)
    df.printSchema()

    import org.apache.spark.sql.functions._
    val avgScorePerYear = hc.asH2OFrame(df.groupBy("Year").agg(mean("Score"), count("Score")), "avgScorePerYear")
    val avgScorePerMonth = hc.asH2OFrame(df.groupBy("Month").agg(mean("Score"), count("Score")), "avgScorePerMonth")
    val avgScorePerDay = hc.asH2OFrame(df.groupBy("DayOfWeek").agg(mean("Score"), count("Score")), "avgScorePerDay")

    // Input for sentiment analysis
    val sentimentDF = hc.asDataFrame(reviews('Score, 'Month, 'Day, 'DayOfWeek, 'Summary))

    // Transform Score to binary +/- feature - skip neutral reviews
    val toBinaryScore = udf { score: Byte => if (score < 3.toByte) "negative" else "positive" }

    val toTokens = udf { summary: String =>
      summary.split(",")
        .map(v => v.trim.toLowerCase.replaceAll("[^\\p{IsAlphabetic}]", ""))
          .filter(v => !StopWords.English.contains(v))
    }

    val hashingTF = new HashingTF(4096) // Larger space?
    val toNumericFeatures = udf { terms: Seq[_] => hashingTF.transform(terms) }

    // Skip all neutral reviews
    val vectorizedFrame: DataFrame = sentimentDF.where("Score != 3")
      .withColumn("Score", toBinaryScore(col("Score")))
      .withColumn("Summary", toNumericFeatures(toTokens(col("Summary"))))

    val idfModel = new IDF(minDocFreq = 1).fit(vectorizedFrame.select("Summary").map { case Row(v: spark.mllib.linalg.Vector) => v})
    val toIdf = udf { vector: spark.mllib.linalg.Vector => idfModel.transform(vector)}
    val finalFrame: DataFrame = vectorizedFrame.withColumn("Summary", toIdf(col("Summary")))
    finalFrame.printSchema()

    val p = hc.asH2OFrame(finalFrame, "finalFrame")
    // Cleanup
    reviews.delete()

    // RUN GLM or SVM

    // Create a predictor function
    println("DONE")
  }
}

class RefineTime extends MRTask[RefineTime] {
  override def map(in: Array[Chunk], out: Array[NewChunk]): Unit = {
    val mdt = new MutableDateTime(DateTimeZone.UTC)
    val timeCol = in(5 /* Index of Time column*/)
    val (dayNC, monthNC, yearNC, dayOfWeekNC, hourNC) = (out(0), out(1), out(2), out(3), out(4))
    for (row <- 0 until timeCol._len) {
      val time = timeCol.at8(row) * 1000 /* JODA API expect millis seconds from epoch */
      mdt.setMillis(time)
      dayNC.addNum(mdt.getDayOfMonth, 0)
      monthNC.addNum(mdt.getMonthOfYear, 0)
      yearNC.addNum(mdt.getYear, 0)
      dayOfWeekNC.addNum(mdt.getDayOfWeek, 0)
      hourNC.addNum(mdt.getHourOfDay, 0)
    }
  }
}

object H2OStopWords {
  val English = StopWords.English
}

/*
plot (g) -> g(
  g.point(
    g.position "Year", "avg(Score)"
    g.fillColor "count(Score)"
    g.size "count(Score)"
  )
  g.from inspect "data", getFrame "avgScorePerYear"
)
 */




© 2015 - 2025 Weber Informatics LLC | Privacy Policy