All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.datastax.insight.ml.spark.mllib.recommendation.als.RatingMetrics Maven / Gradle / Ivy

package com.datastax.insight.ml.spark.mllib.recommendation.als;

import com.alibaba.fastjson.JSON;
import com.datastax.insight.core.entity.Metrics;
import com.datastax.insight.spec.RDDOperator;
import com.datastax.insight.core.service.PersistService;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.evaluation.RegressionMetrics;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;
import scala.Tuple2;

/**
 * Created by datastax on 2017/1/5.
 */
public class RatingMetrics implements RDDOperator {

    public Metrics evaluation(MatrixFactorizationModel model, JavaRDD data) {
        JavaPairRDD usersproducts = data.mapToPair(d-> new Tuple2<>(d.user(), d.product()));

        JavaPairRDD, Double> predictions = model.predict(usersproducts)
                .mapToPair(d-> new Tuple2<>(new Tuple2<>(d.user(), d.product()), d.rating()));
        JavaPairRDD, Double> rawData = data
                .mapToPair(d-> new Tuple2<>(new Tuple2<>(d.user(), d.product()), d.rating()));

        JavaRDD> predictedAndTrue = rawData.join(predictions)
                .map(d->new Tuple2<>(d._2()._1(), d._2()._2()));

        RegressionMetrics regressionMetrics = new RegressionMetrics(predictedAndTrue.rdd());

        Metrics metrics = new Metrics();
        metrics.getIndicator().setMae(regressionMetrics.meanAbsoluteError());
        metrics.getIndicator().setMse(regressionMetrics.meanSquaredError());
        metrics.getIndicator().setRmse(regressionMetrics.rootMeanSquaredError());
        metrics.getIndicator().setR2(regressionMetrics.r2());

        PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO",
                "saveModelMetrics",
                new String[]{Long.class.getTypeName(), String.class.getTypeName()},
                new Object[]{PersistService.getFlowId(), JSON.toJSONString(metrics)});

        return metrics;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy