All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.annotators.similarity.DocumentSimilarityRankerApproach.scala Maven / Gradle / Ivy

There is a newer version: 5.5.0
Show newest version
package com.johnsnowlabs.nlp.annotators.similarity

import com.johnsnowlabs.nlp.AnnotatorType.{DOC_SIMILARITY_RANKINGS, SENTENCE_EMBEDDINGS}
import com.johnsnowlabs.nlp.annotators.similarity.DocumentSimilarityUtil._
import com.johnsnowlabs.nlp.{AnnotatorApproach, HasEnableCachingProperties}
import com.johnsnowlabs.util.spark.SparkUtil.retrieveColumnName
import org.apache.spark.ml.PipelineModel
import org.apache.spark.ml.feature.{BucketedRandomProjectionLSH, MinHashLSH}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.{BooleanParam, Param}
import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.{DataFrame, Dataset}

sealed trait NeighborAnnotation {
  def neighbors: Array[_]
}

case class IndexedNeighbors(neighbors: Array[Int]) extends NeighborAnnotation

case class IndexedNeighborsWithDistance(neighbors: Array[(Int, Double)])
    extends NeighborAnnotation

case class NeighborsResultSet(result: (Int, NeighborAnnotation))

/** Annotator that uses LSH techniques present in Spark ML lib to execute approximate nearest
  * neighbors search on top of sentence embeddings.
  *
  * It aims to capture the semantic meaning of a document in a dense, continuous vector space and
  * return it to the ranker search.
  *
  * For instantiated/pretrained models, see [[DocumentSimilarityRankerModel]].
  *
  * For extended examples of usage, see the jupyter notebook
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/annotation/text/english/text-similarity/doc-sim-ranker/test_doc_sim_ranker.ipynb Document Similarity Ranker for Spark NLP]].
  *
  * ==Example==
  * {{{
  * import com.johnsnowlabs.nlp.base._
  * import com.johnsnowlabs.nlp.annotator._
  * import com.johnsnowlabs.nlp.annotators.similarity.DocumentSimilarityRankerApproach
  * import com.johnsnowlabs.nlp.finisher.DocumentSimilarityRankerFinisher
  * import org.apache.spark.ml.Pipeline
  *
  * import spark.implicits._
  *
  * val documentAssembler = new DocumentAssembler()
  *   .setInputCol("text")
  *   .setOutputCol("document")
  *
  * val sentenceEmbeddings = RoBertaSentenceEmbeddings
  *   .pretrained()
  *   .setInputCols("document")
  *   .setOutputCol("sentence_embeddings")
  *
  * val documentSimilarityRanker = new DocumentSimilarityRankerApproach()
  *   .setInputCols("sentence_embeddings")
  *   .setOutputCol("doc_similarity_rankings")
  *   .setSimilarityMethod("brp")
  *   .setNumberOfNeighbours(1)
  *   .setBucketLength(2.0)
  *   .setNumHashTables(3)
  *   .setVisibleDistances(true)
  *   .setIdentityRanking(false)
  *
  * val documentSimilarityRankerFinisher = new DocumentSimilarityRankerFinisher()
  *   .setInputCols("doc_similarity_rankings")
  *   .setOutputCols(
  *     "finished_doc_similarity_rankings_id",
  *     "finished_doc_similarity_rankings_neighbors")
  *   .setExtractNearestNeighbor(true)
  *
  * // Let's use a dataset where we can visually control similarity
  * // Documents are coupled, as 1-2, 3-4, 5-6, 7-8 and they were create to be similar on purpose
  * val data = Seq(
  *   "First document, this is my first sentence. This is my second sentence.",
  *   "Second document, this is my second sentence. This is my second sentence.",
  *   "Third document, climate change is arguably one of the most pressing problems of our time.",
  *   "Fourth document, climate change is definitely one of the most pressing problems of our time.",
  *   "Fifth document, Florence in Italy, is among the most beautiful cities in Europe.",
  *   "Sixth document, Florence in Italy, is a very beautiful city in Europe like Lyon in France.",
  *   "Seventh document, the French Riviera is the Mediterranean coastline of the southeast corner of France.",
  *   "Eighth document, the warmest place in France is the French Riviera coast in Southern France.")
  *   .toDF("text")
  *
  * val pipeline = new Pipeline().setStages(
  *   Array(
  *     documentAssembler,
  *     sentenceEmbeddings,
  *     documentSimilarityRanker,
  *     documentSimilarityRankerFinisher))
  *
  * val result = pipeline.fit(data).transform(data)
  *
  * result
  *   .select("finished_doc_similarity_rankings_id", "finished_doc_similarity_rankings_neighbors")
  *   .show(10, truncate = false)
  * +-----------------------------------+------------------------------------------+
  * |finished_doc_similarity_rankings_id|finished_doc_similarity_rankings_neighbors|
  * +-----------------------------------+------------------------------------------+
  * |1510101612                         |[(1634839239,0.12448559591306324)]        |
  * |1634839239                         |[(1510101612,0.12448559591306324)]        |
  * |-612640902                         |[(1274183715,0.1220122862046063)]         |
  * |1274183715                         |[(-612640902,0.1220122862046063)]         |
  * |-1320876223                        |[(1293373212,0.17848855164122393)]        |
  * |1293373212                         |[(-1320876223,0.17848855164122393)]       |
  * |-1548374770                        |[(-1719102856,0.23297156732534166)]       |
  * |-1719102856                        |[(-1548374770,0.23297156732534166)]       |
  * +-----------------------------------+------------------------------------------+
  * }}}
  *
  * @param uid
  *   required uid for storing annotator to disk
  * @groupname anno Annotator types
  * @groupdesc anno
  *   Required input and expected output annotator types
  * @groupname Ungrouped Members
  * @groupname param Parameters
  * @groupname setParam Parameter setters
  * @groupname getParam Parameter getters
  * @groupname Ungrouped Members
  * @groupprio param  1
  * @groupprio anno  2
  * @groupprio Ungrouped 3
  * @groupprio setParam  4
  * @groupprio getParam  5
  * @groupdesc param
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
  *   parameter values through setters and getters, respectively.
  */
class DocumentSimilarityRankerApproach(override val uid: String)
    extends AnnotatorApproach[DocumentSimilarityRankerModel]
    with HasEnableCachingProperties {

  override val description: AnnotatorType = "LSH based document similarity annotator"

  /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator
    * type
    */
  def this() = this(Identifiable.randomUID("DocumentSimilarityRankerApproach"))

  override val inputAnnotatorTypes: Array[AnnotatorType] = Array(SENTENCE_EMBEDDINGS)

  override val outputAnnotatorType: AnnotatorType = DOC_SIMILARITY_RANKINGS

  val LSH_INPUT_COL_NAME = "features"

  val LSH_OUTPUT_COL_NAME = "hashes"

  val INDEX_COL_NAME = "index"

  val DISTANCE = "distCol"

  val TEXT = "text"

  /** The similarity method used to calculate the neighbours. (Default: `"brp"`, Bucketed Random
    * Projection for Euclidean Distance)
    *
    * @group param
    */
  val similarityMethod = new Param[String](
    this,
    "similarityMethod",
    """The similarity method used to calculate the neighbours.
      |(Default: `"brp"`, Bucketed Random Projection for Euclidean Distance)
      |""".stripMargin)

  def setSimilarityMethod(value: String): this.type = set(similarityMethod, value)

  def getSimilarityMethod: String = $(similarityMethod)

  /** The number of neighbours the model will return (Default:`"10"`).
    *
    * @group param
    */
  val numberOfNeighbours = new Param[Int](
    this,
    "numberOfNeighbours",
    """The number of neighbours the model will return for each document (Default:`"10"`)""")

  def setNumberOfNeighbours(value: Int): this.type = set(numberOfNeighbours, value)

  def getNumberOfNeighbours: Int = $(numberOfNeighbours)

  val bucketLength = new Param[Double](
    this,
    "bucketLength",
    """The bucket length that controls the average size of hash buckets.
      |A larger bucket length (i.e., fewer buckets) increases the probability of features being hashed
      |to the same bucket (increasing the numbers of true and false positives)
      |""".stripMargin)

  def setBucketLength(value: Double): this.type = set(bucketLength, value)

  def getBucketLength: Double = $(bucketLength)

  val numHashTables = new Param[Int](
    this,
    "numHashTables",
    """number of hash tables, where increasing number of hash tables lowers the false negative rate,
      |and decreasing it improves the running performance.
      |""".stripMargin)

  def setNumHashTables(value: Int): this.type = set(numHashTables, value)

  val visibleDistances = new BooleanParam(
    this,
    "visibleDistances",
    "Whether to set visibleDistances in ranking output (Default: `false`)")

  def setVisibleDistances(value: Boolean): this.type = set(visibleDistances, value)

  def getVisibleDistances: Boolean = $(visibleDistances)

  val identityRanking = new BooleanParam(
    this,
    "identityRanking",
    "Whether to include identity in ranking result set. Useful for debug. (Default: `false`)")

  def setIdentityRanking(value: Boolean): this.type = set(identityRanking, value)

  def getIdentityRanking: Boolean = $(identityRanking)

  val asRetrieverQuery = new Param[String](
    this,
    "asRetrieverQuery",
    "Whether to set the model as retriever RAG with a specific query string. (Default: `empty`)")

  def asRetriever(value: String): this.type = set(asRetrieverQuery, value)

  def getAsRetrieverQuery: String = $(asRetrieverQuery)

  /** Specifies the method used to aggregate multiple sentence embeddings into a single vector
    * representation. Options include 'AVERAGE' (compute the mean of all embeddings), 'FIRST' (use
    * the first embedding only), 'MAX' (compute the element-wise maximum across embeddings)
    *
    * Default AVERAGE
    *
    * @group param
    */
  val aggregationMethod = new Param[String](
    this,
    "aggregationMethod",
    "Specifies the method used to aggregate multiple sentence embeddings into a single vector representation.")

  /** Set the method used to aggregate multiple sentence embeddings into a single vector
    * representation. Options include 'AVERAGE' (compute the mean of all embeddings), 'FIRST' (use
    * the first embedding only), 'MAX' (compute the element-wise maximum across embeddings)
    *
    * Default AVERAGE
    *
    * @group param
    */
  def setAggregationMethod(strategy: String): this.type = {
    strategy.toLowerCase() match {
      case "average" => set(aggregationMethod, "AVERAGE")
      case "first" => set(aggregationMethod, "FIRST")
      case "max" => set(aggregationMethod, "MAX")
      case _ => throw new MatchError("aggregationMethod must be AVERAGE, FIRST or MAX")
    }
  }

  def getAggregationMethod: String = $(aggregationMethod)

  setDefault(
    similarityMethod -> "brp",
    numberOfNeighbours -> 10,
    bucketLength -> 2.0,
    numHashTables -> 3,
    visibleDistances -> false,
    identityRanking -> false,
    asRetrieverQuery -> "",
    aggregationMethod -> "AVERAGE")

  def getNeighborsResultSet(
      query: (Int, Vector),
      similarityDataset: DataFrame): NeighborsResultSet = {

    val lsh = $(similarityMethod) match {
      case "brp" =>
        new BucketedRandomProjectionLSH()
          .setBucketLength($(bucketLength))
          .setNumHashTables($(numHashTables))
          .setInputCol(LSH_INPUT_COL_NAME)
          .setOutputCol(LSH_OUTPUT_COL_NAME)
      case "mh" =>
        new MinHashLSH()
          .setNumHashTables($(numHashTables))
          .setInputCol(LSH_INPUT_COL_NAME)
          .setOutputCol(LSH_OUTPUT_COL_NAME)
      case _ =>
        throw new IllegalArgumentException(s"${$(similarityMethod)} is not a valid value.")
    }

    val model = lsh.fit(similarityDataset)

    query match {
      case (index, queryVector) =>
        val _similarityDataset =
          if (getIdentityRanking) {
            similarityDataset
          } else {
            similarityDataset.where(col("index") =!= index)
          }

        val similarRankedDocs =
          model.approxNearestNeighbors(_similarityDataset, queryVector, getNumberOfNeighbours)

        if (getVisibleDistances) {
          val rankedNeighboursWithDistances = similarRankedDocs
            .select(INDEX_COL_NAME, DISTANCE)
            .collect()
            .map(row => (row.getInt(0), row.getDouble(1)))

          NeighborsResultSet((index, IndexedNeighborsWithDistance(rankedNeighboursWithDistances)))
        } else {
          val rankedNeighbours = similarRankedDocs
            .select(INDEX_COL_NAME)
            .collect()
            .map(_.getInt(0))

          NeighborsResultSet(index, IndexedNeighbors(rankedNeighbours))
        }
      case _ =>
        throw new IllegalArgumentException("asRetrieverQuery is not of type (Int, DenseVector)")
    }
  }

  override def train(
      embeddingsDataset: Dataset[_],
      recursivePipeline: Option[PipelineModel]): DocumentSimilarityRankerModel = {

    val inputEmbeddingsColumn =
      s"${retrieveColumnName(embeddingsDataset, SENTENCE_EMBEDDINGS)}.embeddings"

    val similarityDataset: DataFrame = getAggregationMethod match {
      case "AVERAGE" =>
        embeddingsDataset
          .withColumn(s"$LSH_INPUT_COL_NAME", averageAggregation(col(inputEmbeddingsColumn)))
      case "FIRST" =>
        embeddingsDataset
          .withColumn(
            s"$LSH_INPUT_COL_NAME",
            firstEmbeddingAggregation(col(inputEmbeddingsColumn)))
      case "MAX" =>
        embeddingsDataset
          .withColumn(s"$LSH_INPUT_COL_NAME", maxAggregation(col(inputEmbeddingsColumn)))
    }

    val similarityDatasetWithHashIndex =
      similarityDataset.withColumn(INDEX_COL_NAME, mh3UDF(col(TEXT)))

    val indexedVectorTuples = similarityDatasetWithHashIndex
      .select(INDEX_COL_NAME, LSH_INPUT_COL_NAME)
      .rdd
      .map(x => (x.getAs[Int](INDEX_COL_NAME), x.getAs[Vector](LSH_INPUT_COL_NAME)))
      .collect()

    val asRetrieverQuery = getAsRetrieverQuery

    val similarityMappings: Map[Int, NeighborAnnotation] =
      if (asRetrieverQuery.isEmpty)
        indexedVectorTuples
          .map(query => getNeighborsResultSet(query, similarityDatasetWithHashIndex))
          .map(_.result)
          .toMap
      else
        similarityDatasetWithHashIndex
          .where(col("text") === asRetrieverQuery)
          .select(INDEX_COL_NAME, LSH_INPUT_COL_NAME)
          .rdd
          .map(x => (x.getAs[Int](INDEX_COL_NAME), x.getAs[Vector](LSH_INPUT_COL_NAME)))
          .collect()
          .map(asRetrieverQuery =>
            getNeighborsResultSet(asRetrieverQuery, similarityDatasetWithHashIndex))
          .map(_.result)
          .toMap

    new DocumentSimilarityRankerModel()
      .setSimilarityMappings(Map("similarityMappings" -> similarityMappings))
  }
}

/** This is the companion object of [[DocumentSimilarityRankerApproach]]. Please refer to that
  * class for the documentation.
  */
object DocumentSimilarityRankerApproach
    extends DefaultParamsReadable[DocumentSimilarityRankerApproach]




© 2015 - 2024 Weber Informatics LLC | Privacy Policy