All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.ml.knn.KNN.scala Maven / Gradle / Ivy

The newest version!
package org.apache.spark.ml.knn

import breeze.linalg.{DenseVector, Vector => BV}
import breeze.stats._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.classification.KNNClassificationModel
import org.apache.spark.ml.knn.KNN.{KNNPartitioner, RowWithVector, VectorWithNorm}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.regression.KNNRegressionModel
import org.apache.spark.ml.util._
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.linalg.{Vector, VectorUDT, Vectors}
import org.apache.spark.mllib.rdd.MLPairRDDFunctions._
import org.apache.spark.rdd.{RDD, ShuffledRDD}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.random.XORShiftRandom
import org.apache.spark.{HashPartitioner, Partitioner}
import org.apache.log4j
import org.apache.spark.mllib.knn.KNNUtils

import scala.annotation.tailrec
import scala.collection.mutable.ArrayBuffer
import scala.util.hashing.byteswap64

// features column => vector, input columns => auxiliary columns to return by KNN model
private[ml] trait KNNModelParams extends Params with HasFeaturesCol with HasInputCols {
  /**
    * Param for the column name for returned neighbors.
    * Default: "neighbors"
    *
    * @group param
    */
  val neighborsCol = new Param[String](this, "neighborsCol", "column names for returned neighbors")

  /** @group getParam */
  def getNeighborsCol: String = $(neighborsCol)

  /**
    * Param for distance column that will create a distance column of each nearest neighbor
    * Default: no distance column will be used
    *
    * @group param
    */
  val distanceCol = new Param[String](this, "distanceCol", "column that includes each neighbors' distance as an additional column")

  /** @group getParam */
  def getDistanceCol: String = $(distanceCol)

  /**
    * Param for number of neighbors to find (> 0).
    * Default: 5
    *
    * @group param
    */
  val k = new IntParam(this, "k", "number of neighbors to find", ParamValidators.gt(0))

  /** @group getParam */
  def getK: Int = $(k)

  /**
   * Param for maximum distance to find neighbors
   * Default: Double.PositiveInfinity
    *
    * @group param
   */
  val maxDistance = new DoubleParam(this, "maxNeighbors", "maximum distance to find neighbors", // todo: maxDistance or maxNeighbors?
                                     ParamValidators.gt(0))

  /** @group getParam */
  def getMaxDistance: Double = $(maxDistance)

  /**
    * Param for size of buffer used to construct spill trees and top-level tree search.
    * Note the buffer size is 2 * tau as described in the paper.
    *
    * When buffer size is 0.0, the tree itself reverts to a metric tree.
    * -1.0 triggers automatic effective nearest neighbor distance estimation.
    *
    * Default: -1.0
    *
    * @group param
    */
  val bufferSize = new DoubleParam(this, "bufferSize",
    "size of buffer used to construct spill trees and top-level tree search", ParamValidators.gtEq(-1.0))

  /** @group getParam */
  def getBufferSize: Double = $(bufferSize)

  private[ml] def transform(data: RDD[Vector], topTree: Broadcast[Tree], subTrees: RDD[Tree]): RDD[(Long, Array[(Row,Double)])] = {
    val searchData = data.zipWithIndex()
      .flatMap {
        case (vector, index) =>
          val vectorWithNorm = new VectorWithNorm(vector)
          val idx = KNN.searchIndices(vectorWithNorm, topTree.value, $(bufferSize))
            .map(i => (i, (vectorWithNorm, index)))

          assert(idx.nonEmpty, s"indices must be non-empty: $vector ($index)")
          idx
      }
      .partitionBy(new HashPartitioner(subTrees.partitions.length))

    // for each partition, search points within corresponding child tree
    val results = searchData.zipPartitions(subTrees) {
      (childData, trees) =>
        val tree = trees.next()
        assert(!trees.hasNext)
        childData.flatMap {
          case (_, (point, i)) =>
            tree.query(point, $(k)).collect {
              case (neighbor, distance) if distance <= $(maxDistance) =>
                (i, (neighbor.row, distance))
            }
        }
    }

    // merge results by point index together and keep topK results
    results.topByKey($(k))(Ordering.by(-_._2))
      .map { case (i, seq) => (i, seq) }
  }

  private[ml] def transform(dataset: Dataset[_], topTree: Broadcast[Tree], subTrees: RDD[Tree]): RDD[(Long, Array[(Row, Double)])] = {
    transform(dataset.select($(featuresCol)).rdd.map(_.getAs[Vector](0)), topTree, subTrees)
  }

}

private[ml] trait KNNParams extends KNNModelParams with HasSeed {
  /**
    * Param for number of points to sample for top-level tree (> 0).
    * Default: 1000
    *
    * @group param
    */
  val topTreeSize = new IntParam(this, "topTreeSize", "number of points to sample for top-level tree", ParamValidators.gt(0))

  /** @group getParam */
  def getTopTreeSize: Int = $(topTreeSize)

  /**
    * Param for number of points at which to switch to brute-force for top-level tree (> 0).
    * Default: 5
    *
    * @group param
    */
  val topTreeLeafSize = new IntParam(this, "topTreeLeafSize",
    "number of points at which to switch to brute-force for top-level tree", ParamValidators.gt(0))

  /** @group getParam */
  def getTopTreeLeafSize: Int = $(topTreeLeafSize)

  /**
    * Param for number of points at which to switch to brute-force for distributed sub-trees (> 0).
    * Default: 20
    *
    * @group param
    */
  val subTreeLeafSize = new IntParam(this, "subTreeLeafSize",
    "number of points at which to switch to brute-force for distributed sub-trees", ParamValidators.gt(0))

  /** @group getParam */
  def getSubTreeLeafSize: Int = $(subTreeLeafSize)

  /**
    * Param for number of sample sizes to take when estimating buffer size (at least two samples).
    * Default: 100 to 1000 by 100
    *
    * @group param
    */
  val bufferSizeSampleSizes = new IntArrayParam(this, "bufferSizeSampleSize",  // todo: should this have an 's' at the end?
    "number of sample sizes to take when estimating buffer size", { arr: Array[Int] => arr.length > 1 && arr.forall(_ > 0) })

  /** @group getParam */
  def getBufferSizeSampleSizes: Array[Int] = $(bufferSizeSampleSizes)

  /**
    * Param for fraction of total points at which spill tree reverts back to metric tree
    * if either child contains more points (0 <= rho <= 1).
    * Default: 70%
    *
    * @group param
    */
  val balanceThreshold = new DoubleParam(this, "balanceThreshold",
    "fraction of total points at which spill tree reverts back to metric tree if either child contains more points",
    ParamValidators.inRange(0, 1))

  /** @group getParam */
  def getBalanceThreshold: Double = $(balanceThreshold)

  setDefault(topTreeSize -> 1000, topTreeLeafSize -> 10, subTreeLeafSize -> 30,
    bufferSize -> -1.0, bufferSizeSampleSizes -> (100 to 1000 by 100).toArray, balanceThreshold -> 0.7,
    k -> 5, neighborsCol -> "neighbors", distanceCol -> "", maxDistance -> Double.PositiveInfinity)

  /**
    * Validates and transforms the input schema.
    *
    * @param schema input schema
    * @return output schema
    */
  protected def validateAndTransformSchema(schema: StructType): StructType = {
    SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
    val auxFeatures = $(inputCols).map(c => schema(c))
    val schemaWithNeighbors = SchemaUtils.appendColumn(schema, $(neighborsCol), ArrayType(StructType(auxFeatures)))

    if ($(distanceCol).isEmpty) {
      schemaWithNeighbors
    } else {
      SchemaUtils.appendColumn(schemaWithNeighbors, $(distanceCol), ArrayType(DoubleType))
    }
  }
}

/**
  * kNN Model facilitates k-Nestrest Neighbor search by storing distributed hybrid spill tree.
  * Top level tree is a MetricTree but instead of using back tracking, it searches all possible leaves in parallel
  * to avoid multiple iterations. It uses the same buffer size that is used in model training, when the search
  * vector falls into the buffer zone of the node, it dispatches search to both children.
  *
  * A high level overview of the search phases is as follows:
  *
  *  1. For each vector to search, go through the top level tree to output a pair of (index, point)
  *  1. Repartition search points by partition index
  *  1. Search each point through the hybrid spill tree in that particular partition
  *  1. For each point, merge results from different partitions and keep top k results.
  *
  */
class KNNModel private[ml](
                            override val uid: String,
                            val topTree: Broadcast[Tree],
                            val subTrees: RDD[Tree]
                          ) extends Model[KNNModel] with KNNModelParams {
  require(subTrees.getStorageLevel != StorageLevel.NONE,
    "KNNModel is not designed to work with Trees that have not been cached")

  /** @group setParam */
  def setNeighborsCol(value: String): this.type = set(neighborsCol, value)

  /** @group setParam */
  def setDistanceCol(value: String): this.type = set(distanceCol, value)

  /** @group setParam */
  def setK(value: Int): this.type = set(k, value)

  /** @group setParam */
  def setMaxDistance(value: Double): this.type = set(maxDistance, value)

  /** @group setParam */
  def setBufferSize(value: Double): this.type = set(bufferSize, value)

  //TODO: All these can benefit from DataSet API
  override def transform(dataset: Dataset[_]): DataFrame = {
    val merged: RDD[(Long, Array[(Row,Double)])] = transform(dataset, topTree, subTrees)

    val withDistance = $(distanceCol).nonEmpty

    dataset.sqlContext.createDataFrame(
      dataset.toDF().rdd.zipWithIndex().map { case (row, i) => (i, row) }
        .leftOuterJoin(merged)
        .map {
          case (i, (row, neighborsAndDistances)) =>
            val (neighbors, distances) = neighborsAndDistances.map(_.unzip).getOrElse((Array.empty[Row], Array.empty[Double]))
            if (withDistance) {
              Row.fromSeq(row.toSeq :+ neighbors :+ distances)
            } else {
              Row.fromSeq(row.toSeq :+ neighbors)
            }
        },
      transformSchema(dataset.schema)
    )
  }

  override def transformSchema(schema: StructType): StructType = {
    val auxFeatures = $(inputCols).map(c => schema(c))
    val schemaWithNeighbors = SchemaUtils.appendColumn(schema, $(neighborsCol), ArrayType(StructType(auxFeatures)))
    if ($(distanceCol).isEmpty) {
      schemaWithNeighbors
    } else {
      SchemaUtils.appendColumn(schemaWithNeighbors, $(distanceCol), ArrayType(DoubleType))
    }
  }

  override def copy(extra: ParamMap): KNNModel = {
    val copied = new KNNModel(uid, topTree, subTrees)
    copyValues(copied, extra).setParent(parent)
  }

  def toNewClassificationModel(uid: String, numClasses: Int): KNNClassificationModel = {
    copyValues(new KNNClassificationModel(uid, topTree, subTrees, numClasses))
  }

  def toNewRegressionModel(uid: String): KNNRegressionModel = {
    copyValues(new KNNRegressionModel(uid, topTree, subTrees))
  }
}

/**
  * k-Nearest Neighbors (kNN) algorithm
  *
  * kNN finds k closest observations in training dataset. It can be used for both classification and regression.
  * Furthermore it can also be used for other purposes such as input to clustering algorithm.
  *
  * While the brute-force approach requires no pre-training, each prediction requires going through the entire training
  * set resulting O(n log(k)) runtime per individual prediction using a heap keep track of neighbor candidates.
  * Many different implementations have been proposed such as Locality Sensitive Hashing (LSH), KD-Tree, Metric Tree and etc.
  * Each algorithm has its shortcomings that prevent them to be effective on large-scale and/or high-dimensional dataset.
  *
  * This is an implementation of kNN based upon distributed Hybrid Spill-Trees where training points are organized into
  * distributed binary trees. The algorithm is designed to support accurate approximate kNN search but by tuning parameters
  * an exact search can also be performed with cost of additional runtime.
  *
  * Each binary tree node is either a
  *
  * '''Metric Node''':
  * Metric Node partition points exclusively into two children by finding two pivot points and divide by middle plane.
  * When searched, the child whose pivot is closer to query vector is searched first. Back tracking is required to
  * ensure accuracy in this case, where the other child should be searched if it can possibly contain better neighbor
  * based upon candidates picked during previous search.
  *
  * '''Spill Node''':
  * Spill Node also partitions points into two children however there are an overlapping buffer between the two pivot
  * points. The larger the buffer size, the less effective the node eliminates points thus could increase tree height.
  * When searched, defeatist search is used where only one child is searched and no back tracking happens in this
  * process. Because of the buffer between two children, we are likely to end up with good enough candidates without
  * searching the other part of the tree.
  *
  * While Spill Node promises O(h) runtime where h is the tree height, the tree is deeper than Metric Tree's O(log n)
  * height on average. Furthermore, when it comes down to leaves where points are more closer to each other, the static
  * buffer size means more points will end up in the buffer. Therefore a Balance Threshold (rho) is introduced: when
  * either child of Spill Node makes up more than rho fraction of the total points at this level, Spill Node is reverted
  * back to a Metric Node.
  *
  * A high level overview of the algorithm is as follows:
  *
  *  1. Sample M data points (M is relatively small and can be held in driver)
  *  1. Build the top level metric tree
  *  1. Repartition RDD by assigning each point to leaf node of the above tree
  *  1. Build a hybrid spill tree at each partition
  *
  * This concludes the training phase of kNN.
  * See [[KNNModel]] for details on prediction phase.
  *
  *
  * This algorithm is described in [[http://dx.doi.org/10.1109/WACV.2007.18]] where it was shown to scale well in terms of
  * number of observations and dimensions, bounded by the available memory across clusters (billions in paper's example).
  * This implementation adapts the MapReduce algorithm to work with Spark.
  *
  */
class KNN(override val uid: String) extends Estimator[KNNModel] with KNNParams {
  def this() = this(Identifiable.randomUID("knn"))

  /** @group setParam */
  def setFeaturesCol(value: String): this.type = set(featuresCol, value)

  /** @group setParam */
  def setK(value: Int): this.type = set(k, value)

  /** @group setParam */
  def setAuxCols(value: Array[String]): this.type = set(inputCols, value)

  /** @group setParam */
  def setTopTreeSize(value: Int): this.type = set(topTreeSize, value)

  /** @group setParam */
  def setTopTreeLeafSize(value: Int): this.type = set(topTreeLeafSize, value)

  /** @group setParam */
  def setSubTreeLeafSize(value: Int): this.type = set(subTreeLeafSize, value)

  /** @group setParam */
  def setBufferSizeSampleSizes(value: Array[Int]): this.type = set(bufferSizeSampleSizes, value)

  /** @group setParam */
  def setBalanceThreshold(value: Double): this.type = set(balanceThreshold, value)

  /** @group setParam */
  def setSeed(value: Long): this.type = set(seed, value)

  override def fit(dataset: Dataset[_]): KNNModel = {
    val rand = new XORShiftRandom($(seed))
    //prepare data for model estimation
    val data = dataset.selectExpr($(featuresCol), $(inputCols).mkString("struct(", ",", ")"))
      .rdd
      .map(row => new RowWithVector(row.getAs[Vector](0), row.getStruct(1)))
    //sample data to build top-level tree
    val sampled = data.sample(withReplacement = false, $(topTreeSize).toDouble / dataset.count(), rand.nextLong()).collect()
    val topTree = MetricTree.build(sampled, $(topTreeLeafSize), rand.nextLong())
    //build partitioner using top-level tree
    val part = new KNNPartitioner(topTree)
    //noinspection ScalaStyle
    val repartitioned = new ShuffledRDD[RowWithVector, Null, Null](data.map(v => (v, null)), part).keys

    val tau =
      if ($(balanceThreshold) > 0 && $(bufferSize) < 0) {
        KNN.estimateTau(data, $(bufferSizeSampleSizes), rand.nextLong())
      } else {
        math.max(0, $(bufferSize))
      }
    logInfo("Tau is: " + tau)

    val trees = repartitioned.mapPartitionsWithIndex {
      (partitionId, itr) =>
        val rand = new XORShiftRandom(byteswap64($(seed) ^ partitionId))
        val childTree =
          HybridTree.build(itr.toIndexedSeq, $(subTreeLeafSize), tau, $(balanceThreshold), rand.nextLong())

        Iterator(childTree)
    }.persist(StorageLevel.MEMORY_AND_DISK)
    // TODO: force persisting trees primarily for benchmark. any reason not to do this for regular runs?
    trees.count()

    val model = new KNNModel(uid, trees.context.broadcast(topTree), trees).setParent(this)
    copyValues(model).setBufferSize(tau)
  }

  override def transformSchema(schema: StructType): StructType = {
    validateAndTransformSchema(schema)
  }

  override def copy(extra: ParamMap): KNN = defaultCopy(extra)
}


object KNN {

  val logger = log4j.Logger.getLogger(classOf[KNN])

  /**
    * VectorWithNorm can use more efficient algorithm to calculate distance
    */
  case class VectorWithNorm(vector: Vector, norm: Double) {
    def this(vector: Vector) = this(vector, Vectors.norm(vector, 2))

    def this(vector: BV[Double]) = this(Vectors.fromBreeze(vector))

    def fastSquaredDistance(v: VectorWithNorm): Double = {
      KNNUtils.fastSquaredDistance(vector, norm, v.vector, v.norm)
    }

    def fastDistance(v: VectorWithNorm): Double = math.sqrt(fastSquaredDistance(v))
  }

  /**
    * VectorWithNorm plus auxiliary row information
    */
  case class RowWithVector(vector: VectorWithNorm, row: Row) {
    def this(vector: Vector, row: Row) = this(new VectorWithNorm(vector), row)
  }

  /**
    * Estimate a suitable buffer size based on dataset
    *
    * A suitable buffer size is the minimum size such that nearest neighbors can be accurately found even at
    * boundary of splitting plane between pivot points. Therefore assuming points are uniformly distributed in
    * high dimensional space, it should be approximately the average distance between points.
    *
    * Specifically the number of points within a certain radius of a given point is proportionally to the density of
    * points raised to the effective number of dimensions, of which manifold data points exist on:
    * R_s = \frac{c}{N_s ** 1/d}
    * where R_s is the radius, N_s is the number of points, d is effective number of dimension, and c is a constant.
    *
    * To estimate R_s_all for entire dataset, we can take samples of the dataset of different size N_s to compute R_s.
    * We can estimate c and d using linear regression. Lastly we can calculate R_s_all using total number of observation
    * in dataset.
    *
    */
  def estimateTau(data: RDD[RowWithVector], sampleSize: Array[Int], seed: Long): Double = {
    val total = data.count()

    // take samples of points for estimation
    val samples = data.mapPartitionsWithIndex {
      case (partitionId, itr) =>
        val rand = new XORShiftRandom(byteswap64(seed ^ partitionId))
        itr.flatMap {
          p => sampleSize.zipWithIndex
            .filter { case (size, _) => rand.nextDouble() * total < size }
            .map { case (size, index) => (index, p) }
        }
    }
    // compute N_s and R_s pairs
    val estimators = samples
      .groupByKey()
      .map {
        case (index, points) => (points.size, computeAverageDistance(points))
      }.collect().distinct

    // collect x and y vectors
    val x = DenseVector(estimators.map { case (n, _) => math.log(n) })
    val y = DenseVector(estimators.map { case (_, d) => math.log(d) })

    // estimate log(R_s) = alpha + beta * log(N_s)
    val xMeanVariance: MeanAndVariance = meanAndVariance(x)
    val xmean = xMeanVariance.mean
    val yMeanVariance: MeanAndVariance = meanAndVariance(y)
    val ymean = yMeanVariance.mean

    val corr = (mean(x :* y) - xmean * ymean) / math.sqrt((mean(x :* x) - xmean * xmean) * (mean(y :* y) - ymean * ymean))

    val beta = corr * yMeanVariance.stdDev / xMeanVariance.stdDev
    val alpha = ymean - beta * xmean
    val rs = math.exp(alpha + beta * math.log(total))

    if (beta > 0 || beta.isNaN || rs.isNaN) {
      val yMax = breeze.linalg.max(y)
      logger.error(
        s"""Unable to estimate Tau with positive beta: $beta. This maybe because data is too small.
            |Setting to $yMax which is the maximum average distance we found in the sample.
            |This may leads to poor accuracy. Consider manually set bufferSize instead.
            |You can also try setting balanceThreshold to zero so only metric trees are built.""".stripMargin)
      yMax
    } else {
      // c = alpha, d = - 1 / beta
      rs / math.sqrt(-1 / beta)
    }
  }

  // compute the average distance of nearest neighbors within points using brute-force
  private[this] def computeAverageDistance(points: Iterable[RowWithVector]): Double = {
    val distances = points.map {
      point => points.map(p => p.vector.fastSquaredDistance(point.vector)).filter(_ > 0).min
    }.map(math.sqrt)

    distances.sum / distances.size
  }

  /**
    * Search leaf index used by KNNPartitioner to partition training points
    *
    * @param v one training point to partition
    * @param tree top tree constructed using sampled points
    * @param acc accumulator used to help determining leaf index
    * @return leaf/partition index
    */
  @tailrec
  private[knn] def searchIndex(v: RowWithVector, tree: Tree, acc: Int = 0): Int = {
    tree match {
      case node: MetricTree =>
        val leftDistance = node.leftPivot.fastSquaredDistance(v.vector)
        val rightDistance = node.rightPivot.fastSquaredDistance(v.vector)
        if (leftDistance < rightDistance) {
          searchIndex(v, node.leftChild, acc)
        } else {
          searchIndex(v, node.rightChild, acc + node.leftChild.leafCount)
        }
      case _ => acc // reached leaf
    }
  }

  //TODO: Might want to make this tail recursive
  private[ml] def searchIndices(v: VectorWithNorm, tree: Tree, tau: Double, acc: Int = 0): Seq[Int] = {
    tree match {
      case node: MetricTree =>
        val leftDistance = node.leftPivot.fastDistance(v)
        val rightDistance = node.rightPivot.fastDistance(v)

        val buffer = new ArrayBuffer[Int]
        if (leftDistance - rightDistance <= tau) {
          buffer ++= searchIndices(v, node.leftChild, tau, acc)
        }

        if (rightDistance - leftDistance <= tau) {
          buffer ++= searchIndices(v, node.rightChild, tau, acc + node.leftChild.leafCount)
        }

        buffer
      case _ => Seq(acc) // reached leaf
    }
  }

  /**
    * Partitioner used to map vector to leaf node which determines the partition it goes to
    *
    * @param tree `Tree` used to find leaf
    */
  class KNNPartitioner[T <: RowWithVector](tree: Tree) extends Partitioner {
    override def numPartitions: Int = tree.leafCount

    override def getPartition(key: Any): Int = {
      key match {
        case v: RowWithVector => searchIndex(v, tree)
        case _ => throw new IllegalArgumentException(s"Key must be of type Vector but got: $key")
      }
    }

  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy