All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.librec.recommender.cf.BHFreeRecommender Maven / Gradle / Ivy

/**
 * Copyright (C) 2016 LibRec
 * 

* This file is part of LibRec. * LibRec is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. *

* LibRec is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. *

* You should have received a copy of the GNU General Public License * along with LibRec. If not, see . */ package net.librec.recommender.cf; import com.google.common.collect.HashBasedTable; import com.google.common.collect.Table; import net.librec.common.LibrecException; import net.librec.math.algorithm.Randoms; import net.librec.math.structure.DenseMatrix; import net.librec.math.structure.DenseVector; import net.librec.math.structure.MatrixEntry; import net.librec.recommender.ProbabilisticGraphicalRecommender; /** * Barbieri et al., Balancing Prediction and Recommendation Accuracy: Hierarchical Latent Factors for Preference * Data, SDM 2012.
*

* Remarks: this class implements the BH-free method. * * @author Guo Guibing and haidong zhang */ public class BHFreeRecommender extends ProbabilisticGraphicalRecommender { private float initGamma, initSigma, initAlpha, initBeta; /** * number of user communities */ private int numUserTopics; // K /** * number of item categories */ private int numItemTopics; // L /** * evaluation of the user u which have been assigned to the user topic k */ private DenseMatrix userTopicNum; /** * observations for the user */ private DenseVector userNum; /** * observations associated with community k */ private DenseVector uTopicNum; /** * number of user communities * number of topics */ private DenseMatrix userTopicItemTopicNum; // Nkl /** * number of user communities * number of topics * number of ratings */ private int[][][] userTopicItemTopicRatingNum, userTopicItemTopicItemNum; // Nklr, Nkli; /** * */ private Table userTopics, itemTopics; /** * */ private int numRatingLevels; // parameters private DenseMatrix userTopicProbs, userTopicItemTopicProbs; private DenseMatrix userTopicSumProbs, userTopicItemTopicSumProbs; private double[][][] userTopicItemTopicRatingProbs, userTopicItemTopicItemProbs; private double[][][] userTopicItemTopicRatingSumProbs, userTopicItemTopicItemSumProbs; protected void setup() throws LibrecException { super.setup(); numUserTopics = conf.getInt("rec.bhfree.user.topic.number", 10); numItemTopics = conf.getInt("rec.bhfree.item.topic.number", 10); initAlpha = conf.getFloat("rec.bhfree.alpha", 1.0f / numUserTopics); initBeta = conf.getFloat("rec.bhfree.beta", 1.0f / numItemTopics); initGamma = conf.getFloat("rec.bhfree.gamma", 1.0f / numRatingLevels); initSigma = conf.getFloat("rec.sigma", 1.0f / numItems); numRatingLevels = trainMatrix.getValueSet().size(); userTopicNum = new DenseMatrix(numUsers, numUserTopics); userNum = new DenseVector(numUsers); userTopicItemTopicNum = new DenseMatrix(numUserTopics, numItemTopics); uTopicNum = new DenseVector(numUserTopics); userTopicItemTopicRatingNum = new int[numUserTopics][numItemTopics][numRatingLevels]; userTopicItemTopicItemNum = new int[numUserTopics][numItemTopics][numItems]; userTopics = HashBasedTable.create(); itemTopics = HashBasedTable.create(); for (MatrixEntry me : trainMatrix) { int u = me.row(); int i = me.column(); double rate = me.get(); int r = ratingScale.indexOf(rate); int k = (int) (numUserTopics * Randoms.uniform()); // user's topic k int l = (int) (numItemTopics * Randoms.uniform()); // item's topic l userTopicNum.add(u, k, 1); userNum.add(u, 1); userTopicItemTopicNum.add(k, l, 1); uTopicNum.add(k, 1); userTopicItemTopicRatingNum[k][l][r]++; userTopicItemTopicItemNum[k][l][i]++; userTopics.put(u, i, k); itemTopics.put(u, i, l); } // parameters userTopicSumProbs = new DenseMatrix(numUsers, numUserTopics); userTopicItemTopicSumProbs = new DenseMatrix(numUserTopics, numItemTopics); userTopicItemTopicRatingSumProbs = new double[numUserTopics][numItemTopics][numRatingLevels]; userTopicItemTopicRatingProbs = new double[numUserTopics][numItemTopics][numRatingLevels]; userTopicItemTopicItemSumProbs = new double[numUserTopics][numItemTopics][numItems]; userTopicItemTopicItemProbs = new double[numUserTopics][numItemTopics][numItems]; } @Override protected void eStep() { for (MatrixEntry me : trainMatrix) { int u = me.row(); int i = me.column(); double rate = me.get(); int r = ratingScale.indexOf(rate); int k = userTopics.get(u, i); int l = itemTopics.get(u, i); userTopicNum.add(u, k, -1); userNum.add(u, -1); userTopicItemTopicNum.add(k, l, -1); uTopicNum.add(k, -1); userTopicItemTopicRatingNum[k][l][r]--; userTopicItemTopicItemNum[k][l][i]--; DenseMatrix userTopicItemTopicProbs = new DenseMatrix(numUserTopics, numItemTopics); double sum = 0; for (int z = 0; z < numUserTopics; z++) { for (int w = 0; w < numItemTopics; w++) { double v1 = (userTopicNum.get(u, k) + initAlpha) / (userNum.get(u) + numUserTopics * initAlpha); double v2 = (userTopicItemTopicNum.get(k, l) + initBeta) / (uTopicNum.get(k) + numItemTopics * initBeta); double v3 = (userTopicItemTopicRatingNum[k][l][r] + initGamma) / (userTopicItemTopicNum.get(k, l) + numRatingLevels * initGamma); double v4 = (userTopicItemTopicItemNum[k][l][i] + initSigma) / (userTopicItemTopicNum.get(k, l) + numItems * initSigma); double val = v1 * v2 * v3 * v4; userTopicItemTopicProbs.set(z, w, val); sum += val; } } // normalization userTopicItemTopicProbs = userTopicItemTopicProbs.scale(1.0 / sum); // resample k double[] userTopicProbs = new double[numUserTopics]; for (int z = 0; z < numUserTopics; z++) { userTopicProbs[z] = userTopicItemTopicProbs.sumOfRow(z); } for (int z = 1; z < numUserTopics; z++) { userTopicProbs[z] += userTopicProbs[z - 1]; } double rand = Randoms.uniform(); for (k = 0; k < numUserTopics - 1; k++) { if (rand < userTopicProbs[k]) break; } // resample item topic double[] itemTopicProbs = new double[numItemTopics]; for (int w = 0; w < numItemTopics; w++) { itemTopicProbs[w] = userTopicItemTopicProbs.sumOfColumn(w); } for (int w = 1; w < numItemTopics; w++) { itemTopicProbs[w] += itemTopicProbs[w - 1]; } rand = Randoms.uniform(); for (l = 0; l < numItemTopics - 1; l++) { if (rand < itemTopicProbs[l]) break; } // add statistic userTopicNum.add(u, k, 1); userNum.add(u, 1); userTopicItemTopicNum.add(k, l, 1); uTopicNum.add(k, 1); userTopicItemTopicRatingNum[k][l][r]++; userTopicItemTopicItemNum[k][l][i]++; userTopics.put(u, i, k); itemTopics.put(u, i, l); } } @Override protected void mStep() { } @Override protected void readoutParams() { for (int u = 0; u < numUsers; u++) { for (int k = 0; k < numUserTopics; k++) { userTopicSumProbs.add(u, k, (userTopicNum.get(u, k) + initAlpha) / (userNum.get(u) + numUserTopics * initAlpha)); } } for (int k = 0; k < numUserTopics; k++) { for (int l = 0; l < numItemTopics; l++) { userTopicItemTopicSumProbs.add(k, l, (userTopicItemTopicNum.get(k, l) + initBeta) / (uTopicNum.get(k) + numItemTopics * initBeta)); } } for (int k = 0; k < numUserTopics; k++) { for (int l = 0; l < numItemTopics; l++) { for (int r = 0; r < numRatingLevels; r++) { userTopicItemTopicRatingSumProbs[k][l][r] += (userTopicItemTopicRatingNum[k][l][r] + initGamma) / (userTopicItemTopicNum.get(k, l) + numRatingLevels * initGamma); } } } for (int k = 0; k < numUserTopics; k++) { for (int l = 0; l < numItemTopics; l++) { for (int i = 0; i < numItems; i++) { userTopicItemTopicItemSumProbs[k][l][i] += (userTopicItemTopicItemNum[k][l][i] + initSigma) / (userTopicItemTopicNum.get(k, l) + numItems * initSigma); } } } numStats++; } @Override protected void estimateParams() { double scale = 1.0 / numStats; userTopicProbs = userTopicSumProbs.scale(scale); userTopicItemTopicProbs = userTopicItemTopicSumProbs.scale(scale); for (int k = 0; k < numUserTopics; k++) { for (int l = 0; l < numItemTopics; l++) { for (int r = 0; r < numRatingLevels; r++) { userTopicItemTopicRatingProbs[k][l][r] = userTopicItemTopicRatingSumProbs[k][l][r] * scale; } } } for (int k = 0; k < numUserTopics; k++) { for (int l = 0; l < numItemTopics; l++) { for (int i = 0; i < numItems; i++) { userTopicItemTopicItemProbs[k][l][i] = userTopicItemTopicItemSumProbs[k][l][i] * scale; } } } } @Override protected double predict(int userIdx, int itemIdx) throws LibrecException { if (isRanking) { return predictRanking(userIdx, itemIdx); } else { return predictRating(userIdx, itemIdx); } } protected double predictRating(int userIdx, int itemIdx) { double sum = 0, probs = 0; for (int r = 0; r < numRatingLevels; r++) { double rate = ratingScale.get(r); double prob = 0; for (int k = 0; k < numUserTopics; k++) { for (int l = 0; l < numItemTopics; l++) { prob += userTopicProbs.get(userIdx, k) * userTopicItemTopicProbs.get(k, l) * userTopicItemTopicRatingProbs[k][l][r]; } } sum += rate * prob; probs += prob; } return sum / probs; } protected double predictRanking(int userIdx, int itemIdx) { double rank = 0; for (int r = 0; r < numRatingLevels; r++) { double rate = ratingScale.get(r); double prob = 0; for (int k = 0; k < numUserTopics; k++) { for (int l = 0; l < numItemTopics; l++) { prob += userTopicProbs.get(userIdx, k) * userTopicItemTopicProbs.get(k, l) * userTopicItemTopicItemSumProbs[k][l][itemIdx] * userTopicItemTopicRatingProbs[k][l][r]; } } rank += rate * prob; } return rank; } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy