All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.librec.recommender.cf.rating.NMFRecommender Maven / Gradle / Ivy

/**
 * Copyright (C) 2016 LibRec
 * 

* This file is part of LibRec. * LibRec is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. *

* LibRec is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. *

* You should have received a copy of the GNU General Public License * along with LibRec. If not, see . */ package net.librec.recommender.cf.rating; import net.librec.annotation.ModelData; import net.librec.common.LibrecException; import net.librec.math.structure.DenseMatrix; import net.librec.math.structure.DenseVector; import net.librec.math.structure.MatrixEntry; import net.librec.math.structure.SparseVector; import net.librec.recommender.AbstractRecommender; import net.librec.recommender.MatrixFactorizationRecommender; /** * Daniel D. Lee and H. Sebastian Seung, Algorithms for Non-negative Matrix Factorization, NIPS 2001. * * @author guoguibing and Keqiang Wang */ @ModelData({"isRating", "nmf", "transUserFactors", "transItemFactors"}) public class NMFRecommender extends AbstractRecommender { /** * userFactors and itemFactors matrix transpose */ DenseMatrix transUserFactors; DenseMatrix transItemFactors; /** * the number of latent factors; */ protected int numFactors; /** * the number of iterations */ protected int numIterations; @Override protected void setup() throws LibrecException { super.setup(); numFactors = conf.getInt("rec.factor.number", 10); numIterations = conf.getInt("rec.iterator.maximum",100); transUserFactors = new DenseMatrix(numFactors, numUsers); transItemFactors = new DenseMatrix(numFactors, numItems); transUserFactors.init(0.01); transItemFactors.init(0.01); } @Override protected void trainModel() throws LibrecException { for (int iter = 0; iter <= numIterations; ++iter) { // update userFactors by fixing itemFactors for (int userIdx = 0; userIdx < numUsers; userIdx++) { SparseVector itemRatingsVector = trainMatrix.row(userIdx); if (itemRatingsVector.getCount() > 0) { SparseVector itemPredictsVector = new SparseVector(numItems, itemRatingsVector.size()); for (int itemIdx : itemRatingsVector.getIndex()) { itemPredictsVector.append(itemIdx, predict(userIdx, itemIdx)); } for (int factorIdx = 0; factorIdx < numFactors; factorIdx++) { DenseVector factorItemsVector = transItemFactors.row(factorIdx, false); double realValue = factorItemsVector.inner(itemRatingsVector); double estmValue = factorItemsVector.inner(itemPredictsVector) + 1e-9; transUserFactors.set(factorIdx, userIdx, transUserFactors.get(factorIdx, userIdx) * (realValue / estmValue)); } } } // update itemFactors by fixing userFactors for (int itemIdx = 0; itemIdx < numItems; itemIdx++) { SparseVector userRatingsVector = trainMatrix.column(itemIdx); if (userRatingsVector.getCount() > 0) { SparseVector userPredictsVector = new SparseVector(numUsers, userRatingsVector.size()); for (int userIdx : userRatingsVector.getIndex()) { userPredictsVector.append(userIdx, predict(userIdx, itemIdx)); } for (int factorIdx = 0; factorIdx < numFactors; factorIdx++) { DenseVector factorUsersVector = transUserFactors.row(factorIdx, false); double realValue = factorUsersVector.inner(userRatingsVector); double estmValue = factorUsersVector.inner(userPredictsVector) + 1e-9; transItemFactors.set(factorIdx, itemIdx, transItemFactors.get(factorIdx, itemIdx) * (realValue / estmValue)); } } } // compute errors loss = 0.0d; for (MatrixEntry matrixEntry : trainMatrix) { int userIdx = matrixEntry.row(); int itemIdx = matrixEntry.column(); double rating = matrixEntry.get(); if (rating > 0) { double ratingError = predict(userIdx, itemIdx) - rating; loss += ratingError * ratingError; } } loss *= 0.5d; if (isConverged(iter) && earlyStop) { break; } } } /** * predict a specific rating for user userIdx on item itemIdx. * * @param userIdx user index * @param itemIdx item index * @return predictive rating for user userIdx on item itemIdx * @throws LibrecException if error occurs */ @Override protected double predict(int userIdx, int itemIdx) throws LibrecException { return DenseMatrix.colMult(transUserFactors, userIdx, transItemFactors, itemIdx); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy