All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.datastax.insight.ml.spark.mllib.regression.DecisionTreeRegression Maven / Gradle / Ivy

The newest version!
package com.datastax.insight.ml.spark.mllib.regression;

import com.datastax.insight.spec.RDDOperator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import scala.Tuple2;

import java.util.HashMap;
import java.util.Map;

public class DecisionTreeRegression implements RDDOperator {
    public static DecisionTreeModel train(JavaRDD data,
                                          int maxDepth, int maxBins){
       return train(data,"variance",maxDepth,maxBins);
    }

    public static DecisionTreeModel train(JavaRDD data,
                                          String impurity, int maxDepth, int maxBins){
        Map categoricalFeaturesInfo = new HashMap<>();
        DecisionTreeModel model = DecisionTree.trainRegressor(data,
                categoricalFeaturesInfo, impurity, maxDepth, maxBins);
        return model;
    }

    public static JavaPairRDD predict(JavaRDD data, DecisionTreeModel model){
        JavaPairRDD predictionAndLabel =
                data.mapToPair(new PairFunction() {
                    @Override
                    public Tuple2 call(LabeledPoint p) {
                        return new Tuple2<>(model.predict(p.features()), p.label());
                    }
                });
        return predictionAndLabel;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy