All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.datastax.insight.ml.spark.mllib.evaluator.RankingMetricsWrapper Maven / Gradle / Ivy

The newest version!
package com.datastax.insight.ml.spark.mllib.evaluator;

import com.alibaba.fastjson.JSON;
import com.datastax.insight.core.entity.Metrics;
import com.datastax.insight.spec.RDDOperator;
import com.datastax.insight.core.service.PersistService;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.evaluation.RankingMetrics;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;
import scala.Tuple2;

import java.util.ArrayList;
import java.util.List;

/**
 * Created by huangping on 17-1-16.
 */
public class RankingMetricsWrapper implements RDDOperator {

    public Metrics evaluation(MatrixFactorizationModel model, JavaRDD data) {

        //get top 10 recommendations for every user and scala ratings from 0 to 1
        JavaRDD> userRecs = model.recommendProductsForUsers(10).toJavaRDD();
        JavaRDD> userRecsScaled = userRecs.map((Function, Tuple2>) t -> {
            Rating[] scaledRating = new Rating[t._2().length];
            for (int i = 0; i < scaledRating.length; i++) {
                double newRating = Math.max(Math.min(t._2()[i].rating(), 1.0), 0.0);
                scaledRating[i] = new Rating(t._2()[i].user(), t._2()[i].product(), newRating);
            }
            return new Tuple2<>(t._1(), scaledRating);
        });
        JavaPairRDD userRecommended = JavaPairRDD.fromJavaRDD(userRecsScaled);

        //map ratings to 1 or 0, indicating a product that should be recommended
        JavaRDD binarizedRatings = data.map(r -> new Rating(r.user(), r.product(), r.rating() > 0 ? 1.0 : 0.0));

        //group ratings by common user
        JavaPairRDD> userProducts = binarizedRatings.groupBy((Function) r -> r.user());

        //get true relevant documents from user ratings
        JavaPairRDD> userProductsList = userProducts.mapValues((Function, List>) t -> {
            List products = new ArrayList<>();
            for (Rating r : t) {
                if (r.rating() > 0.0) {
                    products.add(r.product());
                }
            }
            return products;
        });

        //extract the product id from each recommendation
        JavaPairRDD> userRecommendedList = userRecommended.mapValues((Function>) t -> {
            List products = new ArrayList<>();
            for (Rating r : t) {
                products.add(r.product());
            }
            return products;
        });
        JavaRDD, List>> relevantDocs = userProductsList.join(userRecommendedList).values();

        //instantiate the metrics object
        RankingMetrics rankingMetrics = RankingMetrics.of(relevantDocs);
        Metrics metrics = new Metrics();

        metrics.getIndicator().setMeanAveragePrecision(rankingMetrics.meanAveragePrecision());

        PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO",
                "saveModelMetrics",
                new String[]{Long.class.getTypeName(), String.class.getTypeName()},
                new Object[]{PersistService.getFlowId(), JSON.toJSONString(metrics)});

        return metrics;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy