com.datastax.insight.ml.spark.mllib.regression.RandomForestRegression Maven / Gradle / Ivy
package com.datastax.insight.ml.spark.mllib.regression;
import com.datastax.insight.spec.RDDOperator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.RandomForest;
import org.apache.spark.mllib.tree.model.RandomForestModel;
import scala.Tuple2;
import java.util.HashMap;
import java.util.Map;
public class RandomForestRegression implements RDDOperator {
public static RandomForestModel train(JavaRDD data,
int numTrees, String featureSubsetStrategy, int maxDepth, int maxBins, int seed){
return train(data,numTrees,featureSubsetStrategy,"variance",maxDepth,maxBins,seed);
}
public static RandomForestModel train(JavaRDD data,
int numTrees,String featureSubsetStrategy,String impurity,int maxDepth,int maxBins,int seed){
Map categoricalFeaturesInfo = new HashMap<>();
RandomForestModel model = RandomForest.trainRegressor(data,
categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins,
seed);
return model;
}
public static JavaPairRDD predict(JavaRDD data, RandomForestModel model){
JavaPairRDD predictionAndLabel =
data.mapToPair(new PairFunction() {
@Override
public Tuple2 call(LabeledPoint p) {
return new Tuple2<>(model.predict(p.features()), p.label());
}
});
return predictionAndLabel;
}
}