org.apache.mahout.classifier.naivebayes.SparkNaiveBayes.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mahout-spark_2.10 Show documentation
Show all versions of mahout-spark_2.10 Show documentation
Mahout Bindings for Apache Spark
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.classifier.naivebayes
import org.apache.mahout.classifier.stats.{ClassifierResult, ResultAnalyzer}
import org.apache.mahout.math._
import org.apache.mahout.sparkbindings.drm.CheckpointedDrmSpark
import scalabindings._
import scalabindings.RLikeOps._
import drm.RLikeDrmOps._
import drm._
import scala.reflect.ClassTag
import scala.language.asInstanceOf
import collection._
import JavaConversions._
import org.apache.mahout.sparkbindings._
/**
* Distributed training of a Naive Bayes model. Follows the approach presented in Rennie et.al.: Tackling the poor
* assumptions of Naive Bayes Text classifiers, ICML 2003, http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf
*/
object SparkNaiveBayes extends NaiveBayes{
/**
* Math-Scala Naive Bayes optimized for Spark.
*
* Extract label Keys from raw TF or TF-IDF Matrix generated by seqdirectory/seq2sparse
* and aggregate TF or TF-IDF values by their label
*
* @param stringKeyedObservations DrmLike matrix; Output from seq2sparse
* in form K = e.g./Category/document_title
* V = TF or TF-IDF values per term
* @param cParser a String => String function used to extract categories from
* Keys of the stringKeyedObservations DRM. The default
* CategoryParser will extract "Category" from: '/Category/document_id'
* @return (labelIndexMap, aggregatedByLabelObservationDrm)
* labelIndexMap is a HashMap K = label row index
* V = label
* aggregatedByLabelObservationDrm is a DrmLike[Int] of aggregated
* TF or TF-IDF counts per label
*/
override def extractLabelsAndAggregateObservations[K](stringKeyedObservations: DrmLike[K],
cParser: CategoryParser = seq2SparseCategoryParser)
(implicit ctx: DistributedContext):
(mutable.HashMap[String, Integer], DrmLike[Int]) = {
val stringKeyedRdd = stringKeyedObservations
.checkpoint()
.asInstanceOf[CheckpointedDrmSpark[String]]
.rdd
// how expensive is it for spark to sort (relatively few) tuples?
// does this cause repartitioning on the back end?
val aggregatedRdd = stringKeyedRdd
.map(x => (cParser(x._1), x._2))
.reduceByKey(_ + _)
// .sortByKey(true)
stringKeyedObservations.uncache()
var categoryIndex = 0
val labelIndexMap = new mutable.HashMap[String, Integer]
// todo: has to be an better way of creating this map
val categoryArray = aggregatedRdd.keys.takeOrdered(aggregatedRdd.count.toInt)
for(i <- categoryArray.indices){
labelIndexMap.put(categoryArray(i), categoryIndex)
categoryIndex += 1
}
val intKeyedRdd = aggregatedRdd.map(x => (labelIndexMap(x._1).toInt, x._2))
val aggregetedObservationByLabelDrm = drmWrap(intKeyedRdd)
(labelIndexMap, aggregetedObservationByLabelDrm)
}
/**
* Test a trained model with a labeled dataset
* @param model a trained NBModel
* @param testSet a labeled testing set
* @param testComplementary test using a complementary or a standard NB classifier
* @param cParser a String => String function used to extract categories from
* Keys of the testing set DRM. The default
* CategoryParser will extract "Category" from: '/Category/document_id'
* @tparam K implicitly determined Key type of test set DRM: String
* @return a result analyzer with confusion matrix and accuracy statistics
*/
override def test[K: ClassTag](model: NBModel,
testSet: DrmLike[K],
testComplementary: Boolean = false,
cParser: CategoryParser = seq2SparseCategoryParser)
(implicit ctx: DistributedContext): ResultAnalyzer = {
val labelMap = model.labelIndex
val numLabels = model.numLabels
testSet.checkpoint()
val numTestInstances = testSet.nrow.toInt
// instantiate the correct type of classifier
val classifier = testComplementary match {
case true => new ComplementaryNBClassifier(model) with Serializable
case _ => new StandardNBClassifier(model) with Serializable
}
val bCastClassifier = ctx.broadcast(classifier)
if (testComplementary) {
assert(testComplementary == model.isComplementary,
"Complementary Label Assignment requires Complementary Training")
}
val scoredTestSet = testSet.mapBlock(ncol = numLabels){
case (keys, block)=>
val numInstances = keys.size
val blockB= block.like(numInstances, numLabels)
for(i <- 0 until numInstances){
blockB(i, ::) := bCastClassifier.value.classifyFull(block(i, ::) )
}
keys -> blockB
}
testSet.uncache()
// may want to strip this down if we think that numDocuments x numLabels wont fit into memory
val testSetLabelMap = scoredTestSet.getRowLabelBindings
// collect so that we can slice rows.
val inCoreScoredTestSet = scoredTestSet.collect
// reverse the label map and extract the labels
val reverseTestSetLabelMap = testSetLabelMap.map(x => x._2 -> cParser(x._1))
// reverse the label map from out model
val reverseLabelMap = labelMap.map(x => x._2 -> x._1)
val analyzer = new ResultAnalyzer(labelMap.keys.toList.sorted, "DEFAULT")
// assign labels- winner takes all
for (i <- 0 until numTestInstances) {
val (bestIdx, bestScore) = argmax(inCoreScoredTestSet(i,::))
val classifierResult = new ClassifierResult(reverseLabelMap(bestIdx), bestScore)
analyzer.addInstance(reverseTestSetLabelMap(i), classifierResult)
}
analyzer
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy