org.apache.mahout.math.neighborhood.ProjectionSearch Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mahout-mr Show documentation
Show all versions of mahout-mr Show documentation
Scalable machine learning libraries
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.math.neighborhood;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import com.google.common.base.Preconditions;
import com.google.common.collect.AbstractIterator;
import com.google.common.collect.BoundType;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.google.common.collect.TreeMultiset;
import org.apache.mahout.math.random.RandomProjector;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.random.WeightedThing;
/**
* Does approximate nearest neighbor dudes search by projecting the data.
*/
public class ProjectionSearch extends UpdatableSearcher {
/**
* A lists of tree sets containing the scalar projections of each vector.
* The elements in a TreeMultiset are WeightedThing, where the weight is the scalar
* projection of the vector at the index pointed to by the Integer from the referenceVectors list
* on the basis vector whose index is the same as the index of the TreeSet in the List.
*/
private List>> scalarProjections;
/**
* The list of random normalized projection vectors forming a basis.
* The TreeSet of scalar projections at index i in scalarProjections corresponds to the vector
* at index i from basisVectors.
*/
private Matrix basisMatrix;
/**
* The number of elements to consider on both sides in the ball around the vector found by the
* search in a TreeSet from scalarProjections.
*/
private final int searchSize;
private final int numProjections;
private boolean initialized = false;
private void initialize(int numDimensions) {
if (initialized) {
return;
}
initialized = true;
basisMatrix = RandomProjector.generateBasisNormal(numProjections, numDimensions);
scalarProjections = Lists.newArrayList();
for (int i = 0; i < numProjections; ++i) {
scalarProjections.add(TreeMultiset.>create());
}
}
public ProjectionSearch(DistanceMeasure distanceMeasure, int numProjections, int searchSize) {
super(distanceMeasure);
Preconditions.checkArgument(numProjections > 0 && numProjections < 100,
"Unreasonable value for number of projections. Must be: 0 < numProjections < 100");
this.searchSize = searchSize;
this.numProjections = numProjections;
}
/**
* Adds a WeightedVector into the set of projections for later searching.
* @param vector The WeightedVector to add.
*/
@Override
public void add(Vector vector) {
initialize(vector.size());
Vector projection = basisMatrix.times(vector);
// Add the the new vector and the projected distance to each set separately.
int i = 0;
for (TreeMultiset> s : scalarProjections) {
s.add(new WeightedThing<>(vector, projection.get(i++)));
}
int numVectors = scalarProjections.get(0).size();
for (TreeMultiset> s : scalarProjections) {
Preconditions.checkArgument(s.size() == numVectors, "Number of vectors in projection sets "
+ "differ");
double firstWeight = s.firstEntry().getElement().getWeight();
for (WeightedThing w : s) {
Preconditions.checkArgument(firstWeight <= w.getWeight(), "Weights not in non-decreasing "
+ "order");
firstWeight = w.getWeight();
}
}
}
/**
* Returns the number of scalarProjections that we can search
* @return The number of scalarProjections added to the search so far.
*/
@Override
public int size() {
if (scalarProjections == null) {
return 0;
}
return scalarProjections.get(0).size();
}
/**
* Searches for the query vector returning the closest limit referenceVectors.
*
* @param query the vector to search for.
* @param limit the number of results to return.
* @return a list of Vectors wrapped in WeightedThings where the "thing"'s weight is the
* distance.
*/
@Override
public List> search(Vector query, int limit) {
Set candidates = Sets.newHashSet();
Iterator extends Vector> projections = basisMatrix.iterator();
for (TreeMultiset> v : scalarProjections) {
Vector basisVector = projections.next();
WeightedThing projectedQuery = new WeightedThing<>(query,
query.dot(basisVector));
for (WeightedThing candidate : Iterables.concat(
Iterables.limit(v.tailMultiset(projectedQuery, BoundType.CLOSED), searchSize),
Iterables.limit(v.headMultiset(projectedQuery, BoundType.OPEN).descendingMultiset(), searchSize))) {
candidates.add(candidate.getValue());
}
}
// If searchSize * scalarProjections.size() is small enough not to cause much memory pressure,
// this is probably just as fast as a priority queue here.
List> top = Lists.newArrayList();
for (Vector candidate : candidates) {
top.add(new WeightedThing<>(candidate, distanceMeasure.distance(query, candidate)));
}
Collections.sort(top);
return top.subList(0, Math.min(limit, top.size()));
}
/**
* Returns the closest vector to the query.
* When only one the nearest vector is needed, use this method, NOT search(query, limit) because
* it's faster (less overhead).
*
* @param query the vector to search for
* @param differentThanQuery if true, returns the closest vector different than the query (this
* only matters if the query is among the searched vectors), otherwise,
* returns the closest vector to the query (even the same vector).
* @return the weighted vector closest to the query
*/
@Override
public WeightedThing searchFirst(Vector query, boolean differentThanQuery) {
double bestDistance = Double.POSITIVE_INFINITY;
Vector bestVector = null;
Iterator extends Vector> projections = basisMatrix.iterator();
for (TreeMultiset> v : scalarProjections) {
Vector basisVector = projections.next();
WeightedThing projectedQuery = new WeightedThing<>(query, query.dot(basisVector));
for (WeightedThing candidate : Iterables.concat(
Iterables.limit(v.tailMultiset(projectedQuery, BoundType.CLOSED), searchSize),
Iterables.limit(v.headMultiset(projectedQuery, BoundType.OPEN).descendingMultiset(), searchSize))) {
double distance = distanceMeasure.distance(query, candidate.getValue());
if (distance < bestDistance && (!differentThanQuery || !candidate.getValue().equals(query))) {
bestDistance = distance;
bestVector = candidate.getValue();
}
}
}
return new WeightedThing<>(bestVector, bestDistance);
}
@Override
public Iterator iterator() {
return new AbstractIterator() {
private final Iterator> projected = scalarProjections.get(0).iterator();
@Override
protected Vector computeNext() {
if (!projected.hasNext()) {
return endOfData();
}
return projected.next().getValue();
}
};
}
@Override
public boolean remove(Vector vector, double epsilon) {
WeightedThing toRemove = searchFirst(vector, false);
if (toRemove.getWeight() < epsilon) {
Iterator extends Vector> basisVectors = basisMatrix.iterator();
for (TreeMultiset> projection : scalarProjections) {
if (!projection.remove(new WeightedThing<>(vector, vector.dot(basisVectors.next())))) {
throw new RuntimeException("Internal inconsistency in ProjectionSearch");
}
}
return true;
} else {
return false;
}
}
@Override
public void clear() {
if (scalarProjections == null) {
return;
}
for (TreeMultiset> set : scalarProjections) {
set.clear();
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy