All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.lenskit.knn.user.UserSnapshot Maven / Gradle / Ivy

There is a newer version: 3.0-T5
Show newest version
/*
 * LensKit, an open source recommender systems toolkit.
 * Copyright 2010-2014 LensKit Contributors.  See CONTRIBUTORS.md.
 * Work on LensKit has been funded by the National Science Foundation under
 * grants IIS 05-34939, 08-08692, 08-12148, and 10-17697.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as
 * published by the Free Software Foundation; either version 2.1 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
 * details.
 *
 * You should have received a copy of the GNU General Public License along with
 * this program; if not, write to the Free Software Foundation, Inc., 51
 * Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
 */
package org.lenskit.knn.user;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import it.unimi.dsi.fastutil.longs.*;
import org.grouplens.grapht.annotation.DefaultProvider;
import org.lenskit.inject.Shareable;
import org.lenskit.inject.Transient;
import org.lenskit.util.io.ObjectStream;
import org.lenskit.data.dao.UserEventDAO;
import org.lenskit.data.events.Event;
import org.lenskit.data.history.UserHistory;
import org.grouplens.lenskit.data.history.UserHistorySummarizer;
import org.grouplens.lenskit.transform.normalize.UserVectorNormalizer;
import org.grouplens.lenskit.vectors.ImmutableSparseVector;
import org.grouplens.lenskit.vectors.MutableSparseVector;
import org.lenskit.util.collections.LongUtils;
import org.lenskit.util.keys.SortedKeyIndex;

import javax.annotation.concurrent.ThreadSafe;
import javax.inject.Inject;
import javax.inject.Provider;
import java.io.Serializable;
import java.util.List;

/**
 * User snapshot used by {@link SnapshotNeighborFinder}.
 * @author GroupLens Research
 * @since 2.1
 */
@Shareable
@ThreadSafe
@DefaultProvider(UserSnapshot.Builder.class)
public class UserSnapshot implements Serializable {
    private static final long serialVersionUID = 1L;
    private final SortedKeyIndex users;
    private final List vectors;
    private final List normedVectors;
    private final Long2ObjectMap itemUserSets;

    /**
     * Construct a user snapshot.
     * @param us The set of users.
     * @param vs The list of raw user vectors.
     * @param nvs The list of normalized user vectors.
     */
    UserSnapshot(SortedKeyIndex us, List vs, List nvs,
                 Long2ObjectMap iuSets) {
        Preconditions.checkArgument(vs.size() == us.size(),
                                    "incorrectly sized vector list");
        Preconditions.checkArgument(nvs.size() == us.size(),
                                    "incorrectly sized normalized vector list");
        users = us;
        vectors = ImmutableList.copyOf(vs);
        normedVectors = ImmutableList.copyOf(nvs);
        itemUserSets = iuSets;
    }

    public ImmutableSparseVector getUserVector(long user) {
        int idx = users.tryGetIndex(user);
        Preconditions.checkArgument(idx >= 0, "invalid user " + user);
        return vectors.get(idx);
    }

    public ImmutableSparseVector getNormalizedUserVector(long user) {
        int idx = users.tryGetIndex(user);
        Preconditions.checkArgument(idx >= 0, "invalid user " + user);
        return normedVectors.get(idx);
    }

    public LongSet getItemUsers(long item) {
        return itemUserSets.get(item);
    }

    public static class Builder implements Provider {
        private final UserEventDAO userEventDAO;
        private final UserVectorNormalizer normalizer;
        private final UserHistorySummarizer summarizer;

        @Inject
        public Builder(@Transient UserEventDAO dao,
                       @Transient UserVectorNormalizer norm,
                       @Transient UserHistorySummarizer sum) {
            userEventDAO = dao;
            normalizer = norm;
            summarizer = sum;
        }

        @Override
        public UserSnapshot get() {
            Long2ObjectMap vectors = new Long2ObjectOpenHashMap();
            ObjectStream> users = userEventDAO.streamEventsByUser(summarizer.eventTypeWanted());
            try {
                for (UserHistory user: users) {
                    MutableSparseVector uvec = summarizer.summarize(user).mutableCopy();
                    vectors.put(user.getUserId(), uvec);
                }
            } finally {
                users.close();
            }

            Long2ObjectMap itemUserLists = new Long2ObjectOpenHashMap();
            SortedKeyIndex domain = SortedKeyIndex.fromCollection(vectors.keySet());
            ImmutableList.Builder vecs = ImmutableList.builder();
            ImmutableList.Builder nvecs = ImmutableList.builder();
            for (LongIterator uiter = domain.keyIterator(); uiter.hasNext();) {
                final long user = uiter.nextLong();
                MutableSparseVector vec = vectors.get(user);
                // save user's original vector
                ImmutableSparseVector userVector = vec.immutable();
                vecs.add(userVector);
                // normalize user vector
                normalizer.normalize(user, userVector, vec);
                // and save normalized vector
                nvecs.add(vec.immutable());
                for (LongIterator iiter = userVector.keySet().iterator(); iiter.hasNext();) {
                    final long item = iiter.nextLong();
                    LongList itemUsers = itemUserLists.get(item);
                    if (itemUsers == null) {
                        itemUsers = new LongArrayList();
                        itemUserLists.put(item, itemUsers);
                    }
                    itemUsers.add(user);
                }
            }

            Long2ObjectMap itemUserSets = new Long2ObjectOpenHashMap();
            for (Long2ObjectMap.Entry entry: itemUserLists.long2ObjectEntrySet()) {
                itemUserSets.put(entry.getLongKey(), LongUtils.packedSet(entry.getValue()));
            }
            return new UserSnapshot(domain, vecs.build(), nvecs.build(), itemUserSets);
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy