com.datastax.insight.ml.spark.mllib.classification.NaiveBayesClassifier Maven / Gradle / Ivy
package com.datastax.insight.ml.spark.mllib.classification;
import com.datastax.insight.spec.RDDOperator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.classification.NaiveBayes;
import org.apache.spark.mllib.classification.NaiveBayesModel;
import org.apache.spark.mllib.regression.LabeledPoint;
import scala.Tuple2;
public class NaiveBayesClassifier implements RDDOperator {
public static NaiveBayesModel train(JavaRDD data,double lamda){
NaiveBayesModel model = NaiveBayes.train(data.rdd(), lamda);
return model;
}
public static JavaPairRDD predict(JavaRDD data,NaiveBayesModel model){
JavaPairRDD predictionAndLabel =
data.mapToPair(new PairFunction() {
@Override
public Tuple2 call(LabeledPoint p) {
return new Tuple2<>(model.predict(p.features()), p.label());
}
});
return predictionAndLabel;
}
}