All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.cf.taste.impl.recommender.svd.SVDPlusPlusFactorizer Maven / Gradle / Ivy

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.cf.taste.impl.recommender.svd;

import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.impl.common.FastIDSet;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.common.RandomUtils;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;

/**
 * SVD++, an enhancement of classical matrix factorization for rating prediction.
 * Additionally to using ratings (how did people rate?) for learning, this model also takes into account
 * who rated what.
 *
 * Yehuda Koren: Factorization Meets the Neighborhood: a Multifaceted Collaborative Filtering Model, KDD 2008.
 * http://research.yahoo.com/files/kdd08koren.pdf
 */
public final class SVDPlusPlusFactorizer extends RatingSGDFactorizer {

  private double[][] p;
  private double[][] y;
  private Map> itemsByUser;

  public SVDPlusPlusFactorizer(DataModel dataModel, int numFeatures, int numIterations) throws TasteException {
    this(dataModel, numFeatures, 0.01, 0.1, 0.01, numIterations, 1.0);
    biasLearningRate = 0.7;
    biasReg = 0.33;
  }

  public SVDPlusPlusFactorizer(DataModel dataModel, int numFeatures, double learningRate, double preventOverfitting,
      double randomNoise, int numIterations, double learningRateDecay) throws TasteException {
    super(dataModel, numFeatures, learningRate, preventOverfitting, randomNoise, numIterations, learningRateDecay);
  }

  @Override
  protected void prepareTraining() throws TasteException {
    super.prepareTraining();
    Random random = RandomUtils.getRandom();

    p = new double[dataModel.getNumUsers()][numFeatures];
    for (int i = 0; i < p.length; i++) {
      for (int feature = 0; feature < FEATURE_OFFSET; feature++) {
        p[i][feature] = 0;
      }
      for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) {
        p[i][feature] = random.nextGaussian() * randomNoise;
      }
    }

    y = new double[dataModel.getNumItems()][numFeatures];
    for (int i = 0; i < y.length; i++) {
      for (int feature = 0; feature < FEATURE_OFFSET; feature++) {
        y[i][feature] = 0;
      }
      for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) {
        y[i][feature] = random.nextGaussian() * randomNoise;
      }
    }

    /* get internal item IDs which we will need several times */
    itemsByUser = new HashMap<>();
    LongPrimitiveIterator userIDs = dataModel.getUserIDs();
    while (userIDs.hasNext()) {
      long userId = userIDs.nextLong();
      int userIndex = userIndex(userId);
      FastIDSet itemIDsFromUser = dataModel.getItemIDsFromUser(userId);
      List itemIndexes = new ArrayList<>(itemIDsFromUser.size());
      itemsByUser.put(userIndex, itemIndexes);
      for (long itemID2 : itemIDsFromUser) {
        int i2 = itemIndex(itemID2);
        itemIndexes.add(i2);
      }
    }
  }

  @Override
  public Factorization factorize() throws TasteException {
    prepareTraining();

    super.factorize();

    for (int userIndex = 0; userIndex < userVectors.length; userIndex++) {
      for (int itemIndex : itemsByUser.get(userIndex)) {
        for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) {
          userVectors[userIndex][feature] += y[itemIndex][feature];
        }
      }
      double denominator = Math.sqrt(itemsByUser.get(userIndex).size());
      for (int feature = 0; feature < userVectors[userIndex].length; feature++) {
        userVectors[userIndex][feature] =
            (float) (userVectors[userIndex][feature] / denominator + p[userIndex][feature]);
      }
    }

    return createFactorization(userVectors, itemVectors);
  }


  @Override
  protected void updateParameters(long userID, long itemID, float rating, double currentLearningRate) {
    int userIndex = userIndex(userID);
    int itemIndex = itemIndex(itemID);

    double[] userVector = p[userIndex];
    double[] itemVector = itemVectors[itemIndex];

    double[] pPlusY = new double[numFeatures];
    for (int i2 : itemsByUser.get(userIndex)) {
      for (int f = FEATURE_OFFSET; f < numFeatures; f++) {
        pPlusY[f] += y[i2][f];
      }
    }
    double denominator = Math.sqrt(itemsByUser.get(userIndex).size());
    for (int feature = 0; feature < pPlusY.length; feature++) {
      pPlusY[feature] = (float) (pPlusY[feature] / denominator + p[userIndex][feature]);
    }

    double prediction = predictRating(pPlusY, itemIndex);
    double err = rating - prediction;
    double normalized_error = err / denominator;

    // adjust user bias
    userVector[USER_BIAS_INDEX] +=
        biasLearningRate * currentLearningRate * (err - biasReg * preventOverfitting * userVector[USER_BIAS_INDEX]);

    // adjust item bias
    itemVector[ITEM_BIAS_INDEX] +=
        biasLearningRate * currentLearningRate * (err - biasReg * preventOverfitting * itemVector[ITEM_BIAS_INDEX]);

    // adjust features
    for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) {
      double pF = userVector[feature];
      double iF = itemVector[feature];

      double deltaU = err * iF - preventOverfitting * pF;
      userVector[feature] += currentLearningRate * deltaU;

      double deltaI = err * pPlusY[feature] - preventOverfitting * iF;
      itemVector[feature] += currentLearningRate * deltaI;

      double commonUpdate = normalized_error * iF;
      for (int itemIndex2 : itemsByUser.get(userIndex)) {
        double deltaI2 = commonUpdate - preventOverfitting * y[itemIndex2][feature];
        y[itemIndex2][feature] += learningRate * deltaI2;
      }
    }
  }

  private double predictRating(double[] userVector, int itemID) {
    double sum = 0;
    for (int feature = 0; feature < numFeatures; feature++) {
      sum += userVector[feature] * itemVectors[itemID][feature];
    }
    return sum;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy