Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
com.johnsnowlabs.nlp.finisher.DocumentSimilarityRankerFinisher.scala Maven / Gradle / Ivy
package com.johnsnowlabs.nlp.finisher
import com.johnsnowlabs.nlp.AnnotatorType
import com.johnsnowlabs.nlp.util.FinisherUtil
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.{DataFrame, Dataset}
import scala.util.hashing.MurmurHash3
case class DocumentSimilarityRankerFinisher(override val uid: String)
extends Transformer
with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("DOCUMENT_SIMILARITY_RANKER_FINISHER"))
val LSH_ID_COL_NAME = "lshId"
val LSH_NEIGHBORS_COL_NAME = "lshNeighbors"
val FINISHED_DOC_SIM_RANKER_ID_DEFAULT = "finished_doc_similarity_rankings_id"
val FINISHED_DOC_SIM_RANKER_NEIGHBORS_DEFAULT = "finished_doc_similarity_rankings_neighbors"
/** Name of input annotation cols containing embeddings
*
* @group param
*/
val inputCols: StringArrayParam =
new StringArrayParam(
this,
"inputCols",
"Name of input annotation cols containing similar documents")
/** Name of input annotation cols containing similar documents
*
* @group setParam
*/
def setInputCols(value: Array[String]): this.type = set(inputCols, value)
/** Name of input annotation cols containing similar documents
*
* @group setParam
*/
def setInputCols(value: String*): this.type = setInputCols(value.toArray)
/** Name of DocumentSimilarityRankerFinisher output cols
*
* @group getParam
*/
def getInputCols: Array[String] = $(inputCols)
/** Name of DocumentSimilarityRankerFinisher output cols
*
* @group param
*/
val outputCols: StringArrayParam =
new StringArrayParam(
this,
"outputCols",
"Name of DocumentSimilarityRankerFinisher output cols")
/** Name of DocumentSimilarityRankerFinisher output cols
*
* @group setParam
*/
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
/** Name of DocumentSimilarityRankerFinisher output cols
*
* @group setParam
*/
def setOutputCols(value: String*): this.type = setOutputCols(value.toArray)
/** Name of input annotation cols containing embeddings
*
* @group getParam
*/
def getOutputCols: Array[String] = get(outputCols).getOrElse(getInputCols.map("finished_" + _))
val extractNearestNeighbor: BooleanParam =
new BooleanParam(
this,
"extractNearestNeighbor",
doc = "Extract the best neighbors with distance")
/** Set flag to extract best neighbor with distance
*
* @group setParam
*/
def setExtractNearestNeighbor(value: Boolean): this.type = set(extractNearestNeighbor, value)
/** Name of input annotation cols containing embeddings
*
* @group getParam
*/
def getExtractNearestNeighbor: Boolean = $(extractNearestNeighbor)
setDefault(extractNearestNeighbor -> false)
override def transform(dataset: Dataset[_]): DataFrame = {
require(
getOutputCols.length == 1 || getOutputCols.length == 2,
"Output column array should have length 1 (default case) or 2 when value id and neighbors are assigned.")
val (idColName, neighborsColName) =
getOutputCols.length match {
case 1 => (FINISHED_DOC_SIM_RANKER_ID_DEFAULT, FINISHED_DOC_SIM_RANKER_NEIGHBORS_DEFAULT)
case 2 => (getOutputCols(0), getOutputCols(1))
}
val transformed = dataset
.withColumn(
idColName,
element_at(col(s"${AnnotatorType.DOC_SIMILARITY_RANKINGS}.metadata"), 1)
.getItem(LSH_ID_COL_NAME)
.cast("int"))
.withColumn(
neighborsColName,
element_at(col(s"${AnnotatorType.DOC_SIMILARITY_RANKINGS}.metadata"), 1)
.getItem(LSH_NEIGHBORS_COL_NAME))
val formatted = transformed
.withColumn(
s"no_squared_$neighborsColName",
regexp_replace(col(neighborsColName), "[\\[\\]]", ""))
.withColumn(
s"tuple_extract_$neighborsColName",
regexp_extract(col(s"no_squared_$neighborsColName"), "\\((.*?)\\)", 0))
.withColumn(
s"no_rounded_$neighborsColName",
regexp_replace(col(s"tuple_extract_$neighborsColName"), "[\\(\\)]", ""))
val result =
if (getExtractNearestNeighbor)
formatted
.withColumn(
s"split_$neighborsColName",
split(col(s"no_rounded_$neighborsColName"), ","))
.withColumn(
"nearest_neighbor_id",
element_at(col(s"split_$neighborsColName"), 1).cast(IntegerType))
.withColumn("nearest_neighbor_distance", element_at(col(s"split_$neighborsColName"), 2))
else {
val mh3Func = (s: String) => MurmurHash3.stringHash(s, MurmurHash3.stringSeed)
val mh3UDF = udf { mh3Func }
val removeRoundBracketsFunc = (x: String) => x.replaceAll("\\(", "").replaceAll("\\)", "")
val removeRoundBracketsUDF = udf { removeRoundBracketsFunc }
val neighbors = formatted
.where(col("finished_doc_similarity_rankings_neighbors") =!= "[]")
.select(col("finished_doc_similarity_rankings_neighbors"))
.withColumn(
"finished_doc_similarity_rankings_neighbors",
regexp_replace(col("finished_doc_similarity_rankings_neighbors"), "\\[", ""))
.withColumn(
"finished_doc_similarity_rankings_neighbors",
regexp_replace(col("finished_doc_similarity_rankings_neighbors"), "\\]", ""))
.withColumn(
"split_nearest_neighbors",
split(col("finished_doc_similarity_rankings_neighbors"), "\\),\\("))
.select(
col("finished_doc_similarity_rankings_neighbors"),
col("split_nearest_neighbors"))
.withColumn("nearest_neighbors_array_exploded", explode(col("split_nearest_neighbors")))
.withColumn(
"nearest_neighbors_array_exploded_cleaned",
removeRoundBracketsUDF(col("nearest_neighbors_array_exploded")))
.withColumn(
"nearest_neighbor_id",
split(col("nearest_neighbors_array_exploded_cleaned"), ",")(0))
.withColumn(
"nearest_neighbor_distance",
split(col("nearest_neighbors_array_exploded_cleaned"), ",")(1))
.select("nearest_neighbor_id", "nearest_neighbor_distance")
dataset
.withColumn("nearest_neighbor_id", mh3UDF(col("text")))
.join(neighbors, usingColumn = "nearest_neighbor_id")
.select("text", "nearest_neighbor_id", "nearest_neighbor_distance")
}
result
.where(col("nearest_neighbor_id").isNotNull)
.drop(
s"no_squared_$neighborsColName",
s"tuple_extract_$neighborsColName",
s"no_rounded_$neighborsColName",
s"split_$neighborsColName")
}
override def copy(extra: ParamMap): Transformer = defaultCopy(extra)
override def transformSchema(schema: StructType): StructType = {
val documentSimilarityRankerAnnotators = Seq(AnnotatorType.DOC_SIMILARITY_RANKINGS)
getInputCols.foreach { annotationColumn =>
FinisherUtil.checkIfInputColsExist(getInputCols, schema)
FinisherUtil.checkIfAnnotationColumnIsSparkNLPAnnotation(schema, annotationColumn)
/** Check if the annotationColumn has DocumentSimilarityRanker. It must be annotators:
* DocumentSimilarityRanker
*/
require(
documentSimilarityRankerAnnotators.contains(
schema(annotationColumn).metadata.getString("annotatorType")),
s"column [$annotationColumn] must be of type DocumentSimilarityRanker")
}
val outputFields = schema.fields
StructType(outputFields)
}
}
object DocumentSimilarityRankerFinisher
extends DefaultParamsReadable[DocumentSimilarityRankerFinisher]