com.datastax.insight.ml.spark.mllib.recommendation.als.RatingMetrics Maven / Gradle / Ivy
package com.datastax.insight.ml.spark.mllib.recommendation.als;
import com.alibaba.fastjson.JSON;
import com.datastax.insight.core.entity.Metrics;
import com.datastax.insight.spec.RDDOperator;
import com.datastax.insight.core.service.PersistService;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.evaluation.RegressionMetrics;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;
import scala.Tuple2;
/**
* Created by datastax on 2017/1/5.
*/
public class RatingMetrics implements RDDOperator {
public Metrics evaluation(MatrixFactorizationModel model, JavaRDD data) {
JavaPairRDD usersproducts = data.mapToPair(d-> new Tuple2<>(d.user(), d.product()));
JavaPairRDD, Double> predictions = model.predict(usersproducts)
.mapToPair(d-> new Tuple2<>(new Tuple2<>(d.user(), d.product()), d.rating()));
JavaPairRDD, Double> rawData = data
.mapToPair(d-> new Tuple2<>(new Tuple2<>(d.user(), d.product()), d.rating()));
JavaRDD> predictedAndTrue = rawData.join(predictions)
.map(d->new Tuple2<>(d._2()._1(), d._2()._2()));
RegressionMetrics regressionMetrics = new RegressionMetrics(predictedAndTrue.rdd());
Metrics metrics = new Metrics();
metrics.getIndicator().setMae(regressionMetrics.meanAbsoluteError());
metrics.getIndicator().setMse(regressionMetrics.meanSquaredError());
metrics.getIndicator().setRmse(regressionMetrics.rootMeanSquaredError());
metrics.getIndicator().setR2(regressionMetrics.r2());
PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO",
"saveModelMetrics",
new String[]{Long.class.getTypeName(), String.class.getTypeName()},
new Object[]{PersistService.getFlowId(), JSON.toJSONString(metrics)});
return metrics;
}
}