All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.librec.recommender.cf.ranking.ItemBigramRecommender Maven / Gradle / Ivy

/**
 * Copyright (C) 2016 LibRec
 * 

* This file is part of LibRec. * LibRec is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. *

* LibRec is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. *

* You should have received a copy of the GNU General Public License * along with LibRec. If not, see . */ package net.librec.recommender.cf.ranking; import com.google.common.collect.HashBasedTable; import com.google.common.collect.Table; import net.librec.annotation.ModelData; import net.librec.common.LibrecException; import net.librec.math.algorithm.Randoms; import net.librec.math.structure.DenseMatrix; import net.librec.math.structure.DenseVector; import net.librec.math.structure.SparseMatrix; import net.librec.recommender.ProbabilisticGraphicalRecommender; import net.librec.util.RatingContext; import java.util.*; import static net.librec.math.algorithm.Gamma.digamma; /** * Hanna M. Wallach, Topic Modeling: Beyond Bag-of-Words, ICML 2006. * * @author Keqiang Wang **/ @ModelData({"isRanking", "itembigram", "userTopicProbs", "topicPreItemCurItemProbs"}) public class ItemBigramRecommender extends ProbabilisticGraphicalRecommender { private Map> userItemsMap; /** * k: current topic; j: previously rated item; i: current item */ private int[][][] topicPreItemCurItemNum; private DenseMatrix topicItemProbs; private double[][][] topicPreItemCurItemProbs, topicPreItemCurItemSumProbs; private DenseMatrix beta; /** * vector of hyperparameters for alpha */ protected DenseVector alpha; /** * number of topics */ protected int numTopics; /** * Dirichlet hyper-parameters of user-topic distribution: typical value is 50/K */ protected float initAlpha; /** * Dirichlet hyper-parameters of topic-item distribution, typical value is 0.01 */ protected float initBeta; /** * cumulative statistics of theta, phi */ protected DenseMatrix userTopicProbsSum; /** * entry[u, k]: number of tokens assigned to topic k, given user u. */ protected DenseMatrix userTopicNumbers; /** * entry[u]: number of tokens rated by user u. */ protected DenseVector userTokenNumbers; /** * posterior probabilities of parameters */ protected DenseMatrix userTopicProbs; /** * entry[u, i, k]: topic assignment as sparse structure */ protected Table topicAssignments; /** * time sparse matrix */ private SparseMatrix timeMatrix; @Override protected void setup() throws LibrecException { super.setup(); numTopics = conf.getInt("rec.topic.number", 10); initAlpha = conf.getFloat("rec.user.dirichlet.prior", 0.01f); initBeta = conf.getFloat("rec.topic.dirichlet.prior", 0.01f); timeMatrix = (SparseMatrix) getDataModel().getDatetimeDataSet(); // build the training data, sorting by date userItemsMap = new HashMap<>(); for (int userIdx = 0; userIdx < numUsers; userIdx++) { List unsortedItems = trainMatrix.getColumns(userIdx); int size = unsortedItems.size(); List rcs = new ArrayList<>(size); for (Integer itemIdx : unsortedItems) { rcs.add(new RatingContext(userIdx, itemIdx, (long) timeMatrix.get(userIdx, itemIdx))); } Collections.sort(rcs); List sortedItems = new ArrayList<>(size); for (RatingContext rc : rcs) { sortedItems.add(rc.getItem()); } userItemsMap.put(userIdx, sortedItems); } // count variables // initialize count variables. userTopicNumbers = new DenseMatrix(numUsers, numTopics); userTokenNumbers = new DenseVector(numUsers); topicPreItemCurItemNum = new int[numTopics][numItems + 1][numItems]; topicItemProbs = new DenseMatrix(numTopics, numItems + 1); // Logs.debug("topicPreItemCurItemNum consumes {} bytes", Strings.toString(Memory.bytes(topicPreItemCurItemNum))); // parameters userTopicProbsSum = new DenseMatrix(numUsers, numTopics); topicPreItemCurItemSumProbs = new double[numTopics][numItems + 1][numItems]; topicPreItemCurItemProbs = new double[numTopics][numItems + 1][numItems]; // hyper-parameters alpha = new DenseVector(numTopics); alpha.setAll(initAlpha); beta = new DenseMatrix(numTopics, numItems + 1); beta.setAll(initBeta); // initialization topicAssignments = HashBasedTable.create(); for (Map.Entry> userItemEntry : userItemsMap.entrySet()) { int userIdx = userItemEntry.getKey(); List itemIdxList = userItemEntry.getValue(); for (int itemIdxIndex = 0; itemIdxIndex < itemIdxList.size(); itemIdxIndex++) { int itemIdx = itemIdxList.get(itemIdxIndex); int topicIdx = (int) (Math.random() * numTopics); topicAssignments.put(userIdx, itemIdx, topicIdx); userTopicNumbers.add(userIdx, topicIdx, 1.0); userTokenNumbers.add(userIdx, 1.0); int preItemIdx = itemIdxIndex > 0 ? itemIdxList.get(itemIdxIndex - 1) : numItems; topicPreItemCurItemNum[topicIdx][preItemIdx][itemIdx]++; topicItemProbs.add(topicIdx, preItemIdx, 1); } } } @Override protected void eStep() { double sumAlpha = alpha.sum(); double tempValue1, tempValue2; for (Map.Entry> userItemEntry : userItemsMap.entrySet()) { int userIdx = userItemEntry.getKey(); List items = userItemEntry.getValue(); for (int itemIdxIndex = 0; itemIdxIndex < items.size(); itemIdxIndex++) { int itemIdx = items.get(itemIdxIndex); int topicIdx = topicAssignments.get(userIdx, itemIdx); userTopicNumbers.add(userIdx, topicIdx, -1.0); userTokenNumbers.add(userIdx, -1.0); int preItemIdx = itemIdxIndex > 0 ? items.get(itemIdxIndex - 1) : numItems; topicPreItemCurItemNum[topicIdx][preItemIdx][itemIdx]--; topicItemProbs.add(topicIdx, preItemIdx, -1); double[] tempUserProbs = new double[numTopics]; for (int topicInIdx = 0; topicInIdx < numTopics; topicInIdx++) { tempValue1 = (userTopicNumbers.get(userIdx, topicIdx) + alpha.get(topicInIdx)) / (userTokenNumbers.get(userIdx) + sumAlpha); tempValue2 = (topicPreItemCurItemNum[topicInIdx][preItemIdx][itemIdx] + beta.get(topicInIdx, preItemIdx)) / (topicItemProbs.get(topicInIdx, preItemIdx) + beta.sumOfRow(topicInIdx)); tempUserProbs[topicInIdx] = tempValue1 * tempValue2; } for (int topicInIdx = 1; topicInIdx < numTopics; topicInIdx++) { tempUserProbs[topicInIdx] += tempUserProbs[topicInIdx - 1]; } double rand = Randoms.uniform() * tempUserProbs[numTopics - 1]; for (topicIdx = 0; topicIdx < numTopics; topicIdx++) { if (rand < tempUserProbs[topicIdx]) break; } topicAssignments.put(userIdx, itemIdx, topicIdx); userTopicNumbers.add(userIdx, topicIdx, 1.0d); userTokenNumbers.add(userIdx, 1.0d); topicPreItemCurItemNum[topicIdx][preItemIdx][itemIdx]++; topicItemProbs.add(topicIdx, preItemIdx, 1.0d); } } } @Override protected void mStep() { double sumAlpha = alpha.sum(); for (int topicIdx = 0; topicIdx < numTopics; topicIdx++) { double alphaTopicValue = alpha.get(topicIdx); double numerator = 0, denominator = 0; for (int userIdx = 0; userIdx < numUsers; userIdx++) { numerator += digamma(userTopicNumbers.get(userIdx, topicIdx) + alphaTopicValue) - digamma(alphaTopicValue); denominator += digamma(userTokenNumbers.get(userIdx) + sumAlpha) - digamma(sumAlpha); } if (numerator != 0) alpha.set(topicIdx, alphaTopicValue * (numerator / denominator)); } for (int topicIdx = 0; topicIdx < numTopics; topicIdx++) { double betaTopicValue = beta.sumOfRow(topicIdx); for (int itemIdx = 0; itemIdx < numItems + 1; itemIdx++) { double betaTopicItemValue = beta.get(topicIdx, itemIdx); double numerator = 0.0d, denominator = 0.0d; for (int preItemIdx = 0; preItemIdx < numItems; preItemIdx++) { numerator += digamma(topicPreItemCurItemNum[topicIdx][itemIdx][preItemIdx] + betaTopicItemValue) - digamma(betaTopicItemValue); denominator += digamma(topicItemProbs.get(topicIdx, itemIdx) + betaTopicValue) - digamma(betaTopicValue); } if (numerator != 0) beta.set(topicIdx, itemIdx, betaTopicItemValue * (numerator / denominator)); } } } @Override protected void readoutParams() { double val; double sumAlpha = alpha.sum(); for (int userIdx = 0; userIdx < numTopics; userIdx++) { for (int topicIdx = 0; topicIdx < numTopics; topicIdx++) { val = (userTopicNumbers.get(userIdx, topicIdx) + alpha.get(topicIdx)) / (userTokenNumbers.get(userIdx) + sumAlpha); userTopicProbsSum.add(userIdx, topicIdx, val); } } for (int topicIdx = 0; topicIdx < numTopics; topicIdx++) { double betaTopicValue = beta.sumOfRow(topicIdx); for (int itemIdx = 0; itemIdx < numItems + 1; itemIdx++) { for (int preItemIdx = 0; preItemIdx < numItems; preItemIdx++) { val = (topicPreItemCurItemNum[topicIdx][itemIdx][preItemIdx] + beta.get(topicIdx, itemIdx)) / (topicItemProbs.get(topicIdx, itemIdx) + betaTopicValue); topicPreItemCurItemSumProbs[topicIdx][itemIdx][preItemIdx] += val; } } } numStats++; } @Override protected void estimateParams() { userTopicProbs = userTopicProbsSum.scale(1.0 / numStats); for (int topicIdx = 0; topicIdx < numTopics; topicIdx++) { for (int itemIdx = 0; itemIdx < numItems + 1; itemIdx++) { for (int preItemIdx = 0; preItemIdx < numItems; preItemIdx++) { topicPreItemCurItemProbs[topicIdx][itemIdx][preItemIdx] = topicPreItemCurItemSumProbs[topicIdx][itemIdx][preItemIdx] / numStats; } } } } @Override protected double predict(int userIdx, int itemIdx) throws LibrecException { List items = userItemsMap.get(userIdx); int preItemIdx = items.size() < 1 ? numItems : items.get(items.size() - 1); // last rated item double predictRating = 0; for (int topicIdx = 0; topicIdx < numTopics; topicIdx++) { predictRating += userTopicProbs.get(userIdx, topicIdx) * topicPreItemCurItemProbs[topicIdx][preItemIdx][itemIdx]; } return predictRating; } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy