com.datastax.insight.ml.spark.mllib.evaluator.MulticlassMetricsWrapper Maven / Gradle / Ivy
package com.datastax.insight.ml.spark.mllib.evaluator;
import com.alibaba.fastjson.JSON;
import com.datastax.insight.core.entity.Metrics;
import com.datastax.insight.spec.RDDOperator;
import com.datastax.insight.core.service.PersistService;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.classification.ClassificationModel;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.tree.model.TreeEnsembleModel;
import org.apache.spark.mllib.util.Saveable;
import scala.Tuple2;
public class MulticlassMetricsWrapper implements RDDOperator {
public Metrics evaluation(Saveable model, JavaRDD data) {
JavaRDD> scoreAndLabels = null;
if (model instanceof ClassificationModel) {
ClassificationModel realModel = (ClassificationModel) model;
scoreAndLabels = data.map(d -> new Tuple2<>(realModel.predict(d.features()), d.label()));
} else if (model instanceof DecisionTreeModel) {
DecisionTreeModel realModel = (DecisionTreeModel) model;
scoreAndLabels = data.map(d -> new Tuple2<>(realModel.predict(d.features()), d.label()));
} else if (model instanceof TreeEnsembleModel) {
TreeEnsembleModel realModel = (TreeEnsembleModel) model;
scoreAndLabels = data.map(d -> new Tuple2<>(realModel.predict(d.features()), d.label()));
} else {
String message = "[" + model.getClass().getTypeName() + "] is not supported, currently supports: ClassificationModel, DecisionTreeModel, TreeEnsembleModel";
throw new IllegalArgumentException(message);
}
Metrics metrics = new Metrics();
MulticlassMetrics multiclassMetrics = new MulticlassMetrics(scoreAndLabels.rdd());
metrics.getIndicator().setPrecision(multiclassMetrics.precision());
metrics.getIndicator().setRecall(multiclassMetrics.recall());
metrics.getIndicator().setfMeasure(multiclassMetrics.fMeasure());
metrics.getIndicator().setAccuracy(multiclassMetrics.accuracy());
metrics.getIndicator().setWeightedPrecision(multiclassMetrics.weightedPrecision());
metrics.getIndicator().setWeightedRecall(multiclassMetrics.weightedRecall());
metrics.getIndicator().setWeightedFMeasure(multiclassMetrics.weightedFMeasure());
metrics.getIndicator().setWeightedTruePositiveRate(multiclassMetrics.weightedTruePositiveRate());
metrics.getIndicator().setWeightedFalsePositiveRate(multiclassMetrics.weightedFalsePositiveRate());
PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO",
"saveModelMetrics",
new String[]{Long.class.getTypeName(), String.class.getTypeName()},
new Object[]{PersistService.getFlowId(), JSON.toJSONString(metrics)});
return metrics;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy