All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.math.neighborhood.FastProjectionSearch Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.math.neighborhood;

import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Set;

import com.google.common.base.Preconditions;
import com.google.common.collect.AbstractIterator;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.random.RandomProjector;
import org.apache.mahout.math.random.WeightedThing;

/**
 * Does approximate nearest neighbor search by projecting the vectors similar to ProjectionSearch.
 * The main difference between this class and the ProjectionSearch is the use of sorted arrays
 * instead of binary search trees to implement the sets of scalar projections.
 *
 * Instead of taking log n time to add a vector to each of the vectors, * the pending additions are
 * kept separate and are searched using a brute search. When there are "enough" pending additions,
 * they're committed into the main pool of vectors.
 */
public class FastProjectionSearch extends UpdatableSearcher {
  // The list of vectors that have not yet been projected (that are pending).
  private final List pendingAdditions = Lists.newArrayList();

  // The list of basis vectors. Populated when the first vector's dimension is know by calling
  // initialize once.
  private Matrix basisMatrix = null;

  // The list of sorted lists of scalar projections. The outer list has one entry for each basis
  // vector that all the other vectors will be projected on.
  // For each basis vector, the inner list has an entry for each vector that has been projected.
  // These entries are WeightedThing where the weight is the value of the scalar
  // projection and the value is the vector begin referred to.
  private List>> scalarProjections;

  // The number of projection used for approximating the distance.
  private final int numProjections;

  // The number of elements to keep on both sides of the closest estimated distance as possible
  // candidates for the best actual distance.
  private final int searchSize;

  // Initially, the dimension of the vectors searched by this searcher is unknown. After adding
  // the first vector, the basis will be initialized. This marks whether initialization has
  // happened or not so we only do it once.
  private boolean initialized = false;

  // Removing vectors from the searcher is done lazily to avoid the linear time cost of removing
  // elements from an array. This member keeps track of the number of removed vectors (marked as
  // "impossible" values in the array) so they can be removed when updating the structure.
  private int numPendingRemovals = 0;

  private static final double ADDITION_THRESHOLD = 0.05;
  private static final double REMOVAL_THRESHOLD = 0.02;

  public FastProjectionSearch(DistanceMeasure distanceMeasure, int numProjections, int searchSize) {
    super(distanceMeasure);
    Preconditions.checkArgument(numProjections > 0 && numProjections < 100,
        "Unreasonable value for number of projections. Must be: 0 < numProjections < 100");
    this.numProjections = numProjections;
    this.searchSize = searchSize;
    scalarProjections = Lists.newArrayListWithCapacity(numProjections);
    for (int i = 0; i < numProjections; ++i) {
      scalarProjections.add(Lists.>newArrayList());
    }
  }

  private void initialize(int numDimensions) {
    if (initialized) {
      return;
    }
    basisMatrix = RandomProjector.generateBasisNormal(numProjections, numDimensions);
    initialized = true;
  }

  /**
   * Add a new Vector to the Searcher that will be checked when getting
   * the nearest neighbors.
   * 

* The vector IS NOT CLONED. Do not modify the vector externally otherwise the internal * Searcher data structures could be invalidated. */ @Override public void add(Vector vector) { initialize(vector.size()); pendingAdditions.add(vector); } /** * Returns the number of WeightedVectors being searched for nearest neighbors. */ @Override public int size() { return pendingAdditions.size() + scalarProjections.get(0).size() - numPendingRemovals; } /** * When querying the Searcher for the closest vectors, a list of WeightedThings is * returned. The value of the WeightedThing is the neighbor and the weight is the * the distance (calculated by some metric - see a concrete implementation) between the query * and neighbor. * The actual type of vector in the pair is the same as the vector added to the Searcher. */ @Override public List> search(Vector query, int limit) { reindex(false); Set candidates = Sets.newHashSet(); Vector projection = basisMatrix.times(query); for (int i = 0; i < basisMatrix.numRows(); ++i) { List> currProjections = scalarProjections.get(i); int middle = Collections.binarySearch(currProjections, new WeightedThing(projection.get(i))); if (middle < 0) { middle = -(middle + 1); } for (int j = Math.max(0, middle - searchSize); j < Math.min(currProjections.size(), middle + searchSize + 1); ++j) { if (currProjections.get(j).getValue() == null) { continue; } candidates.add(currProjections.get(j).getValue()); } } List> top = Lists.newArrayListWithCapacity(candidates.size() + pendingAdditions.size()); for (Vector candidate : Iterables.concat(candidates, pendingAdditions)) { top.add(new WeightedThing<>(candidate, distanceMeasure.distance(candidate, query))); } Collections.sort(top); return top.subList(0, Math.min(top.size(), limit)); } /** * Returns the closest vector to the query. * When only one the nearest vector is needed, use this method, NOT search(query, limit) because * it's faster (less overhead). * * @param query the vector to search for * @param differentThanQuery if true, returns the closest vector different than the query (this * only matters if the query is among the searched vectors), otherwise, * returns the closest vector to the query (even the same vector). * @return the weighted vector closest to the query */ @Override public WeightedThing searchFirst(Vector query, boolean differentThanQuery) { reindex(false); double bestDistance = Double.POSITIVE_INFINITY; Vector bestVector = null; Vector projection = basisMatrix.times(query); for (int i = 0; i < basisMatrix.numRows(); ++i) { List> currProjections = scalarProjections.get(i); int middle = Collections.binarySearch(currProjections, new WeightedThing(projection.get(i))); if (middle < 0) { middle = -(middle + 1); } for (int j = Math.max(0, middle - searchSize); j < Math.min(currProjections.size(), middle + searchSize + 1); ++j) { if (currProjections.get(j).getValue() == null) { continue; } Vector vector = currProjections.get(j).getValue(); double distance = distanceMeasure.distance(vector, query); if (distance < bestDistance && (!differentThanQuery || !vector.equals(query))) { bestDistance = distance; bestVector = vector; } } } for (Vector vector : pendingAdditions) { double distance = distanceMeasure.distance(vector, query); if (distance < bestDistance && (!differentThanQuery || !vector.equals(query))) { bestDistance = distance; bestVector = vector; } } return new WeightedThing<>(bestVector, bestDistance); } @Override public boolean remove(Vector vector, double epsilon) { WeightedThing closestPair = searchFirst(vector, false); if (distanceMeasure.distance(closestPair.getValue(), vector) > epsilon) { return false; } boolean isProjected = true; Vector projection = basisMatrix.times(vector); for (int i = 0; i < basisMatrix.numRows(); ++i) { List> currProjections = scalarProjections.get(i); WeightedThing searchedThing = new WeightedThing<>(projection.get(i)); int middle = Collections.binarySearch(currProjections, searchedThing); if (middle < 0) { isProjected = false; break; } // Elements to be removed are kept in the sorted array until the next reindex, but their inner vector // is set to null. scalarProjections.get(i).set(middle, searchedThing); } if (isProjected) { ++numPendingRemovals; return true; } for (int i = 0; i < pendingAdditions.size(); ++i) { if (pendingAdditions.get(i).equals(vector)) { pendingAdditions.remove(i); break; } } return true; } private void reindex(boolean force) { int numProjected = scalarProjections.get(0).size(); if (force || pendingAdditions.size() > ADDITION_THRESHOLD * numProjected || numPendingRemovals > REMOVAL_THRESHOLD * numProjected) { // We only need to copy the first list because when iterating we use only that list for the Vector // references. // see public Iterator iterator() List>> scalarProjections = Lists.newArrayListWithCapacity(numProjections); for (int i = 0; i < numProjections; ++i) { if (i == 0) { scalarProjections.add(Lists.newArrayList(this.scalarProjections.get(i))); } else { scalarProjections.add(this.scalarProjections.get(i)); } } // Project every pending vector onto every basis vector. for (Vector pending : pendingAdditions) { Vector projection = basisMatrix.times(pending); for (int i = 0; i < numProjections; ++i) { scalarProjections.get(i).add(new WeightedThing<>(pending, projection.get(i))); } } pendingAdditions.clear(); // For each basis vector, sort the resulting list (for binary search) and remove the number // of pending removals (it's the same for every basis vector) at the end (the weights are // set to Double.POSITIVE_INFINITY when removing). for (int i = 0; i < numProjections; ++i) { List> currProjections = scalarProjections.get(i); for (WeightedThing v : currProjections) { if (v.getValue() == null) { v.setWeight(Double.POSITIVE_INFINITY); } } Collections.sort(currProjections); for (int j = 0; j < numPendingRemovals; ++j) { currProjections.remove(currProjections.size() - 1); } } numPendingRemovals = 0; this.scalarProjections = scalarProjections; } } @Override public void clear() { pendingAdditions.clear(); for (int i = 0; i < numProjections; ++i) { scalarProjections.get(i).clear(); } numPendingRemovals = 0; } /** * This iterates on the snapshot of the contents first instantiated regardless of any future modifications. * Changes done after the iterator is created will not be visible to the iterator but will be visible * when searching. * @return iterator through the vectors in this searcher. */ @Override public Iterator iterator() { reindex(true); return new AbstractIterator() { private final Iterator> data = scalarProjections.get(0).iterator(); @Override protected Vector computeNext() { do { if (!data.hasNext()) { return endOfData(); } WeightedThing next = data.next(); if (next.getValue() != null) { return next.getValue(); } } while (true); } }; } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy