All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.classifier.naivebayes.SparkNaiveBayes.scala Maven / Gradle / Ivy

There is a newer version: 0.13.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.classifier.naivebayes

import org.apache.mahout.classifier.stats.{ClassifierResult, ResultAnalyzer}
import org.apache.mahout.math._

import org.apache.mahout.sparkbindings.drm.CheckpointedDrmSpark

import scalabindings._
import scalabindings.RLikeOps._
import drm.RLikeDrmOps._
import drm._
import scala.reflect.ClassTag
import scala.language.asInstanceOf
import collection._
import JavaConversions._

import org.apache.mahout.sparkbindings._

/**
 * Distributed training of a Naive Bayes model. Follows the approach presented in Rennie et.al.: Tackling the poor
 * assumptions of Naive Bayes Text classifiers, ICML 2003, http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf
 */
object SparkNaiveBayes extends NaiveBayes{

  /**
   * Math-Scala Naive Bayes optimized for Spark.
   *
   * Extract label Keys from raw TF or TF-IDF Matrix generated by seqdirectory/seq2sparse
   * and aggregate TF or TF-IDF values by their label
   *
   * @param stringKeyedObservations DrmLike matrix; Output from seq2sparse
   *   in form K = e.g./Category/document_title
   *           V = TF or TF-IDF values per term
   * @param cParser a String => String function used to extract categories from
   *   Keys of the stringKeyedObservations DRM. The default
   *   CategoryParser will extract "Category" from: '/Category/document_id'
   * @return  (labelIndexMap, aggregatedByLabelObservationDrm)
   *   labelIndexMap is a HashMap  K = label row index
   *                               V = label
   *   aggregatedByLabelObservationDrm is a DrmLike[Int] of aggregated
   *   TF or TF-IDF counts per label
   */
  override def extractLabelsAndAggregateObservations[K](stringKeyedObservations: DrmLike[K],
                                                                  cParser: CategoryParser = seq2SparseCategoryParser)
                                                                 (implicit ctx: DistributedContext):
                                                                 (mutable.HashMap[String, Integer], DrmLike[Int]) = {

    val stringKeyedRdd = stringKeyedObservations
                           .checkpoint()
                           .asInstanceOf[CheckpointedDrmSpark[String]]
                           .rdd

    // how expensive is it for spark to sort (relatively few) tuples?
    // does this cause repartitioning on the back end?
    val aggregatedRdd = stringKeyedRdd
                          .map(x => (cParser(x._1), x._2))
                          .reduceByKey(_ + _)
                        //  .sortByKey(true)

    stringKeyedObservations.uncache()

    var categoryIndex = 0
    val labelIndexMap = new mutable.HashMap[String, Integer]

    // todo: has to be an better way of creating this map
    val categoryArray = aggregatedRdd.keys.takeOrdered(aggregatedRdd.count.toInt)
    for(i <- categoryArray.indices){
      labelIndexMap.put(categoryArray(i), categoryIndex)
      categoryIndex += 1
    }

    val intKeyedRdd = aggregatedRdd.map(x => (labelIndexMap(x._1).toInt, x._2))

    val aggregetedObservationByLabelDrm = drmWrap(intKeyedRdd)

    (labelIndexMap, aggregetedObservationByLabelDrm)
  }

  /**
   * Test a trained model with a labeled dataset
   * @param model a trained NBModel
   * @param testSet a labeled testing set
   * @param testComplementary test using a complementary or a standard NB classifier
   * @param cParser a String => String function used to extract categories from
   *   Keys of the testing set DRM. The default
   *   CategoryParser will extract "Category" from: '/Category/document_id'
   * @tparam K implicitly determined Key type of test set DRM: String
   * @return a result analyzer with confusion matrix and accuracy statistics
   */
override def test[K: ClassTag](model: NBModel,
                               testSet: DrmLike[K],
                               testComplementary: Boolean = false,
                               cParser: CategoryParser = seq2SparseCategoryParser)
                              (implicit ctx: DistributedContext): ResultAnalyzer = {

    val labelMap = model.labelIndex
    val numLabels = model.numLabels

    testSet.checkpoint()

    val numTestInstances = testSet.nrow.toInt

    // instantiate the correct type of classifier
    val classifier = testComplementary match {
      case true => new ComplementaryNBClassifier(model) with Serializable
      case _ => new StandardNBClassifier(model) with Serializable
    }

    val bCastClassifier = ctx.broadcast(classifier)

    if (testComplementary) {
      assert(testComplementary == model.isComplementary,
        "Complementary Label Assignment requires Complementary Training")
    }

    val scoredTestSet = testSet.mapBlock(ncol = numLabels){
      case (keys, block)=>
        val numInstances = keys.size
        val blockB= block.like(numInstances, numLabels)
        for(i <- 0 until numInstances){
          blockB(i, ::) := bCastClassifier.value.classifyFull(block(i, ::) )
        }
        keys -> blockB
    }

    testSet.uncache()

    // may want to strip this down if we think that numDocuments x numLabels wont fit into memory
    val testSetLabelMap = scoredTestSet.getRowLabelBindings

    // collect so that we can slice rows.
    val inCoreScoredTestSet = scoredTestSet.collect

    // reverse the label map and extract the labels
    val reverseTestSetLabelMap = testSetLabelMap.map(x => x._2 -> cParser(x._1))

    // reverse the label map from out model
    val reverseLabelMap = labelMap.map(x => x._2 -> x._1)

    val analyzer = new ResultAnalyzer(labelMap.keys.toList.sorted, "DEFAULT")

    // assign labels- winner takes all
    for (i <- 0 until numTestInstances) {
      val (bestIdx, bestScore) = argmax(inCoreScoredTestSet(i,::))
      val classifierResult = new ClassifierResult(reverseLabelMap(bestIdx), bestScore)
      analyzer.addInstance(reverseTestSetLabelMap(i), classifierResult)
    }

    analyzer
  }

}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy