com.datastax.insight.ml.spark.mllib.classification.GradientBoosting Maven / Gradle / Ivy
The newest version!
package com.datastax.insight.ml.spark.mllib.classification;
import com.datastax.insight.spec.RDDOperator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.GradientBoostedTrees;
import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel;
import scala.Tuple2;
import java.util.HashMap;
import java.util.Map;
public class GradientBoosting implements RDDOperator {
public static GradientBoostedTreesModel trainClassifier(JavaRDD data,
int numIterations,int numClasses,int maxDepth){
return train(data,"Classification",numIterations,numClasses,maxDepth);
}
public static GradientBoostedTreesModel trainRegressor(JavaRDD data,
int numIterations,int numClasses,int maxDepth){
return train(data,"Regression",numIterations,numClasses,maxDepth);
}
public static GradientBoostedTreesModel train(JavaRDD data,String defaultParams,
int numIterations,int numClasses,int maxDepth){
BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(defaultParams);
boostingStrategy.setNumIterations(numIterations); // Note: Use more iterations in practice.
boostingStrategy.getTreeStrategy().setNumClasses(numClasses);
boostingStrategy.getTreeStrategy().setMaxDepth(maxDepth);
// Empty categoricalFeaturesInfo indicates all features are continuous.
Map categoricalFeaturesInfo = new HashMap<>();
boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo);
GradientBoostedTreesModel model =
GradientBoostedTrees.train(data, boostingStrategy);
return model;
}
public static JavaPairRDD predict(JavaRDD data,GradientBoostedTreesModel model){
JavaPairRDD predictionAndLabel =
data.mapToPair(new PairFunction() {
@Override
public Tuple2 call(LabeledPoint p) {
return new Tuple2<>(model.predict(p.features()), p.label());
}
});
return predictionAndLabel;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy