All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.math.neighborhood.ProjectionSearch Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.math.neighborhood;

import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Set;

import com.google.common.base.Preconditions;
import com.google.common.collect.AbstractIterator;
import com.google.common.collect.BoundType;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.google.common.collect.TreeMultiset;
import org.apache.mahout.math.random.RandomProjector;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.random.WeightedThing;

/**
 * Does approximate nearest neighbor dudes search by projecting the data.
 */
public class ProjectionSearch extends UpdatableSearcher {

  /**
   * A lists of tree sets containing the scalar projections of each vector.
   * The elements in a TreeMultiset are WeightedThing, where the weight is the scalar
   * projection of the vector at the index pointed to by the Integer from the referenceVectors list
   * on the basis vector whose index is the same as the index of the TreeSet in the List.
   */
  private List>> scalarProjections;

  /**
   * The list of random normalized projection vectors forming a basis.
   * The TreeSet of scalar projections at index i in scalarProjections corresponds to the vector
   * at index i from basisVectors.
   */
  private Matrix basisMatrix;

  /**
   * The number of elements to consider on both sides in the ball around the vector found by the
   * search in a TreeSet from scalarProjections.
   */
  private final int searchSize;

  private final int numProjections;
  private boolean initialized = false;

  private void initialize(int numDimensions) {
    if (initialized) {
      return;
    }
    initialized = true;
    basisMatrix = RandomProjector.generateBasisNormal(numProjections, numDimensions);
    scalarProjections = Lists.newArrayList();
    for (int i = 0; i < numProjections; ++i) {
      scalarProjections.add(TreeMultiset.>create());
    }
  }

  public ProjectionSearch(DistanceMeasure distanceMeasure, int numProjections,  int searchSize) {
    super(distanceMeasure);
    Preconditions.checkArgument(numProjections > 0 && numProjections < 100,
        "Unreasonable value for number of projections. Must be: 0 < numProjections < 100");

    this.searchSize = searchSize;
    this.numProjections = numProjections;
  }

  /**
   * Adds a WeightedVector into the set of projections for later searching.
   * @param vector  The WeightedVector to add.
   */
  @Override
  public void add(Vector vector) {
    initialize(vector.size());
    Vector projection = basisMatrix.times(vector);
    // Add the the new vector and the projected distance to each set separately.
    int i = 0;
    for (TreeMultiset> s : scalarProjections) {
      s.add(new WeightedThing<>(vector, projection.get(i++)));
    }
    int numVectors = scalarProjections.get(0).size();
    for (TreeMultiset> s : scalarProjections) {
      Preconditions.checkArgument(s.size() == numVectors, "Number of vectors in projection sets "
          + "differ");
      double firstWeight = s.firstEntry().getElement().getWeight();
      for (WeightedThing w : s) {
        Preconditions.checkArgument(firstWeight <= w.getWeight(), "Weights not in non-decreasing "
            + "order");
        firstWeight = w.getWeight();
      }
    }
  }

  /**
   * Returns the number of scalarProjections that we can search
   * @return  The number of scalarProjections added to the search so far.
   */
  @Override
  public int size() {
    if (scalarProjections == null) {
      return 0;
    }
    return scalarProjections.get(0).size();
  }

  /**
   * Searches for the query vector returning the closest limit referenceVectors.
   *
   * @param query the vector to search for.
   * @param limit the number of results to return.
   * @return a list of Vectors wrapped in WeightedThings where the "thing"'s weight is the
   * distance.
   */
  @Override
  public List> search(Vector query, int limit) {
    Set candidates = Sets.newHashSet();

    Iterator projections = basisMatrix.iterator();
    for (TreeMultiset> v : scalarProjections) {
      Vector basisVector = projections.next();
      WeightedThing projectedQuery = new WeightedThing<>(query,
          query.dot(basisVector));
      for (WeightedThing candidate : Iterables.concat(
          Iterables.limit(v.tailMultiset(projectedQuery, BoundType.CLOSED), searchSize),
          Iterables.limit(v.headMultiset(projectedQuery, BoundType.OPEN).descendingMultiset(), searchSize))) {
        candidates.add(candidate.getValue());
      }
    }

    // If searchSize * scalarProjections.size() is small enough not to cause much memory pressure,
    // this is probably just as fast as a priority queue here.
    List> top = Lists.newArrayList();
    for (Vector candidate : candidates) {
      top.add(new WeightedThing<>(candidate, distanceMeasure.distance(query, candidate)));
    }
    Collections.sort(top);
    return top.subList(0, Math.min(limit, top.size()));
  }

  /**
   * Returns the closest vector to the query.
   * When only one the nearest vector is needed, use this method, NOT search(query, limit) because
   * it's faster (less overhead).
   *
   * @param query the vector to search for
   * @param differentThanQuery if true, returns the closest vector different than the query (this
   *                           only matters if the query is among the searched vectors), otherwise,
   *                           returns the closest vector to the query (even the same vector).
   * @return the weighted vector closest to the query
   */
  @Override
  public WeightedThing searchFirst(Vector query, boolean differentThanQuery) {
    double bestDistance = Double.POSITIVE_INFINITY;
    Vector bestVector = null;

    Iterator projections = basisMatrix.iterator();
    for (TreeMultiset> v : scalarProjections) {
      Vector basisVector = projections.next();
      WeightedThing projectedQuery = new WeightedThing<>(query, query.dot(basisVector));
      for (WeightedThing candidate : Iterables.concat(
          Iterables.limit(v.tailMultiset(projectedQuery, BoundType.CLOSED), searchSize),
          Iterables.limit(v.headMultiset(projectedQuery, BoundType.OPEN).descendingMultiset(), searchSize))) {
        double distance = distanceMeasure.distance(query, candidate.getValue());
        if (distance < bestDistance && (!differentThanQuery || !candidate.getValue().equals(query))) {
          bestDistance = distance;
          bestVector = candidate.getValue();
        }
      }
    }

    return new WeightedThing<>(bestVector, bestDistance);
  }

  @Override
  public Iterator iterator() {
    return new AbstractIterator() {
      private final Iterator> projected = scalarProjections.get(0).iterator();
      @Override
      protected Vector computeNext() {
        if (!projected.hasNext()) {
          return endOfData();
        }
        return projected.next().getValue();
      }
    };
  }

  @Override
  public boolean remove(Vector vector, double epsilon) {
    WeightedThing toRemove = searchFirst(vector, false);
    if (toRemove.getWeight() < epsilon) {
      Iterator basisVectors = basisMatrix.iterator();
      for (TreeMultiset> projection : scalarProjections) {
        if (!projection.remove(new WeightedThing<>(vector, vector.dot(basisVectors.next())))) {
          throw new RuntimeException("Internal inconsistency in ProjectionSearch");
        }
      }
      return true;
    } else {
      return false;
    }
  }

  @Override
  public void clear() {
    if (scalarProjections == null) {
      return;
    }
    for (TreeMultiset> set : scalarProjections) {
      set.clear();
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy