All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.librec.recommender.cf.rating.BPMFRecommender Maven / Gradle / Ivy

/**
 * Copyright (C) 2016 LibRec
 * 

* This file is part of LibRec. * LibRec is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. *

* LibRec is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. *

* You should have received a copy of the GNU General Public License * along with LibRec. If not, see . */ package net.librec.recommender.cf.rating; import net.librec.common.LibrecException; import net.librec.math.algorithm.Randoms; import net.librec.math.structure.*; import net.librec.recommender.MatrixFactorizationRecommender; import java.util.ArrayList; import java.util.List; /** * Salakhutdinov and Mnih, Bayesian Probabilistic Matrix Factorization using Markov Chain Monte Carlo, * ICML 2008. *

* Matlab version is provided by the authors via this * link. This implementation is modified from the BayesianPMF by the PREA package. * Bayesian Probabilistic Matrix Factorization */ public class BPMFRecommender extends MatrixFactorizationRecommender { private double userMu0, userBeta0, userWishartScale0; private double itemMu0, itemBeta0, itemWishartScale0; private DenseVector userMu, itemMu; private DenseMatrix userWishartScale, itemWishartScale; private double userBeta, itemBeta; private double userWishartNu, itemWishartNu; private double ratingSigma; private SparseMatrix predictMatrix; public class HyperParameters { public DenseVector mu; public DenseMatrix variance; HyperParameters(DenseVector _mu, DenseMatrix _variance) { mu = _mu; variance = _variance; } } @Override protected void setup() throws LibrecException { super.setup(); userMu0 = conf.getDouble("rec.recommender.user.mu", 0.0); userBeta0 = conf.getDouble("rec.recommender.user.beta", 1.0); userWishartScale0 = conf.getDouble("rec.recommender.user.wishart.scale", 1.0); itemMu0 = conf.getDouble("rec.recommender.item.mu", 0.0); itemBeta0 = conf.getDouble("rec.recommender.item.beta", 1.0); itemWishartScale0 = conf.getDouble("rec.recommender.item.wishart.scale", 1.0); ratingSigma = conf.getDouble("rec.recommender.rating.sigma", 2.0); } /** * Initialize the model * * @throws LibrecException if error occurs */ protected void initModel() throws LibrecException { userMu = new DenseVector(numFactors); userMu.setAll(userMu0); itemMu = new DenseVector(numFactors); itemMu.setAll(itemMu0); userBeta = userBeta0; itemBeta = itemBeta0; userWishartScale = new DenseMatrix(numFactors, numFactors); itemWishartScale = new DenseMatrix(numFactors, numFactors); for (int i = 0; i < numFactors; i++) { userWishartScale.set(i, i, userWishartScale0); itemWishartScale.set(i, i, itemWishartScale0); } userWishartScale.inv(); itemWishartScale.inv(); userWishartNu = numFactors; itemWishartNu = numFactors; predictMatrix = new SparseMatrix(testMatrix); } /** * */ @Override protected void trainModel() throws LibrecException { initModel(); // Speed up getting user or item vector in Gibbs sampling List userTrainVectors = new ArrayList(numUsers); List itemTrainVectors = new ArrayList(numItems); for (int u = 0; u < numUsers; u++) { userTrainVectors.add(trainMatrix.row(u)); } for (int i = 0; i < numItems; i++) { itemTrainVectors.add(trainMatrix.column(i)); } DenseVector mu_u = new DenseVector(numFactors); DenseVector mu_m = new DenseVector(numFactors); for (int f = 0; f < numFactors; f++) { mu_u.set(f, userFactors.columnMean(f)); mu_m.set(f, itemFactors.columnMean(f)); } DenseMatrix variance_u = userFactors.cov().inv(); DenseMatrix variance_m = itemFactors.cov().inv(); HyperParameters userHyperParameters = new HyperParameters(mu_u, variance_u); HyperParameters itemHyperParameters = new HyperParameters(mu_m, variance_m); for (int iter = 0; iter < numIterations; iter++) { userHyperParameters = samplingHyperParameters(userHyperParameters, userFactors, userMu, userBeta, userWishartScale, userWishartNu); itemHyperParameters = samplingHyperParameters(itemHyperParameters, itemFactors, itemMu, itemBeta, itemWishartScale, itemWishartNu); for (int gibbsIteration = 0; gibbsIteration < 1; gibbsIteration++) { for (int u = 0; u < numUsers; u++) { SparseVector ratings = userTrainVectors.get(u); int count = ratings.getCount(); if (count == 0) { continue; } userFactors.setRow(u, updateParameters(itemFactors, ratings, userHyperParameters)); } for (int i = 0; i < numItems; i++) { SparseVector ratings = itemTrainVectors.get(i); int count = ratings.getCount(); if (count == 0) { continue; } itemFactors.setRow(i, updateParameters(userFactors, ratings, itemHyperParameters)); } } if (iter == 1) { for (MatrixEntry me : testMatrix) { int u = me.row(); int i = me.column(); predictMatrix.set(u, i, 0.0); } } int startnum = 0; if (iter > startnum) { for (MatrixEntry me : testMatrix) { int userIdx = me.row(); int itemIdx = me.column(); double predictValue = (predictMatrix.get(userIdx, itemIdx) * (iter - 1 - startnum) + globalMean + DenseMatrix.rowMult(userFactors, userIdx, itemFactors, itemIdx)) / (iter - startnum); predictMatrix.set(userIdx, itemIdx, predictValue); } } } } protected HyperParameters samplingHyperParameters(HyperParameters hyperParameters, DenseMatrix factors, DenseVector normalMu0, double normalBeta0, DenseMatrix WishartScale0, double WishartNu0) throws LibrecException { int numRows = factors.numRows(); int numColumns = factors.numColumns(); DenseVector mean = new DenseVector(numFactors); for (int i = 0; i < numColumns; i++) { mean.set(i, factors.columnMean(i)); } DenseMatrix populationVariance = factors.cov(); double betaPost = normalBeta0 + numRows; double nuPost = WishartNu0 + 1.0; DenseVector muPost = normalMu0.scale(normalBeta0).add(mean.scale(numRows)).scale(1.0 / betaPost); DenseMatrix WishartScalePost = WishartScale0.add(populationVariance.scale(numRows)); DenseVector muError = normalMu0.minus(mean); WishartScalePost = WishartScalePost.add(muError.outer(muError).scale(normalBeta0 * numRows / betaPost)); WishartScalePost = WishartScalePost.inv(); WishartScalePost = WishartScalePost.add(WishartScalePost.transpose()).scale(0.5); DenseMatrix variance = Randoms.wishart(WishartScalePost, numRows + numColumns); if (variance != null) { hyperParameters.variance = variance; } DenseMatrix normalVariance = hyperParameters.variance.scale(normalBeta0).inv().cholesky(); if (normalVariance != null) { normalVariance = normalVariance.transpose(); DenseVector normalRdn = new DenseVector(numColumns); for (int f = 0; f < numFactors; f++) normalRdn.set(f, Randoms.gaussian(0, 1)); hyperParameters.mu = normalVariance.mult(normalRdn).add(muPost); } return hyperParameters; } protected DenseVector updateParameters(DenseMatrix factors, SparseVector ratings, HyperParameters hyperParameters) throws LibrecException { int num = ratings.getCount(); DenseMatrix XX = new DenseMatrix(num, numFactors); DenseVector ratingsReg = new DenseVector(num); int index = 0; for (int j : ratings.getIndex()) { ratingsReg.set(index, ratings.get(j) - globalMean); XX.setRow(index, factors.row(j)); index++; } DenseMatrix covar = hyperParameters.variance.add((XX.transpose().mult(XX)).scale(ratingSigma)).inv(); DenseVector mu = XX.transpose().mult(ratingsReg).scale(ratingSigma); mu.addEqual(hyperParameters.variance.mult(hyperParameters.mu)); mu = covar.mult(mu); DenseVector factorVector = new DenseVector(numFactors); DenseMatrix lam = covar.cholesky(); if (lam != null) { lam = lam.transpose(); for (int f = 0; f < numFactors; f++) factorVector.set(f, Randoms.gaussian(0, 1)); DenseVector w1_P1_u = lam.mult(factorVector).add(mu); for (int f = 0; f < numFactors; f++) { factorVector.set(f, w1_P1_u.get(f)); } } return factorVector; } @Override protected double predict(int userIdx, int itemIdx) { return predictMatrix.get(userIdx, itemIdx); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy