All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.grouplens.lenskit.mf.svd.MFModel Maven / Gradle / Ivy

There is a newer version: 3.0-T5
Show newest version
/*
 * LensKit, an open source recommender systems toolkit.
 * Copyright 2010-2014 LensKit Contributors.  See CONTRIBUTORS.md.
 * Work on LensKit has been funded by the National Science Foundation under
 * grants IIS 05-34939, 08-08692, 08-12148, and 10-17697.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as
 * published by the Free Software Foundation; either version 2.1 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
 * details.
 *
 * You should have received a copy of the GNU General Public License along with
 * this program; if not, write to the Free Software Foundation, Inc., 51
 * Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
 */
package org.grouplens.lenskit.mf.svd;

import com.google.common.base.Preconditions;
import mikera.matrixx.IMatrix;
import mikera.matrixx.Matrix;
import mikera.matrixx.impl.ImmutableMatrix;
import mikera.vectorz.AVector;
import org.grouplens.lenskit.indexes.IdIndexMapping;

import javax.annotation.Nullable;
import java.io.*;

/**
 * Common model for matrix factorization (SVD) recommendation.
 *
 * @since 2.1
 * @author GroupLens Research
 */
public class MFModel implements Serializable {
    private static final long serialVersionUID = 2L;

    // FIXME Make these final again
    protected int featureCount;
    protected int userCount;
    protected int itemCount;

    protected ImmutableMatrix userMatrix;
    protected ImmutableMatrix itemMatrix;
    protected IdIndexMapping userIndex;
    protected IdIndexMapping itemIndex;

    /**
     * Construct a matrix factorization model.  The matrices are not copied, so the caller should
     * make sure they won't be modified by anyone else.
     *
     * @param umat The user feature matrix (users x features).
     * @param imat The item feature matrix (items x features).
     * @param uidx The user index mapping.
     * @param iidx The item index mapping.
     */
    public MFModel(ImmutableMatrix umat, ImmutableMatrix imat,
                   IdIndexMapping uidx, IdIndexMapping iidx) {
        Preconditions.checkArgument(umat.columnCount() == imat.columnCount(),
                                    "mismatched matrix sizes");
        featureCount = umat.columnCount();
        userCount = uidx.size();
        itemCount = iidx.size();
        Preconditions.checkArgument(umat.rowCount() == userCount,
                                    "user matrix has %s rows, expected %s",
                                    umat.rowCount(), userCount);
        Preconditions.checkArgument(imat.rowCount() == itemCount,
                                    "item matrix has %s rows, expected %s",
                                    imat.rowCount(), itemCount);
        userMatrix = umat;
        itemMatrix = imat;
        userIndex = uidx;
        itemIndex = iidx;
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.writeInt(featureCount);
        out.writeInt(userCount);
        out.writeInt(itemCount);

        for (int i = 0; i < userCount; i++) {
            for (int j = 0; j < featureCount; j++) {
                out.writeDouble(userMatrix.get(i, j));
            }
        }

        for (int i = 0; i < itemCount; i++) {
            for (int j = 0; j < featureCount; j++) {
                out.writeDouble(itemMatrix.get(i, j));
            }
        }

        out.writeObject(userIndex);
        out.writeObject(itemIndex);
    }

    private void readObject(ObjectInputStream input) throws IOException, ClassNotFoundException {
        featureCount = input.readInt();
        userCount = input.readInt();
        itemCount = input.readInt();

        Matrix umat = Matrix.create(userCount, featureCount);
        for (int i = 0; i < userCount; i++) {
            for (int j = 0; j < featureCount; j++) {
                umat.set(i, j, input.readDouble());
            }
        }
        userMatrix = ImmutableMatrix.wrap(umat);

        Matrix imat = Matrix.create(itemCount, featureCount);
        for (int i = 0; i < itemCount; i++) {
            for (int j = 0; j < featureCount; j++) {
                imat.set(i, j, input.readDouble());
            }
        }
        itemMatrix = ImmutableMatrix.wrap(imat);

        userIndex = (IdIndexMapping) input.readObject();
        itemIndex = (IdIndexMapping) input.readObject();

        if (userIndex.size() != userMatrix.rowCount()) {
            throw new InvalidObjectException("user matrix and index have different row counts");
        }
        if (itemIndex.size() != itemMatrix.rowCount()) {
            throw new InvalidObjectException("item matrix and index have different row counts");
        }
    }

    /**
     * Get the model's feature count.
     *
     * @return The model's feature count.
     */
    public int getFeatureCount() {
        return featureCount;
    }

    /**
     * The number of users.
     */
    public int getUserCount() {
        return userCount;
    }

    /**
     * The number of items.
     */
    public int getItemCount() {
        return itemCount;
    }

    /**
     * The item index mapping.
     */
    public IdIndexMapping getItemIndex() {
        return itemIndex;
    }

    /**
     * The user index mapping.
     */
    public IdIndexMapping getUserIndex() {
        return userIndex;
    }

    /**
     * Get the user matrix.
     * @return The user matrix (users x features).
     */
    public IMatrix getUserMatrix() {
        return userMatrix;
    }

    /**
     * Get the item matrix.
     * @return The item matrix (items x features).
     */
    public IMatrix getItemMatrix() {
        return itemMatrix;
    }

    @Nullable
    public AVector getUserVector(long user) {
        int uidx = userIndex.tryGetIndex(user);
        if (uidx < 0) {
            return null;
        } else {
            return userMatrix.getRow(uidx);
        }
    }

    @Nullable
    public AVector getItemVector(long item) {
        int iidx = itemIndex.tryGetIndex(item);
        if (iidx < 0) {
            return null;
        } else {
            return itemMatrix.getRow(iidx);
        }
    }

    /**
     * Get a particular feature value for an user.
     * @param uid The item ID.
     * @param feature The feature.
     * @return The user-feature value, or 0 if the user was not in the training set.
     */
    public double getUserFeature(long uid, int feature) {
        int uidx = userIndex.tryGetIndex(uid);
        if (uidx < 0) {
            return 0;
        } else {
            return userMatrix.get(uidx, feature);
        }
    }

    /**
     * Get a particular feature value for an item.
     * @param iid The item ID.
     * @param feature The feature.
     * @return The item-feature value, or 0 if the item was not in the training set.
     */
    public double getItemFeature(long iid, int feature) {
        int iidx = itemIndex.tryGetIndex(iid);
        if (iidx < 0) {
            return 0;
        } else {
            return itemMatrix.get(iidx, feature);
        }
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(getClass().getSimpleName())
          .append("(nu=")
          .append(getUserCount())
          .append(", ni=")
          .append(getItemCount())
          .append(", nf=")
          .append(getFeatureCount())
          .append(")");
        return sb.toString();
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy