All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.librec.recommender.cf.ranking.BPRRecommender Maven / Gradle / Ivy

/**
 * Copyright (C) 2016 LibRec
 * 

* This file is part of LibRec. * LibRec is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. *

* LibRec is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. *

* You should have received a copy of the GNU General Public License * along with LibRec. If not, see . */ package net.librec.recommender.cf.ranking; import net.librec.annotation.ModelData; import net.librec.common.LibrecException; import net.librec.math.algorithm.Maths; import net.librec.math.algorithm.Randoms; import net.librec.math.structure.MatrixEntry; import net.librec.math.structure.SparseMatrix; import net.librec.recommender.MatrixFactorizationRecommender; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; /** * Rendle et al., BPR: Bayesian Personalized Ranking from Implicit Feedback, UAI 2009. * * @author GuoGuibing and Keqiang Wang */ @ModelData({"isRanking", "bpr", "userFactors", "itemFactors"}) public class BPRRecommender extends MatrixFactorizationRecommender { private List> userItemsSet; @Override protected void setup() throws LibrecException { super.setup(); } @Override protected void trainModel() throws LibrecException { userItemsSet = getUserItemsSet(trainMatrix); for (int iter = 1; iter <= numIterations; iter++) { loss = 0.0d; for (int sampleCount = 0, smax = numUsers * 100; sampleCount < smax; sampleCount++) { // randomly draw (userIdx, posItemIdx, negItemIdx) int userIdx, posItemIdx, negItemIdx; while (true) { userIdx = Randoms.uniform(numUsers); Set itemSet = userItemsSet.get(userIdx); if (itemSet.size() == 0 || itemSet.size() == numItems) continue; List itemList = trainMatrix.getColumns(userIdx); posItemIdx = itemList.get(Randoms.uniform(itemList.size())); do { negItemIdx = Randoms.uniform(numItems); } while (itemSet.contains(negItemIdx)); break; } // update parameters double posPredictRating = predict(userIdx, posItemIdx); double negPredictRating = predict(userIdx, negItemIdx); double diffValue = posPredictRating - negPredictRating; double lossValue = -Math.log(Maths.logistic(diffValue)); loss += lossValue; double deriValue = Maths.logistic(-diffValue); for (int factorIdx = 0; factorIdx < numFactors; factorIdx++) { double userFactorValue = userFactors.get(userIdx, factorIdx); double posItemFactorValue = itemFactors.get(posItemIdx, factorIdx); double negItemFactorValue = itemFactors.get(negItemIdx, factorIdx); userFactors.add(userIdx, factorIdx, learnRate * (deriValue * (posItemFactorValue - negItemFactorValue) - regUser * userFactorValue)); itemFactors.add(posItemIdx, factorIdx, learnRate * (deriValue * userFactorValue - regItem * posItemFactorValue)); itemFactors.add(negItemIdx, factorIdx, learnRate * (deriValue * (-userFactorValue) - regItem * negItemFactorValue)); loss += regUser * userFactorValue * userFactorValue + regItem * posItemFactorValue * posItemFactorValue + regItem * negItemFactorValue * negItemFactorValue; } } if (isConverged(iter) && earlyStop) { break; } updateLRate(iter); } } private List> getUserItemsSet(SparseMatrix sparseMatrix) { List> userItemsSet = new ArrayList<>(); for (int userIdx = 0; userIdx < numUsers; ++userIdx) { userItemsSet.add(new HashSet(sparseMatrix.getColumns(userIdx))); } return userItemsSet; } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy