All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.grouplens.lenskit.vectors.SparseVector Maven / Gradle / Ivy

The newest version!
/*
 * LensKit, an open source recommender systems toolkit.
 * Copyright 2010-2014 LensKit Contributors.  See CONTRIBUTORS.md.
 * Work on LensKit has been funded by the National Science Foundation under
 * grants IIS 05-34939, 08-08692, 08-12148, and 10-17697.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as
 * published by the Free Software Foundation; either version 2.1 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
 * details.
 *
 * You should have received a copy of the GNU General Public License along with
 * this program; if not, write to the Free Software Foundation, Inc., 51
 * Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
 */
package org.grouplens.lenskit.vectors;

import com.google.common.base.Function;
import com.google.common.collect.Iterators;
import com.google.common.primitives.Longs;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.doubles.DoubleCollection;
import it.unimi.dsi.fastutil.doubles.DoubleIterator;
import it.unimi.dsi.fastutil.ints.IntIterator;
import it.unimi.dsi.fastutil.ints.IntIterators;
import it.unimi.dsi.fastutil.longs.*;
import org.apache.commons.lang3.StringUtils;
import org.grouplens.lenskit.collections.LongKeyDomain;
import org.grouplens.lenskit.symbols.Symbol;
import org.grouplens.lenskit.symbols.TypedSymbol;

import javax.annotation.Nonnull;
import java.io.Serializable;
import java.util.AbstractCollection;
import java.util.Collection;
import java.util.Iterator;
import java.util.Set;

/**
 * Read-only interface to sparse vectors.
 *
 * 

* This vector class works a lot like a map, but it also caches some * commonly-used statistics. The values are stored in parallel arrays sorted by * key. This allows fast lookup and sorted iteration. All iterators access the * items in key order. * *

* Vectors have a key domain, which is a set containing all valid keys in * the vector. This key domain is fixed at construction; mutable vectors cannot * set values for keys not in this domain. Thinking of the vector as a function * from longs to doubles, the key domain would actually be the codomain, and the * key set the algebraic domain, but that gets cumbersome to write in code. So * think of the key domain as the domain from which valid keys are drawn. * *

* This class provides a read-only interface to sparse vectors. It may * actually be a {@link MutableSparseVector}, so the data may be modified by * code elsewhere that has access to the mutable representation. For sparse * vectors that are guaranteed to be unchanging, see * {@link ImmutableSparseVector}. * * @see Sparse Vector tutorial * @compat Public * @deprecated Use maps instead. */ @Deprecated public abstract class SparseVector implements Iterable, Serializable { private static final long serialVersionUID = 2L; /** * The domain of keys. */ final LongKeyDomain keys; /** * The value array. Indexes in this array correspond to indexes produced by {@link #keys}. */ double[] values; //region Constructors /** * Construct a new vector from a key set and value array. * @param ks The key set. Used as-is, and will be modified. Pass a clone, usually. * @param vs The value array. */ @SuppressWarnings("PMD.ArrayIsStoredDirectly") SparseVector(LongKeyDomain ks, double[] vs) { assert vs.length >= ks.domainSize(); keys = ks; keys.acquire(); values = vs; } /** * Construct a new sparse vector with a particular domain. Allocates the value storage. * @param ks The key set. Used as-is, and will be modified. Pass a clone, usually. */ SparseVector(LongKeyDomain ks) { this(ks, new double[ks.domainSize()]); ks.setAllActive(false); } /** * Construct a new vector from the contents of a map. The key domain is the * key set of the map. Therefore, no new keys can be added to this vector. * * @param keyValueMap A map providing the values for the vector. */ SparseVector(Long2DoubleMap keyValueMap) { keys = LongKeyDomain.fromCollection(keyValueMap.keySet(), true); final int len = keys.domainSize(); values = new double[len]; for (int i = 0; i < len; i++) { values[i] = keyValueMap.get(keys.getKey(i)); } } public Long2DoubleMap asMap() { return new SparseVectorMapAdapter(this); } //endregion //region Queries /** * Query whether the vector contains an entry for the key in question. * * @param key The key to search for. * @return {@code true} if the key exists. */ public boolean containsKey(long key) { return keys.keyIsActive(key); } /** * Get the value for key. * * @param key the key to look up; the key must be in the key set. * @return the key's value * @throws IllegalArgumentException if key is not in the key set. */ public double get(long key) { final int idx = keys.getIndexIfActive(key); if (idx >= 0) { return values[idx]; } else { throw new IllegalArgumentException("Key " + key + " is not in the key set"); } } /** * Get the value for key. * * @param key the key to look up * @param dft The value to return if the key is not in the vector * @return the value (or dft if the key is not set to a value) */ public double get(long key, double dft) { final int idx = keys.getIndexIfActive(key); if (idx >= 0) { return values[idx]; } else { return dft; } } /** * Get the value for the entry's key. * * @param entry A {@code VectorEntry} with the key to look up * @return the key's value * @throws IllegalArgumentException if the entry is unset, or if it is not from this vector or another vector * sharing the same key domain. Only vectors and their side channels share key domains for the * purposes of this check. */ public double get(VectorEntry entry) { final SparseVector evec = entry.getVector(); final int eind = entry.getIndex(); if (evec == null) { throw new IllegalArgumentException("entry is not associated with a vector"); } else if (!evec.keys.isCompatibleWith(keys)) { throw new IllegalArgumentException("entry does not have safe key domain"); } assert entry.getKey() == keys.getKey(eind); if (keys.indexIsActive(eind)) { return values[eind]; } else { throw new IllegalArgumentException("Key " + entry.getKey() + " is not set"); } } /** * Check whether an entry is set. * @param entry The entry. * @return {@code true} if the entry is set in this vector. * @throws IllegalArgumentException if the entry is not from this vector or another vector * sharing the same key domain. Only vectors and their side channels share key domains for the * purposes of this check. */ public boolean isSet(VectorEntry entry) { final SparseVector evec = entry.getVector(); final int eind = entry.getIndex(); if (evec == null) { throw new IllegalArgumentException("entry is not associated with a vector"); } else if (!keys.isCompatibleWith(evec.keys)) { throw new IllegalArgumentException("entry does not have safe key domain"); } assert entry.getKey() == keys.getKey(eind); return keys.indexIsActive(eind); } //endregion //region Iterators /** * Fast iterator over all set entries (it can reuse entry objects). * * @return a fast iterator over all key/value pairs * @see #fastIterator(VectorEntry.State) * @see it.unimi.dsi.fastutil.longs.Long2DoubleMap.FastEntrySet#fastIterator() * Long2DoubleMap.FastEntrySet.fastIterator() * @deprecated Fast iteration is going away. */ @Deprecated public Iterator fastIterator() { return iterator(); } boolean isMutable() { return true; } /** * Fast iterator over entries (it can reuse entry objects). * * @param state The state of entries to iterate. * @return a fast iterator over all key/value pairs * @see it.unimi.dsi.fastutil.longs.Long2DoubleMap.FastEntrySet#fastIterator() * Long2DoubleMap.FastEntrySet.fastIterator() * @since 0.11 * @deprecated Fast iteration is going away. */ @Deprecated public Iterator fastIterator(VectorEntry.State state) { return iterator(state); } /** * Return an iterable view of this vector using a fast iterator. This method * delegates to {@link #fast(VectorEntry.State)} with state {@link VectorEntry.State#SET}. * * @return This object wrapped in an iterable that returns a fast iterator. * @see #fastIterator() * @deprecated Fast iteration is going away. */ @Deprecated public Iterable fast() { return view(VectorEntry.State.SET); } /** * Return an iterable view of this vector using a fast iterator. * * @param state The entries the resulting iterable should return. * @return This object wrapped in an iterable that returns a fast iterator. * @see #fastIterator(VectorEntry.State) * @since 0.11 * @deprecated Fast iteration is going away. */ @Deprecated public Iterable fast(final VectorEntry.State state) { return view(state); } // The default iterator for this SparseVector iterates over // entries that are "used". It uses an IterImpl class that // generates a new VectorEntry for every element returned, so the // client can safely keep around the VectorEntrys without concern // they will mutate, at some cost in speed. @Override public Iterator iterator() { return new IterImpl(VectorEntry.State.SET); } public Iterator iterator(final VectorEntry.State state) { return new IterImpl(state); } /** * Get a collection view of a vector entry. * @param state The state of entries to view. * @return A collection of vector entries. */ public Collection view(VectorEntry.State state) { return new View(state); } private class IterImpl implements Iterator { private final VectorEntry.State state; private final IntIterator iter; IterImpl(VectorEntry.State st) { state = st; switch (state) { case SET: iter = keys.activeIndexIterator(isMutable()); break; case UNSET: { iter = keys.clone().invert().activeIndexIterator(false); break; } case EITHER: { iter = IntIterators.fromTo(0, keys.domainSize()); break; } default: throw new IllegalArgumentException(); } } @Override public boolean hasNext() { return iter.hasNext(); } @Override @Nonnull public VectorEntry next() { int pos = iter.nextInt(); boolean isSet = state == VectorEntry.State.SET || keys.indexIsActive(pos); double v = isSet ? values[pos] : Double.NaN; return new VectorEntry(SparseVector.this, pos, keys.getKey(pos), v, isSet); } @Override public void remove() { throw new UnsupportedOperationException(); } } private class View extends AbstractCollection { private final VectorEntry.State state; public View(VectorEntry.State st) { state = st; } @Override public Iterator iterator() { return new IterImpl(state); } @Override public int size() { switch (state) { case SET: return SparseVector.this.size(); case EITHER: return keys.domainSize(); case UNSET: return SparseVector.this.size() - keys.domainSize(); default: throw new IllegalStateException(); } } } //endregion //region Domain, set, and value management /** * Get the key domain for this vector. All keys used are in this * set. The keys will be in sorted order. * * @return The key domain for this vector. */ public LongSortedSet keyDomain() { return keys.domain(); } /** * Get the set of keys of this vector. It is a subset of the key * domain. The keys will be in sorted order. * * @return The set of keys used in this vector. */ public LongSortedSet keySet() { return keys.activeSetView(); } /** * Get the set of unset keys. This is \(D \\ S\), where \(D\) is the key domain and \(S\) the * key set. */ public LongSortedSet unsetKeySet() { return keys.clone().invert().activeSetView(); } /** * Return the keys of this vector sorted by value. * * @return A list of keys in nondecreasing order of value. * @see #keysByValue(boolean) */ public LongArrayList keysByValue() { return keysByValue(false); } /** * Get the collection of values of this vector. * * @return The collection of all values in this vector. */ public DoubleCollection values() { DoubleArrayList lst = new DoubleArrayList(size()); IntIterator iter = keys.activeIndexIterator(false); while (iter.hasNext()) { int idx = iter.nextInt(); lst.add(values[idx]); } return lst; } /** * Get the keys of this vector sorted by the value of the items * stored for each key. * * @param decreasing If {@code true}, sort in decreasing order. * @return The sorted list of keys of this vector. */ public LongArrayList keysByValue(boolean decreasing) { long[] skeys = keySet().toLongArray(); LongComparator cmp; // Set up the comparator. We use the key as a secondary comparison to get // a reproducible sort irrespective of sorting algorithm. if (decreasing) { cmp = new AbstractLongComparator() { @Override public int compare(long k1, long k2) { int c = Double.compare(get(k2), get(k1)); if (c != 0) { return c; } else { return Longs.compare(k1, k2); } } }; } else { cmp = new AbstractLongComparator() { @Override public int compare(long k1, long k2) { int c = Double.compare(get(k1), get(k2)); if (c != 0) { return c; } else { return Longs.compare(k1, k2); } } }; } LongArrays.quickSort(skeys, cmp); return LongArrayList.wrap(skeys); } /** * Get the size of this vector (the number of keys). * * @return The number of keys in the vector. This is at most the size of the * key domain. */ public int size() { return keys.size(); } /** * Query whether this vector is empty. * * @return {@code true} if the vector is empty. */ public boolean isEmpty() { return size() == 0; } //endregion //region Linear algebra /** * Compute and return the L2 norm (Euclidian length) of the vector. * * @return The L2 norm of the vector */ public double norm() { double ssq = 0; DoubleIterator iter = values().iterator(); while (iter.hasNext()) { double v = iter.nextDouble(); ssq += v * v; } return Math.sqrt(ssq); } /** * Compute and return the L1 norm (sum) of the vector. * * @return the sum of the vector's values */ public double sum() { double result = 0; if (keys.isCompletelySet()) { for (int i = keys.domainSize() - 1; i >= 0; i--) { result += values[i]; } } else { DoubleIterator iter = values().iterator(); while (iter.hasNext()) { result += iter.nextDouble(); } } return result; } /** * Compute and return the sum of the absolute values of the vector. * * @return the sum of the vector's absolute values */ public double sumAbs() { double result = 0; if (keys.isCompletelySet()) { for (int i = keys.domainSize() - 1; i >= 0; i--) { result += Math.abs(values[i]); } } else { DoubleIterator iter = values().iterator(); while (iter.hasNext()) { result += Math.abs(iter.nextDouble()); } } return result; } /** * Compute and return the mean of the vector's values. * * @return the mean of the vector */ public double mean() { final int sz = size(); return sz > 0 ? sum() / sz : 0; } /** * Compute the dot product between two vectors. * * @param o The other vector. * @return The dot (inner) product between this vector and o. */ public double dot(SparseVector o) { if (keys.isCompletelySet() && o.keys.isCompletelySet()) { return fastDotProduct(o); } else { return slowDotProduct(o); } } private double fastDotProduct(SparseVector o) { double dot = 0; int sz1 = keys.domainSize(); int sz2 = o.keys.domainSize(); int i1 = 0, i2 = 0; while (i1 < sz1 && i2 < sz2) { final long k1 = keys.getKey(i1); final long k2 = o.keys.getKey(i2); if (k1 < k2) { i1++; } else if (k2 < k1) { i2++; } else { dot += values[i1] * o.values[i2]; i1++; i2++; } } return dot; } private double slowDotProduct(SparseVector o) { // FIXME This code was tested, but no longer is. Add relevant tests. double dot = 0; Iterator i1 = iterator(); Iterator i2 = o.iterator(); VectorEntry e1 = i1.hasNext() ? i1.next() : null; VectorEntry e2 = i2.hasNext() ? i2.next() : null; while (e1 != null && e2 != null) { final long k1 = e1.getKey(); final long k2 = e2.getKey(); if (k1 < k2) { e1 = i1.hasNext() ? i1.next() : null; } else if (k2 < k1) { e2 = i2.hasNext() ? i2.next() : null; } else { dot += e1.getValue() * e2.getValue(); e1 = i1.hasNext() ? i1.next() : null; e2 = i2.hasNext() ? i2.next() : null; } } return dot; } /** * Count the common keys between two vectors. * * @param o The other vector. * @return The number of keys appearing in both this and the other vector. */ public int countCommonKeys(SparseVector o) { int count = 0; Iterator i1 = iterator(); Iterator i2 = o.iterator(); VectorEntry e1 = i1.hasNext() ? i1.next() : null; VectorEntry e2 = i2.hasNext() ? i2.next() : null; while (e1 != null && e2 != null) { final long k1 = e1.getKey(); final long k2 = e2.getKey(); if (k1 < k2) { e1 = i1.hasNext() ? i1.next() : null; } else if (k2 < k1) { e2 = i2.hasNext() ? i2.next() : null; } else { count += 1; e1 = i1.hasNext() ? i1.next() : null; e2 = i2.hasNext() ? i2.next() : null; } } return count; } /** * Combine this vector with another vector by taking the union of the key domains of two vectors. * If both vectors have values the same key, the values in {@code o} override those from the * current vector. * * @param o The other vector * @return A vector whose key domain is the union of the key domains of this vector and the other. */ public abstract SparseVector combineWith(SparseVector o); //endregion //region Object support @Override public String toString() { Function label = new Function() { @Override public String apply(VectorEntry e) { return String.format("%d: %.3f", e.getKey(), e.getValue()); } }; return "{" + StringUtils.join(Iterators.transform(iterator(), label), ", ") + "}"; } @Override public boolean equals(Object o) { if (this == o) { return true; } else if (o instanceof SparseVector) { SparseVector vo = (SparseVector) o; int sz = size(); int osz = vo.size(); if (sz != osz) { return false; } else { if (!this.keySet().equals(vo.keySet())) { return false; // same keys } // we know that sparse vector values are always in key order. so just compare them. return this.values().equals(vo.values()); } } else { return false; } } @Override public int hashCode() { return keySet().hashCode() ^ values().hashCode(); } //endregion //region Copying /** * Return an immutable snapshot of this sparse vector. The new vector's key * domain may be shrunk to remove storage of unused keys; no keys in * the key set will be removed. * * @return An immutable sparse vector whose contents are the same as this * vector. If the vector is already immutable, the returned object * may be identical. */ public abstract ImmutableSparseVector immutable(); /** * Return a mutable copy of this sparse vector. The key domain of the * mutable vector will be the same as this vector's key domain. * * @return A mutable sparse vector which can be modified without modifying * this vector. */ public abstract MutableSparseVector mutableCopy(); //endregion //region Channels /** * Return whether this sparse vector has a channel vector stored under a * particular symbol. * * @param channelSymbol the symbol under which the channel was * stored in the vector. * @return whether this vector has such a channel right now. */ public abstract boolean hasChannelVector(Symbol channelSymbol); /** * Return whether this sparse vector has a channel stored under a * particular typed symbol. * * @param channelSymbol the typed symbol under which the channel was * stored in the vector. * @return whether this vector has such a channel right now. */ public abstract boolean hasChannel(TypedSymbol channelSymbol); /** * Get the vector associated with a particular unboxed channel. * * @param channelSymbol the symbol under which the channel was/is * stored in the vector. * @return The vector corresponding to the specified unboxed channel, or {@code null} if * there is no such channel. */ public abstract SparseVector getChannelVector(Symbol channelSymbol); /** * Fetch the channel stored under a particular typed symbol. * * @param channelSymbol the typed symbol under which the channel was/is * stored in the vector. * @return the channel, which is itself a map from the key domain to objects of * the channel's type, or {@code null} if there is no such channel. */ public abstract Long2ObjectMap getChannel(TypedSymbol channelSymbol); /** * Retrieve all symbols that map to side channels for this vector. * @return A set of symbols, each of which identifies a side channel * of the vector. */ public abstract Set getChannelVectorSymbols(); /** * Retrieve all symbols that map to typed side channels for this vector. * @return A set of symbols, each of which identifies a side channel * of the vector. */ public abstract Set> getChannelSymbols(); //endregion //region Static Constructors /** * Get an empty sparse vector. * * @return An empty sparse vector. The vector is immutable, because mutating an empty vector is * impossible. */ @SuppressWarnings("deprecation") public static ImmutableSparseVector empty() { return new ImmutableSparseVector(); } //endregion }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy