com.datastax.insight.ml.spark.mllib.regression.LogisticRegression Maven / Gradle / Ivy
package com.datastax.insight.ml.spark.mllib.regression;
import com.datastax.insight.spec.RDDOperator;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
import org.apache.spark.mllib.regression.LabeledPoint;
import scala.Tuple2;
public class LogisticRegression implements RDDOperator {
public static LogisticRegressionModel train(JavaRDD data,int numClasses){
LogisticRegressionModel model = new LogisticRegressionWithLBFGS().setNumClasses(numClasses).run(data.rdd());
return model;
}
public static JavaRDD> predict(JavaRDD data, LogisticRegressionModel model){
JavaRDD> predictionAndLabels = data.map(
new Function>() {
public Tuple2