All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.soundcloud.lsh.NearestNeighbours.scala Maven / Gradle / Ivy

The newest version!
package com.soundcloud.lsh

import org.apache.spark.mllib.linalg.distributed.{CoordinateMatrix, IndexedRow, IndexedRowMatrix, MatrixEntry}

/**
 * Brute force O(n^2) method to compute exact nearest neighbours.
 * As this is a very expensive computation O(n^2) an additional sample parameter may be passed such
 * that neighbours are just computed for a random fraction.
 *
 * @param distance  a function defining a metric over a vector space
 * @param threshold pairs that are >= to the distance are discarded
 * @param fraction  compute neighbours for the given fraction
 *
 */
class NearestNeighbours(
                         distance: VectorDistance,
                         threshold: Double,
                         fraction: Double) extends Joiner with Serializable {

  def join(inputMatrix: IndexedRowMatrix): CoordinateMatrix = {
    val rows = inputMatrix.rows
    val sampledRows = rows.sample(false, fraction)
    sampledRows.cache()

    val joined = sampledRows.cartesian(rows)

    val similarity = joined.map {
      case ((rowA: IndexedRow), (rowB: IndexedRow)) =>
        ((rowA.index, rowB.index), distance(rowA.vector, rowB.vector))
    }

    val neighbours = similarity.filter {
      case ((indexA: Long, indexB: Long), similarity) =>
        similarity >= threshold &&
          indexA < indexB // make upper triangular and remove self similarities
    }

    val resultRows = neighbours.map {
      case ((indexA: Long, indexB: Long), similarity) =>
        MatrixEntry(indexA, indexB, similarity)
    }

    new CoordinateMatrix(resultRows)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy