All Downloads are FREE. Search and download functionalities are using the official Maven repository.

util.Util.scala Maven / Gradle / Ivy

/** Copyright 2014 TappingStone, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * You may obtain a copy of the License at
  *
  *     http://www.apache.org/licenses/LICENSE-2.0
  *
  * Unless required by applicable law or agreed to in writing, software
  * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */

package io.prediction.engines.util

import io.prediction.controller.NiceRendering

import org.apache.mahout.cf.taste.model.DataModel
import org.apache.mahout.cf.taste.model.Preference
import org.apache.mahout.cf.taste.model.PreferenceArray
import org.apache.mahout.cf.taste.impl.model.GenericDataModel
import org.apache.mahout.cf.taste.impl.model.GenericBooleanPrefDataModel
import org.apache.mahout.cf.taste.impl.model.GenericPreference
import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray
import org.apache.mahout.cf.taste.impl.common.FastByIDMap
import org.apache.mahout.cf.taste.impl.common.FastIDSet

import scala.collection.JavaConversions._
import scala.collection.JavaConversions.asScalaBuffer
import scala.collection.JavaConversions.asScalaSet
import java.util.{ List => JList }
import java.util.{ Set => JSet }
import java.lang.{ Integer => JInteger }
import java.lang.{ Float => JFloat }
import java.lang.{ Long => JLong }

import grizzled.slf4j.Logger
import java.io.FileOutputStream
import java.io.ObjectOutputStream
import java.io.FileInputStream
import java.io.ObjectInputStream

import scala.io.Source
import java.io.PrintWriter
import java.io.File

/** Mahout Integration helper functions */
object MahoutUtil {

  val logger = Logger(MahoutUtil.getClass)

  /** Java version of buildDataModel */
  def jBuildDataModel(ratingSeq: JList[Tuple4[JInteger, JInteger, JFloat, JLong]]): DataModel = {
    buildDataModel(asScalaBuffer(ratingSeq).toList.asInstanceOf[List[(Int, Int, Float, Long)]])
  }

  def jBuildBooleanPrefDataModel(ratingSeq: JList[Tuple3[JInteger, JInteger, JLong]]): DataModel = {
    buildBooleanPrefDataModel(asScalaBuffer(ratingSeq).toList.asInstanceOf[List[(Int, Int, Long)]])
  }

  /** Build DataModel with Seq of (uid, iid, rating, timestamp)
   *  NOTE: assume no duplicated rating on same iid by the same user
   */
  def buildDataModel(
    ratingSeq: Seq[(Int, Int, Float, Long)]): DataModel = {

    val allPrefs = new FastByIDMap[PreferenceArray]()
    val allTimestamps = new FastByIDMap[FastByIDMap[java.lang.Long]]()

    ratingSeq.groupBy(_._1)
      .foreach { case (uid, ratingList) =>
        val userID = uid.toLong
        // preference of items for this user
        val userPrefs = new GenericUserPreferenceArray(ratingList.size)
        // timestamp of items for this user
        val userTimestamps = new FastByIDMap[java.lang.Long]()

        ratingList.zipWithIndex
          .foreach { case (r, i) =>
            val itemID = r._2.toLong
            val pref = new GenericPreference(userID, itemID, r._3)
            userPrefs.set(i, pref)
            userTimestamps.put(itemID, r._4)
        }

        allPrefs.put(userID, userPrefs)
        allTimestamps.put(userID, userTimestamps)
      }

    new GenericDataModel(allPrefs, allTimestamps)
  }

  /** Build DataModel with Seq of (uid, iid, timestamp)
   *  NOTE: assume no duplicated iid by the same user
   */
  def buildBooleanPrefDataModel(
    ratingSeq: Seq[(Int, Int, Long)]): DataModel = {

    val allPrefs = new FastByIDMap[FastIDSet]()
    val allTimestamps = new FastByIDMap[FastByIDMap[java.lang.Long]]()

    ratingSeq.foreach { case (uid, iid, t) =>
      val userID = uid.toLong
      val itemID = iid.toLong

      // item
      val idSet = allPrefs.get(userID)
      if (idSet == null) {
        val newIdSet = new FastIDSet()
        newIdSet.add(itemID)
        allPrefs.put(userID, newIdSet)
      } else {
        idSet.add(itemID)
      }
      // timestamp
      val timestamps = allTimestamps.get(userID)
      if (timestamps == null) {
        val newTimestamps = new FastByIDMap[java.lang.Long]
        newTimestamps.put(itemID, t)
        allTimestamps.put(userID, newTimestamps)
      } else {
        timestamps.put(itemID, t)
      }
    }

    new GenericBooleanPrefDataModel(allPrefs, allTimestamps)
  }

}


/** Math helper functions */
object MathUtil {

  /** Average precision at k */
  def averagePrecisionAtK[T](k: Int, p: Seq[T], r: Set[T]): Double = {
    // supposedly the predictedItems.size should match k
    // NOTE: what if predictedItems is less than k? use the avaiable items as k.
    val n = scala.math.min(p.size, k)

    // find if each element in the predictedItems is one of the relevant items
    // if so, map to 1. else map to 0
    // (0, 1, 0, 1, 1, 0, 0)
    val rBin: Seq[Int] = p.take(n).map { x => if (r(x)) 1 else 0 }
    val pAtKNom = rBin.scanLeft(0)(_ + _)
      .drop(1) // drop 1st one which is initial 0
      .zip(rBin)
      .map(t => if (t._2 != 0) t._1.toDouble else 0.0)
    // ( number of hits at this position if hit or 0 if miss )

    val pAtKDenom = 1 to rBin.size
    val pAtK = pAtKNom.zip(pAtKDenom).map { t => t._1 / t._2 }
    val apAtKDenom = scala.math.min(n, r.size)
    if (apAtKDenom == 0) 0 else pAtK.sum / apAtKDenom
  }

  /** Java's Average precision at k */
  def jAveragePrecisionAtK[T](k: Integer, p: JList[T], r: JSet[T]): Double = {
    averagePrecisionAtK(k, asScalaBuffer[T](p).toList, asScalaSet[T](r).toSet)
  }
}

object EvaluatorVisualization {
  class ObjectInputStreamWithCustomClassLoader(
    fileInputStream: FileInputStream
  ) extends ObjectInputStream(fileInputStream) {
    override def resolveClass(desc: java.io.ObjectStreamClass): Class[_] = {
      try { Class.forName(desc.getName, false, getClass.getClassLoader) }
      catch { case ex: ClassNotFoundException => super.resolveClass(desc) }
    }
  }

  def save[T](data: T, path: String) {
    println(s"Output to: $path")
    val oos = new ObjectOutputStream(new FileOutputStream(path))
    oos.writeObject(data)
    oos.close()
  }

  def load[T](path: String): T = {
    val ois = new ObjectInputStreamWithCustomClassLoader(new FileInputStream(path))
    val obj = ois.readObject().asInstanceOf[T]
    ois.close
    return obj
  }

  def render[T <: NiceRendering](data: NiceRendering, path: String) {
    val htmlPath = s"${path}.html"
    println(s"OutputPath: $htmlPath")
    val dataClass = data.getClass

    val htmlWriter = new PrintWriter(new File(htmlPath))
    val html = dataClass.getMethod("toHTML").invoke(data).asInstanceOf[String]
    htmlWriter.write(html)
    htmlWriter.close()

    val jsonPath = s"${path}.json"
    val jsonWriter = new PrintWriter(new File(jsonPath))
    val json = dataClass.getMethod("toJSON").invoke(data).asInstanceOf[String]
    jsonWriter.write(json)
    jsonWriter.close()
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy