All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.librec.recommender.cf.ranking.PLSARecommender Maven / Gradle / Ivy

/**
 * Copyright (C) 2016 LibRec
 * 

* This file is part of LibRec. * LibRec is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. *

* LibRec is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. *

* You should have received a copy of the GNU General Public License * along with LibRec. If not, see . */ package net.librec.recommender.cf.ranking; import com.google.common.collect.HashBasedTable; import com.google.common.collect.Table; import net.librec.common.LibrecException; import net.librec.math.algorithm.Randoms; import net.librec.math.structure.DenseMatrix; import net.librec.math.structure.DenseVector; import net.librec.math.structure.MatrixEntry; import net.librec.recommender.ProbabilisticGraphicalRecommender; /** * Thomas Hofmann, Latent semantic models for collaborative filtering, * ACM Transactions on Information Systems. * 2004.
* * @author Haidong Zhang and Keqiang Wang */ public class PLSARecommender extends ProbabilisticGraphicalRecommender { /** * number of latent topics */ protected int numTopics; /** * {user, item, {topic z, probability}} */ protected Table Q; /** * Conditional Probability: P(z|u) */ protected DenseMatrix userTopicProbs, userTopicProbsSum; /** * Conditional Probability: P(i|z) */ protected DenseMatrix topicItemProbs, topicItemProbsSum; /** * topic probability sum value */ private DenseVector topicProbsSum; /** * entry[u]: number of tokens rated by user u. */ protected DenseVector numItemsRateByUser; @Override protected void setup() throws LibrecException { super.setup(); numTopics = conf.getInt("rec.topic.number", 10); isRanking = true; userTopicProbs = new DenseMatrix(numUsers, numTopics); for (int userIdx = 0; userIdx < numUsers; userIdx++) { double[] probs = Randoms.randProbs(numTopics); for (int topicIdx = 0; topicIdx < numTopics; topicIdx++) { userTopicProbs.set(userIdx, topicIdx, probs[topicIdx]); } } topicItemProbs = new DenseMatrix(numTopics, numItems); for (int topicIdx = 0; topicIdx < numTopics; topicIdx++) { double[] probs = Randoms.randProbs(numItems); for (int itemIdx = 0; itemIdx < numItems; itemIdx++) { topicItemProbs.set(topicIdx, itemIdx, probs[itemIdx]); } } // initialize Q // initialize Q Q = HashBasedTable.create(); for (MatrixEntry matrixEntry : trainMatrix) { int userIdx = matrixEntry.row(); int itemIdx = matrixEntry.column(); Q.put(userIdx, itemIdx, new double[numTopics]); } numItemsRateByUser = new DenseVector(numUsers); for (MatrixEntry matrixEntry : trainMatrix) { int userIdx = matrixEntry.row(); numItemsRateByUser.add(userIdx, matrixEntry.get()); } } @Override protected void eStep() { for (MatrixEntry matrixEntry : trainMatrix) { int userIdx = matrixEntry.row(); int itemIdx = matrixEntry.column(); double[] topicDistr = Q.get(userIdx, itemIdx); double sum = 0.0; for (int topicIdx = 0; topicIdx < numTopics; topicIdx++) { double value = userTopicProbs.get(userIdx, topicIdx) * topicItemProbs.get(topicIdx, itemIdx); topicDistr[topicIdx] = value; sum += value; } sum = sum > 0.0 ? sum : 1.0d; // Normalize along with the latent states for (int topicIdx = 0; topicIdx < numTopics; topicIdx++) { topicDistr[topicIdx] /= sum; } } } @Override protected void mStep() { userTopicProbsSum = new DenseMatrix(numUsers, numTopics); topicItemProbsSum = new DenseMatrix(numTopics, numItems); topicProbsSum = new DenseVector(numTopics); for (MatrixEntry matrixEntry : trainMatrix) { int userIdx = matrixEntry.row(); int itemIdx = matrixEntry.column(); double num = matrixEntry.get(); double[] topicDistr = Q.get(userIdx, itemIdx); for (int topicIdx = 0; topicIdx < numTopics; topicIdx++) { double val = topicDistr[topicIdx] * num; userTopicProbsSum.add(userIdx, topicIdx, val); topicItemProbsSum.add(topicIdx, itemIdx, val); topicProbsSum.add(topicIdx, val); } } for (int userIdx = 0; userIdx < numUsers; userIdx++) { double deno = numItemsRateByUser.get(userIdx); for (int topicIdx = 0; topicIdx < numTopics; topicIdx++) { double value = deno > 0.0d ? userTopicProbsSum.get(userIdx, topicIdx) / numItemsRateByUser.get(userIdx) : 0.0d; userTopicProbs.set(userIdx, topicIdx, value); } } for (int topicIdx = 0; topicIdx < numTopics; topicIdx++) { double itemTopicProbsSum = topicProbsSum.get(topicIdx); for (int itemIdx = 0; itemIdx < numItems; itemIdx++) { double value = itemTopicProbsSum > 0.0d ? topicItemProbsSum.get(topicIdx, itemIdx) / itemTopicProbsSum : 0.0d; topicItemProbs.set(topicIdx, itemIdx, value); } } } @Override protected double predict(int userIdx, int itemIdx) throws LibrecException { return DenseMatrix.product(userTopicProbs, userIdx, topicItemProbs, itemIdx); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy