All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.spotify.scio.extra.nn.NearestNeighbor.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2016 Spotify AB.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.spotify.scio.extra.nn

import breeze.linalg._
import breeze.math._
import com.google.common.collect.MinMaxPriorityQueue
import info.debatty.java.lsh.LSHSuperBit

import scala.collection.JavaConverters._
import scala.collection.mutable.{Buffer => MBuffer, Map => MMap, Set => MSet}
import scala.reflect.ClassTag
import scala.{specialized => sp}

/** Utilities for creating [[NearestNeighborBuilder]] instances. */
object NearestNeighbor {

  /**
   * Create a new builder for LSH based [[NearestNeighbor]].
   * @param dimension dimension of input vectors
   * @param stages number of times a vector is bucketed
   * @param buckets number of buckets per stage
   */
  def newLSHBuilder[K: ClassTag, @sp(Double, Int, Float, Long) V: ClassTag : Numeric : Semiring]
  (dimension: Int, stages: Int, buckets: Int): NearestNeighborBuilder[K, V] =
    new LSHNNBuilder[K, V](dimension, stages, buckets)

  /**
   * Create a new builder for matrix based [[NearestNeighbor]].
   * @param dimension dimension of input vectors
   */
  def newMatrixBuilder[K: ClassTag, @sp(Double, Int, Float, Long) V: ClassTag : Numeric : Semiring]
  (dimension: Int): NearestNeighborBuilder[K, V] =
    new MatrixNNBuilder[K, V](dimension)

}

/** Builder for immutable [[NearestNeighbor]] instances. */
trait NearestNeighborBuilder[K, @sp(Double, Int, Float, Long) V] extends Serializable {

  /** Dimension of item vectors. */
  protected val dimension: Int

  /** Item key to numeric id mapping. */
  protected val keyToId = MMap.empty[K, Int]

  /** Numeric id to item key mapping. */
  protected val idToKey = MBuffer.empty[K]

  /** Raw item vectors. */
  protected val vectors = MBuffer.empty[DenseVector[V]]

  /** Add a key->vector pair to common storage. */
  protected def addVector(key: K, vec: DenseVector[V]): Int = {
    require(vec.length == dimension, s"Vector dimension ${vec.length} != $dimension")
    require(!keyToId.contains(key), s"Key $key already exists")

    val id = keyToId.size
    keyToId(key) = id
    idToKey.append(key)
    vectors.append(vec)

    id
  }

  /** Add a key->vector pair. The vector should be normalized. */
  def add(key: K, vec: DenseVector[V]): Unit

  /** Build an immutable NearestNeighbor instance. */
  def build: NearestNeighbor[K, V]

}

/**
 * Immutable index for fast nearest neighbor look up.
 *
 * {{
 * import com.spotify.scio.extra.nn._
 *
 * // Tuples of (item ID, vector) to look up from
 * val vectors: Seq[(String, DenseVector[Double])] = // ...
 *
 * // Builder for a nearest neighbor index backed by a matrix
 * val builder = NearestNeighbor.newMatrixBuilder[String, Double](40)
 * vectors.foreach(kv => builder.addVector(kv._1, kv._2))
 * val nn = builder.build
 *
 * // Candidate to loo up nearest neighbors from
 * val candidate: DenseVector[Double] = // ...
 *
 * // Look up top 10 most similar items
 * nn.lookup(candidate, 10)
 * }}
 */
trait NearestNeighbor[K, @sp(Double, Int, Float, Long) V] extends Serializable {

  /** Name of the nearest neighbor method. */
  val name: String

  /** Dimension of item vectors. */
  protected val dimension: Int

  /** Item key to numeric id mapping. */
  protected val keyToId: Map[K, Int]

  /** Numeric id to item key mapping. */
  protected val idToKey: Array[K]

  /** Raw item vectors. */
  protected val vectors: Array[DenseVector[V]]

  @inline protected def getKey(id: Int): K = idToKey(id)
  @inline protected def getId(key: K): Int = keyToId(key)

  /** Lookup nearest neighbors of a vector. The vector should be normalized. */
  def lookup(vec: DenseVector[V], maxResult: Int,
             minSimilarity: Double = Double.NegativeInfinity): Iterable[(K, Double)]

  /** Lookup nearest neighbors of an existing vector. */
  def lookupKey(key: K, maxResult: Int,
                minSimilarity: Double = Double.NegativeInfinity): Iterable[(K, Double)] =
    lookup(vectors(getId(key)), maxResult, minSimilarity)

}

/** Builder for [[MatrixNN]]. */
private class
MatrixNNBuilder[K: ClassTag, @sp(Double, Int, Float, Long) V: ClassTag : Numeric : Semiring]
  (override val dimension: Int)
  extends NearestNeighborBuilder[K, V] {

  /** Add a key->vector pair. The vector should be normalized. */
  override def add(key: K, vec: DenseVector[V]): Unit = addVector(key, vec)

  /** Build an immutable NearestNeighbor instance. */
  override def build: NearestNeighbor[K, V] =
    new MatrixNN(
      dimension, keyToId.toMap, idToKey.toArray, vectors.toArray,
      DenseMatrix(vectors.map(_.toArray): _*))
}

/** Nearest neighbor using vector dot product via matrix multiplication. */
private class MatrixNN[K, @sp(Double, Int, Float, Long) V: ClassTag : Numeric : Semiring]
  (override protected val dimension: Int,
   override protected val keyToId: Map[K, Int],
   override protected val idToKey: Array[K],
   override val vectors: Array[DenseVector[V]],
   private val matrix: Matrix[V])
  extends NearestNeighbor[K, V] {

  /** Name of the nearest neighbor method. */
  override val name: String = "Matrix"

  /** Lookup nearest neighbors of a vector. The vector should be normalized. */
  override def lookup(vec: DenseVector[V], maxResult: Int,
                      minSimilarity: Double): Iterable[(K, Double)] = {
    require(vec.length == dimension, s"Vector dimension ${vec.length} != $dimension")
    require(maxResult > 0, s"maxResult must be > 0")

    val sim = matrix * vec

    val pq = MinMaxPriorityQueue.orderedBy[(Int, Double)](Ordering.by(-_._2))
      .expectedSize(maxResult)
      .maximumSize(maxResult)
      .create[(Int, Double)]()

    val numeric = implicitly[Numeric[V]]
    var i = 0
    while (i < idToKey.length) {
      val cosine = numeric.toDouble(sim(i))
      if (cosine >= minSimilarity) {
        pq.add((i, cosine))
      }
      i += 1
    }
    pq.asScala.map { case (id, v) => (getKey(id), v) }
  }

}

/** Builder for [[LSHNN]]. */
private class
LSHNNBuilder[K: ClassTag, @sp(Double, Int, Float, Long) V: ClassTag : Numeric : Semiring]
  (override val dimension: Int, val stages: Int, val buckets: Int)
  extends NearestNeighborBuilder[K, V] {

  require(stages > 0, "stages must be > 0")
  require(buckets > 0, "buckets must be > 0")
  require(dimension > 0, "dimension must be > 0")

  private val lsh = new LSHSuperBit(stages, buckets, dimension)
  private val bins = Array.fill(buckets)(MBuffer.empty[Int])

  /** Add a key->vector pair. The vector should be normalized. */
  override def add(key: K, vec: DenseVector[V]): Unit = {
    val id = addVector(key, vec)

    val numeric = implicitly[Numeric[V]]
    val buckets = lsh.hash(vec.toArray.map(numeric.toDouble))
    var i = 0
    while (i < buckets.length) {
      bins(buckets(i)).append(id)
      i += 1
    }
  }

  /** Build an immutable NearestNeighbor instance. */
  override def build: NearestNeighbor[K, V] =
    new LSHNN(dimension, keyToId.toMap, idToKey.toArray, vectors.toArray, lsh, bins.map(_.toArray))

}

/** Nearest neighbor using Locality Sensitive Hashing. */
private class LSHNN[K, @sp(Double, Int, Float, Long) V: ClassTag : Numeric : Semiring]
  (override protected val dimension: Int,
   override protected val keyToId: Map[K, Int],
   override protected val idToKey: Array[K],
   override val vectors: Array[DenseVector[V]],
   private val lsh: LSHSuperBit,
   private val bins: Array[Array[Int]])
  extends NearestNeighbor[K, V] {

  /** Name of the nearest neighbor method. */
  override val name: String = "LSH"

  /** Lookup nearest neighbors of a vector. The vector should be normalized. */
  override def lookup(vec: DenseVector[V], maxResult: Int,
                      minSimilarity: Double): Iterable[(K, Double)] = {
    require(vec.length == dimension, s"Vector dimension ${vec.length} != $dimension")
    require(maxResult > 0, s"maxResult must be > 0")

    val numeric = implicitly[Numeric[V]]
    val buckets = lsh.hash(vec.toArray.map(numeric.toDouble))

    val pq = MinMaxPriorityQueue.orderedBy[(Int, Double)](Ordering.by(-_._2))
      .expectedSize(maxResult)
      .maximumSize(maxResult)
      .create[(Int, Double)]()

    var i = 0
    val set = MSet.empty[Int]
    while (i < buckets.length) {
      val b = bins(buckets(i))
      var j = 0
      while (j < b.length) {
        val id = b(j)
        if (!set.contains(id)) {
          set.add(id)
          val cosine = numeric.toDouble(vec dot vectors(id))
          if (cosine >= minSimilarity) {
            pq.add((id, cosine))
          }
        }
        j += 1
      }
      i += 1
    }
    pq.asScala.map { case (id, v) => (getKey(id), v) }
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy