All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.librec.recommender.cf.ranking.CLIMFRecommender Maven / Gradle / Ivy

/**
 * Copyright (C) 2016 LibRec
 * 

* This file is part of LibRec. * LibRec is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. *

* LibRec is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. *

* You should have received a copy of the GNU General Public License * along with LibRec. If not, see . */ package net.librec.recommender.cf.ranking; import net.librec.annotation.ModelData; import net.librec.common.LibrecException; import net.librec.math.algorithm.Maths; import net.librec.math.structure.SparseMatrix; import net.librec.recommender.MatrixFactorizationRecommender; import java.util.*; /** * Shi et al., Climf: learning to maximize reciprocal rank with collaborative less-is-more filtering., * RecSys 2012. * * @author Guibing Guo, Chen Ma and Keqiang Wang */ @ModelData({"isRanking", "climf", "userFactors", "itemFactors"}) public class CLIMFRecommender extends MatrixFactorizationRecommender { private List> userItemsSet; @Override protected void setup() throws LibrecException { super.setup(); } @Override protected void trainModel() throws LibrecException { userItemsSet = getUserItemsSet(trainMatrix); for (int iter = 1; iter <= numIterations; iter++) { loss = 0.0f; for (int userIdx = 0; userIdx < numUsers; userIdx++) { Set itemSet = userItemsSet.get(userIdx); double[] sgds = new double[numFactors]; for (int factorIdx = 0; factorIdx < numFactors; factorIdx++) { double sgd = -regUser * userFactors.get(userIdx, factorIdx); for (int itemIdx : itemSet) { double predictValue = predict(userIdx, itemIdx); double itemFactorValue = itemFactors.get(itemIdx, factorIdx); sgd += Maths.logistic(-predictValue) * itemFactorValue; for (int compareItemIdx : itemSet) { if (compareItemIdx == itemIdx) { continue; } double compPredictValue = predict(userIdx, compareItemIdx); double compItemFactorValue = itemFactors.get(compareItemIdx, factorIdx); double diffValue = compPredictValue - predictValue; sgd += Maths.logisticGradientValue(diffValue) / (1 - Maths.logistic(diffValue)) * (itemFactorValue - compItemFactorValue); } } sgds[factorIdx] = sgd; } Map> itemsSgds = new HashMap<>(); for (int itemIdx : itemSet) { double predictValue = predict(userIdx, itemIdx); List itemSgds = new ArrayList<>(); for (int factorIdx = 0; factorIdx < numFactors; factorIdx++) { double userFactorValue = userFactors.get(userIdx, factorIdx); double itemFactorValue = itemFactors.get(itemIdx, factorIdx); double judgeValue = 1.0d; double sgd = judgeValue * Maths.logistic(-predictValue) * userFactorValue - regItem * itemFactorValue; for (int compItemIdx : itemSet) { if (compItemIdx == itemIdx) { continue; } double compPredictValue = predict(userIdx, compItemIdx); double diffValue = compPredictValue - predictValue; sgd += Maths.logisticGradientValue(-diffValue) * (1.0d / (1 - Maths.logistic(diffValue)) - 1.0d / (1 - Maths.logistic(-diffValue))) * userFactorValue; } itemSgds.add(sgd); } itemsSgds.put(itemIdx, itemSgds); } for (int factorIdx = 0; factorIdx < numFactors; factorIdx++) { userFactors.add(userIdx, factorIdx, learnRate * sgds[factorIdx]); } for (int itemIdx : itemSet) { List itemSgds = itemsSgds.get(itemIdx); for (int factorIdx = 0; factorIdx < numFactors; factorIdx++) { itemFactors.add(itemIdx, factorIdx, learnRate * itemSgds.get(factorIdx)); } } for (int itemIdx = 0; itemIdx < numItems; itemIdx++) { if (itemSet.contains(itemIdx)) { double predictValue = predict(userIdx, itemIdx); loss += Math.log(Maths.logistic(predictValue)); for (int compItemIdx : itemSet) { double compPredictValue = predict(userIdx, compItemIdx); loss += Math.log(1 - Maths.logistic(compPredictValue - predictValue)); } } for (int factorIdx = 0; factorIdx < numFactors; factorIdx++) { double userFactorValue = userFactors.get(userIdx, factorIdx); double itemFactorValue = itemFactors.get(itemIdx, factorIdx); loss += -0.5 * (regUser * userFactorValue * userFactorValue + regItem * itemFactorValue * itemFactorValue); } } } if (isConverged(iter) && earlyStop) { break; } updateLRate(iter); } } private List> getUserItemsSet(SparseMatrix sparseMatrix) { List> userItemsSet = new ArrayList<>(); for (int userIdx = 0; userIdx < numUsers; ++userIdx) { userItemsSet.add(new HashSet(sparseMatrix.getColumns(userIdx))); } return userItemsSet; } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy