All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.sona.ml.feature.BucketedRandomProjectionLSH.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.tencent.angel.sona.ml.feature

import scala.util.Random
import breeze.linalg.normalize
import org.apache.hadoop.fs.Path
import com.tencent.angel.sona.ml.param._
import com.tencent.angel.sona.ml.param.shared.HasSeed
import com.tencent.angel.sona.ml.util._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
import org.apache.spark.linalg
import org.apache.spark.linalg.{BLAS, Matrices, Matrix, VectorUDT, Vectors}
import org.apache.spark.sql.util.SONASchemaUtils

/**
  * :: Experimental ::
  *
  * Params for [[BucketedRandomProjectionLSH]].
  */
private[sona] trait BucketedRandomProjectionLSHParams extends Params {

  /**
    * The length of each hash bucket, a larger bucket lowers the false negative rate. The number of
    * buckets will be `(max L2 norm of input vectors) / bucketLength`.
    *
    *
    * If input vectors are normalized, 1-10 times of pow(numRecords, -1/inputDim) would be a
    * reasonable value
    *
    * @group param
    */
  val bucketLength: DoubleParam = new DoubleParam(this, "bucketLength",
    "the length of each hash bucket, a larger bucket lowers the false negative rate.",
    ParamValidators.gt(0))

  /** @group getParam */
  final def getBucketLength: Double = $(bucketLength)
}

/**
  * :: Experimental ::
  *
  * Model produced by [[BucketedRandomProjectionLSH]], where multiple random vectors are stored. The
  * vectors are normalized to be unit vectors and each vector is used in a hash function:
  * `h_i(x) = floor(r_i.dot(x) / bucketLength)`
  * where `r_i` is the i-th random unit vector. The number of buckets will be `(max L2 norm of input
  * vectors) / bucketLength`.
  *
  * @param randUnitVectors An array of random unit vectors. Each vector represents a hash function.
  */
class BucketedRandomProjectionLSHModel private[sona](
                                                      override val uid: String,
                                                      private[sona] val randUnitVectors: Array[linalg.Vector])
  extends LSHModel[BucketedRandomProjectionLSHModel] with BucketedRandomProjectionLSHParams {

  /** @group setParam */

  override def setInputCol(value: String): this.type = super.set(inputCol, value)

  /** @group setParam */

  override def setOutputCol(value: String): this.type = super.set(outputCol, value)


  override def hashFunction(elems: linalg.Vector): Array[linalg.Vector] = {
    val hashValues = randUnitVectors.map(
      randUnitVector => Math.floor(BLAS.dot(elems, randUnitVector) / $(bucketLength))
    )
    // TODO: Output vectors of dimension numHashFunctions in SPARK-18450
    hashValues.map(Vectors.dense(_))
  }


  override def keyDistance(x: linalg.Vector, y: linalg.Vector): Double = {
    Math.sqrt(Vectors.sqdist(x, y))
  }


  override def hashDistance(x: Seq[linalg.Vector], y: Seq[linalg.Vector]): Double = {
    // Since it's generated by hashing, it will be a pair of dense vectors.
    x.zip(y).map(vectorPair => Vectors.sqdist(vectorPair._1, vectorPair._2)).min
  }


  override def copy(extra: ParamMap): BucketedRandomProjectionLSHModel = {
    val copied = new BucketedRandomProjectionLSHModel(uid, randUnitVectors).setParent(parent)
    copyValues(copied, extra)
  }


  override def write: MLWriter = {
    new BucketedRandomProjectionLSHModel.BucketedRandomProjectionLSHModelWriter(this)
  }
}

/**
  * :: Experimental ::
  *
  * This [[BucketedRandomProjectionLSH]] implements Locality Sensitive Hashing functions for
  * Euclidean distance metrics.
  *
  * The input is dense or sparse vectors, each of which represents a point in the Euclidean
  * distance space. The output will be vectors of configurable dimension. Hash values in the
  * same dimension are calculated by the same hash function.
  *
  * References:
  *
  * 1. 
  * Wikipedia on Stable Distributions
  *
  * 2. Wang, Jingdong et al. "Hashing for similarity search: A survey." arXiv preprint
  * arXiv:1408.2927 (2014).
  */
class BucketedRandomProjectionLSH(override val uid: String)
  extends LSH[BucketedRandomProjectionLSHModel]
    with BucketedRandomProjectionLSHParams with HasSeed {


  override def setInputCol(value: String): this.type = super.setInputCol(value)


  override def setOutputCol(value: String): this.type = super.setOutputCol(value)


  override def setNumHashTables(value: Int): this.type = super.setNumHashTables(value)


  def this() = {
    this(Identifiable.randomUID("brp-lsh"))
  }

  /** @group setParam */

  def setBucketLength(value: Double): this.type = set(bucketLength, value)

  /** @group setParam */

  def setSeed(value: Long): this.type = set(seed, value)


  override protected[this] def createRawLSHModel(
                                                  inputDim: Int): BucketedRandomProjectionLSHModel = {
    val rand = new Random($(seed))
    val randUnitVectors: Array[linalg.Vector] = {
      Array.fill($(numHashTables)) {
        val randArray = Array.fill(inputDim)(rand.nextGaussian())
        Vectors.fromBreeze(normalize(breeze.linalg.Vector(randArray)))
      }
    }
    new BucketedRandomProjectionLSHModel(uid, randUnitVectors)
  }


  override def transformSchema(schema: StructType): StructType = {
    SONASchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
    validateAndTransformSchema(schema)
  }


  override def copy(extra: ParamMap): this.type = defaultCopy(extra)
}


object BucketedRandomProjectionLSH extends DefaultParamsReadable[BucketedRandomProjectionLSH] {
  override def load(path: String): BucketedRandomProjectionLSH = super.load(path)
}


object BucketedRandomProjectionLSHModel extends MLReadable[BucketedRandomProjectionLSHModel] {
  override def read: MLReader[BucketedRandomProjectionLSHModel] = {
    new BucketedRandomProjectionLSHModelReader
  }


  override def load(path: String): BucketedRandomProjectionLSHModel = super.load(path)

  private[BucketedRandomProjectionLSHModel] class BucketedRandomProjectionLSHModelWriter(
                                                                                          instance: BucketedRandomProjectionLSHModel) extends MLWriter {

    // TODO: Save using the existing format of Array[Vector] once SPARK-12878 is resolved.
    private case class Data(randUnitVectors: Matrix)

    override protected def saveImpl(path: String): Unit = {
      DefaultParamsWriter.saveMetadata(instance, path, sc)
      val numRows = instance.randUnitVectors.length
      require(numRows > 0)
      val numCols = instance.randUnitVectors.head.size.toInt
      val values = instance.randUnitVectors.map(_.toArray).reduce(Array.concat(_, _))
      val randMatrix = Matrices.dense(numRows, numCols, values)
      val data = Data(randMatrix)
      val dataPath = new Path(path, "data").toString
      sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
    }
  }

  private class BucketedRandomProjectionLSHModelReader
    extends MLReader[BucketedRandomProjectionLSHModel] {

    /** Checked against metadata when loading model */
    private val className = classOf[BucketedRandomProjectionLSHModel].getName

    override def load(path: String): BucketedRandomProjectionLSHModel = {
      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

      val dataPath = new Path(path, "data").toString
      val data = sparkSession.read.parquet(dataPath)
      val Row(randUnitVectors: Matrix) = MLUtils.convertMatrixColumnsToML(data, "randUnitVectors")
        .select("randUnitVectors")
        .head()
      val model = new BucketedRandomProjectionLSHModel(metadata.uid,
        randUnitVectors.rowIter.toArray)

      metadata.getAndSetParams(model)
      model
    }
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy