com.datastax.insight.ml.spark.mllib.evaluator.BinaryClassificationMetricsWrapper Maven / Gradle / Ivy
package com.datastax.insight.ml.spark.mllib.evaluator;
import com.alibaba.fastjson.JSON;
import com.datastax.insight.core.entity.CurvePoint;
import com.datastax.insight.core.entity.Metrics;
import com.datastax.insight.spec.RDDOperator;
import com.datastax.insight.core.service.PersistService;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.classification.ClassificationModel;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.tree.model.TreeEnsembleModel;
import org.apache.spark.mllib.util.Saveable;
import scala.Tuple2;
import java.util.List;
/**
* Created by datastax on 2017/1/7.
*/
public class BinaryClassificationMetricsWrapper implements RDDOperator {
public Metrics evaluation(Saveable model, JavaRDD data) {
Metrics metrics = new Metrics();
JavaRDD> scoreAndLabels;
// Clear the prediction threshold so the model will return probabilities
if (model instanceof LogisticRegressionModel) {
LogisticRegressionModel logisticRegressionModel = (LogisticRegressionModel) model;
logisticRegressionModel.clearThreshold();
scoreAndLabels = data.map(d -> {
Double score = logisticRegressionModel.predict(d.features());
return new Tuple2<>(score, d.label());
});
} else if (model instanceof ClassificationModel) {
ClassificationModel classificationModel = (ClassificationModel)model;
scoreAndLabels = data.map(d -> {
Double score = classificationModel.predict(d.features());
return new Tuple2<>(score, d.label());
});
} else if (model instanceof DecisionTreeModel) {
DecisionTreeModel decisionTreeModel = (DecisionTreeModel)model;
scoreAndLabels = data.map(d -> {
Double score = decisionTreeModel.predict(d.features());
return new Tuple2<>(score, d.label());
});
} else if (model instanceof TreeEnsembleModel) {
TreeEnsembleModel treeEnsembleModel = (TreeEnsembleModel)model;
scoreAndLabels = data.map(d -> {
Double score = treeEnsembleModel.predict(d.features());
return new Tuple2<>(score, d.label());
});
} else {
String message = "[" + model.getClass().getTypeName() + "] is not supported, currently supports: LogisticRegressionModel, ClassificationModel, DecisionTreeModel, TreeEnsembleModel";
throw new IllegalArgumentException(message);
}
BinaryClassificationMetrics binaryClassificationMetrics = new BinaryClassificationMetrics(scoreAndLabels.rdd());
metrics.getIndicator().setAreaUnderPR(binaryClassificationMetrics.areaUnderPR());
metrics.getIndicator().setAreaUnderROC(binaryClassificationMetrics.areaUnderROC());
List roc = binaryClassificationMetrics.roc().toJavaRDD()
.map(r -> new CurvePoint(Double.parseDouble(r._1().toString()), Double.parseDouble(r._2().toString()))).collect();
metrics.setRoc(roc);
List pr = binaryClassificationMetrics.pr().toJavaRDD()
.map(r -> new CurvePoint(Double.parseDouble(r._1().toString()), Double.parseDouble(r._2().toString()))).collect();
metrics.setPr(pr);
PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO",
"saveModelMetrics",
new String[]{Long.class.getTypeName(), String.class.getTypeName()},
new Object[]{PersistService.getFlowId(), JSON.toJSONString(metrics)});
return metrics;
}
}