All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.librec.recommender.baseline.ItemClusterRecommender Maven / Gradle / Ivy

The newest version!
// Copyright (C) 2014-2015 Guibing Guo
//
// This file is part of LibRec.
//
// LibRec is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// LibRec is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with LibRec. If not, see .
//
package net.librec.recommender.baseline;

import com.google.common.collect.Sets;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.*;
import net.librec.recommender.MatrixProbabilisticGraphicalRecommender;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Set;

/**
 * It is a graphical model that clusters items into K groups for recommendation, as opposite to
 * the {@code UserCluster} recommender.
 *
 * @author Guo Guibing and zhanghaidong
 */
public class ItemClusterRecommender extends MatrixProbabilisticGraphicalRecommender {
    private DenseMatrix topicRatingProbs;   // Pkr
    private DenseVector topicInitialProbs;  // Pi

    private DenseMatrix itemTopicProbs;  // Gamma_(u,k)
    private DenseMatrix itemNumEachRating;    // Nur
    private DenseVector itemNumRatings;     // Nu

    private int numTopics;
    private int numRatingLevels;
    private double lastLoss;

    @Override
    protected void setup() throws LibrecException {
        super.setup();

        isRanking = false;
        Set ratingScaleSet = Sets.newTreeSet(trainMatrix.getDataTable().values());
        ratingScale = new ArrayList<>(ratingScaleSet);
        numRatingLevels = ratingScale.size();
        Collections.sort(ratingScale);
        numTopics = conf.getInt("rec.pgm.number", 10);

        topicRatingProbs = new DenseMatrix(numTopics, numRatingLevels);
        for (int k = 0; k < numTopics; k++) {
            double[] probs = Randoms.randProbs(numRatingLevels);
            for (int r = 0; r < numRatingLevels; r++) {
                topicRatingProbs.set(k, r, probs[r]);
            }
        }

        topicInitialProbs = new VectorBasedDenseVector(Randoms.randProbs(numTopics));

        itemTopicProbs = new DenseMatrix(numItems, numTopics);

        itemNumEachRating = new DenseMatrix(numItems, numRatingLevels);
        itemNumRatings = new VectorBasedDenseVector(numItems);

        for (int i = 0; i < numItems; i++) {
            SequentialSparseVector ri = trainMatrix.column(i);

            for (Vector.VectorEntry vi : ri) {
                double rui = vi.get();
                int r = ratingScale.indexOf(rui);
                itemNumEachRating.plus(i, r, 1);
            }
            itemNumRatings.set(i, ri.size());
        }
        lastLoss = Double.MIN_VALUE;

    }

    @Override
    protected void eStep() {
        for (int i = 0; i < numItems; i++) {
            BigDecimal sum_i = BigDecimal.ZERO;
            SequentialSparseVector ri = trainMatrix.column(i);

            BigDecimal[] sum_ik = new BigDecimal[numTopics];
            for (int k = 0; k < numTopics; k++) {
                BigDecimal itemTopicProb = new BigDecimal(topicInitialProbs.get(k));

                for (Vector.VectorEntry vi : ri) {
                    double rui = vi.get();
                    int r = ratingScale.indexOf(rui);
                    BigDecimal topicRatingProb = new BigDecimal(topicRatingProbs.get(k, r));

                    itemTopicProb = itemTopicProb.multiply(topicRatingProb);
                }
                sum_ik[k] = itemTopicProb;
                sum_i = sum_i.add(itemTopicProb);
            }
            for (int k = 0; k < numTopics; k++) {
                double zik = sum_ik[k].divide(sum_i, 6, RoundingMode.HALF_UP).doubleValue();
                itemTopicProbs.set(i, k, zik);
            }
        }
    }

    @Override
    protected void mStep() {
        double[] sum_ik = new double[numTopics];
        double sum = 0;

        for (int k = 0; k < numTopics; k++) {
            for (int r = 0; r < numRatingLevels; r++) {
                double numerator = 0.0, denorminator = 0.0;

                for (int i = 0; i < numItems; i++) {
                    double ruk = itemTopicProbs.get(i, k);
                    numerator += ruk * itemNumEachRating.get(i, r);
                    denorminator += ruk * itemNumRatings.get(i);
                }
                topicRatingProbs.set(k, r, numerator / denorminator);
            }

            double sum_i = 0;
            for (int i = 0; i < numItems; i++) {
                double ruk = itemTopicProbs.get(i, k);
                sum_i += ruk;
            }

            sum_ik[k] = sum_i;
            sum += sum_i;
        }

        for (int k = 0; k < numTopics; k++) {
            topicInitialProbs.set(k, sum_ik[k] / sum);
        }

    }

    @Override
    protected boolean isConverged(int iter) {
        double loss = 0.0;

        for (int i = 0; i < numItems; i++) {
            for (int k = 0; k < numTopics; k++) {
                double rik = itemTopicProbs.get(i, k);
                double pi_k = topicInitialProbs.get(k);

                double sum_nl = 0;
                for (int r = 0; r < numRatingLevels; r++) {
                    double nir = itemNumEachRating.get(i, r);
                    double pkr = topicRatingProbs.get(k, r);

                    sum_nl += nir * Math.log(pkr);
                }

                loss += rik * (Math.log(pi_k) + sum_nl);
            }
        }

        float deltaLoss = (float) (loss - lastLoss);

        if (iter > 1 && (deltaLoss > 0 || Double.isNaN(deltaLoss))) {
            return true;
        }

        lastLoss = loss;
        return false;
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        double pred = 0;

        for (int k = 0; k < numTopics; k++) {
            double pi_k = itemTopicProbs.get(itemIdx, k); // probability that item i belongs to cluster k
            double pred_k = 0;

            for (int r = 0; r < numRatingLevels; r++) {
                double rij = ratingScale.get(r);
                double pkr = topicRatingProbs.get(k, r);

                pred_k += rij * pkr;
            }

            pred += pi_k * pred_k;
        }

        return pred;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy