All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.datastax.insight.ml.spark.mllib.regression.LogisticRegression Maven / Gradle / Ivy

package com.datastax.insight.ml.spark.mllib.regression;

import com.datastax.insight.spec.RDDOperator;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
import org.apache.spark.mllib.regression.LabeledPoint;
import scala.Tuple2;

public class LogisticRegression implements RDDOperator {
    public static LogisticRegressionModel train(JavaRDD data,int numClasses){
        LogisticRegressionModel model = new LogisticRegressionWithLBFGS().setNumClasses(numClasses).run(data.rdd());
        return model;
    }

    public static JavaRDD> predict(JavaRDD data, LogisticRegressionModel model){
        JavaRDD> predictionAndLabels = data.map(
                new Function>() {
                    public Tuple2 call(LabeledPoint p) {
                        Double prediction = model.predict(p.features());
                        return new Tuple2(prediction, p.label());
                    }
                }
        );
        return predictionAndLabels;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy