net.librec.recommender.cf.ranking.AoBPRRecommender Maven / Gradle / Ivy
Show all versions of librec-core Show documentation
/**
* Copyright (C) 2016 LibRec
*
* This file is part of LibRec.
* LibRec is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LibRec is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LibRec. If not, see .
*/
package net.librec.recommender.cf.ranking;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Maths;
import net.librec.math.algorithm.Randoms;
import net.librec.math.algorithm.Stats;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.SequentialAccessSparseMatrix;
import net.librec.recommender.MatrixFactorizationRecommender;
import net.librec.util.Lists;
import org.apache.commons.lang.ArrayUtils;
import java.util.*;
import static net.librec.math.algorithm.Maths.logistic;
/**
* AoBPR: BPR with Adaptive Oversampling
*
* Rendle and Freudenthaler, Improving pairwise learning for item recommendation from implicit
* feedback, WSDM 2014.
*
* @author guoguibing and Keqiang Wang
*/
@ModelData({"isRanking", "aobpr", "userFactors", "itemFactors"})
public class AoBPRRecommender extends MatrixFactorizationRecommender {
private int loopNumber;
/**
* item geometric distribution parameter
*/
private int lambdaItem;
private double[] var;
private int[][] factorRanking;
private double[] RankingPro;
private List> userItemsSet;
@Override
protected void setup() throws LibrecException {
super.setup();
//set for this alg
lambdaItem = (int) (conf.getFloat("rec.item.distribution.parameter") * numItems);
//lamda_Item=500;
loopNumber = (int) (numItems * Math.log(numItems));
var = new double[numFactors];
factorRanking = new int[numFactors][numItems];
RankingPro = new double[numItems];
double sum = 0;
for (int i = 0; i < numItems; i++) {
RankingPro[i] = Math.exp(-(i + 1) / lambdaItem);
sum += RankingPro[i];
}
for (int i = 0; i < numItems; i++) {
RankingPro[i] /= sum;
}
}
@Override
protected void trainModel() throws LibrecException {
userItemsSet = getUserItemsSet(trainMatrix);
List[] dataLists = getTrainList(trainMatrix);
List userTrainList = dataLists[0];
List itemTrainList = dataLists[1];
int countIter = 0;
int maxSample = trainMatrix.size();
for (int iter = 1; iter <= numIterations; iter++) {
loss = 0.0d;
for (int s = 0; s < maxSample; s++) {
//update Ranking every |I|log|I|
if (countIter % loopNumber == 0) {
updateRankingInFactor();
countIter = 0;
}
countIter++;
// randomly draw (u, i, j)
int userIdx, posItemIdx, negItemIdx;
while (true) {
int dataIdx = Randoms.uniform(numRates);
userIdx = userTrainList.get(dataIdx);
Set itemSet = userItemsSet.get(userIdx);
if (itemSet.size() == 0 || itemSet.size() == numItems)
continue;
posItemIdx = itemTrainList.get(dataIdx);
do {
//randoms get a r by exp(-r/lamda)
int randomNegItemIndex = 0;
do {
randomNegItemIndex = Randoms.discrete(RankingPro);
} while (randomNegItemIndex > numItems);
//randoms get a f by p(f|c)
double[] pfc = new double[numFactors];
double sumfc = 0;
for (int pfcFactprIdx = 0; pfcFactprIdx < numFactors; pfcFactprIdx++) {
double tempAbsValue = Math.abs(userFactors.get(userIdx, pfcFactprIdx));
sumfc += tempAbsValue * var[pfcFactprIdx];
pfc[pfcFactprIdx] = tempAbsValue * var[pfcFactprIdx];
}
//normalization
for (int pfcFactprIdx = 0; pfcFactprIdx < numFactors; pfcFactprIdx++) {
pfc[pfcFactprIdx] /= sumfc;
}
int factorIdx = Randoms.discrete(pfc);
//get the r-1 in f item
if (userFactors.get(userIdx, factorIdx) > 0) {
negItemIdx = factorRanking[factorIdx][randomNegItemIndex];
} else {
negItemIdx = factorRanking[factorIdx][numItems - randomNegItemIndex - 1];
}
} while (itemSet.contains(negItemIdx));
break;
}
// update parameters
double posPredictRating = predict(userIdx, posItemIdx);
double negPredictRating = predict(userIdx, negItemIdx);
double diffValue = posPredictRating - negPredictRating;
double lossValue = -Math.log(Maths.logistic(diffValue));
loss += lossValue;
double deriValue = logistic(-diffValue);
for (int factorIdx = 0; factorIdx < numFactors; factorIdx++) {
double userFactorValue = userFactors.get(userIdx, factorIdx);
double posItemFactorValue = itemFactors.get(posItemIdx, factorIdx);
double negItemFactorValue = itemFactors.get(negItemIdx, factorIdx);
userFactors.plus(userIdx, factorIdx, learnRate * (deriValue * (posItemFactorValue - negItemFactorValue) - regUser * userFactorValue));
itemFactors.plus(posItemIdx, factorIdx, learnRate * (deriValue * userFactorValue - regItem * posItemFactorValue));
itemFactors.plus(negItemIdx, factorIdx, learnRate * (deriValue * (-userFactorValue) - regItem * negItemFactorValue));
loss += regUser * userFactorValue * userFactorValue + regItem * posItemFactorValue * posItemFactorValue + regItem * negItemFactorValue * negItemFactorValue;
}
}
if (isConverged(iter) && earlyStop) {
break;
}
updateLRate(iter);
}
}
public List> sortByDenseVectorValue(DenseVector vector) {
List> sortList = new ArrayList<>();
for (int itemIdx = 0; itemIdx < numItems; itemIdx++) {
sortList.add(new AbstractMap.SimpleImmutableEntry(itemIdx, vector.get(itemIdx)));
}
Lists.sortList(sortList, true);
return sortList;
}
public void updateRankingInFactor() {
//echo for each factors
for (int factorIdx = 0; factorIdx < numFactors; factorIdx++) {
// VectorBasedDenseVector factorVector = itemFactors.column(factorIdx).clone();
DenseVector factorVector = itemFactors.column(factorIdx).clone();
List> sort = sortByDenseVectorValue(factorVector);
double[] valueList = new double[numItems];
for (int itemIdx = 0; itemIdx < numItems; itemIdx++) {
factorRanking[factorIdx][itemIdx] = sort.get(itemIdx).getKey();
valueList[itemIdx] = sort.get(itemIdx).getValue();
}
//get
var[factorIdx] = Stats.variance(valueList);
}
}
private List> getUserItemsSet(SequentialAccessSparseMatrix sparseMatrix) {
List> userItemsSet = new ArrayList<>();
for (int userIdx = 0; userIdx < numUsers; ++userIdx) {
int[] itemIndexes = sparseMatrix.row(userIdx).getIndices();
Integer[] inputBoxed = ArrayUtils.toObject(itemIndexes);
List itemList = Arrays.asList(inputBoxed);
userItemsSet.add(new HashSet(itemList));
}
return userItemsSet;
}
private List[] getTrainList(SequentialAccessSparseMatrix sparseMatrix) {
List userTrainList = new ArrayList<>(), itemTrainList = new ArrayList<>();
for (MatrixEntry matrixEntry : sparseMatrix) {
int userIdx = matrixEntry.row();
int itemIdx = matrixEntry.column();
userTrainList.add(userIdx);
itemTrainList.add(itemIdx);
}
return new List[]{userTrainList, itemTrainList};
}
}