
com.johnsnowlabs.nlp.embeddings.HasEmbeddings.scala Maven / Gradle / Ivy
package com.johnsnowlabs.nlp.embeddings
import com.johnsnowlabs.nlp.AnnotatorType
import org.apache.spark.ml.param.{BooleanParam, IntParam, Params}
import org.apache.spark.sql.Column
import org.apache.spark.sql.types.MetadataBuilder
trait HasEmbeddings extends Params {
val dimension = new IntParam(this, "dimension", "Number of embedding dimensions")
val caseSensitive = new BooleanParam(this, "caseSensitive", "whether to ignore case in tokens for embeddings matching")
setDefault(caseSensitive, false)
def setDimension(value: Int): this.type = set(this.dimension, value)
def setCaseSensitive(value: Boolean): this.type = set(this.caseSensitive, value)
def getDimension: Int = $(dimension)
def getCaseSensitive: Boolean = $(caseSensitive)
protected def wrapEmbeddingsMetadata(col: Column, embeddingsDim: Int, embeddingsRef: Option[String] = None): Column = {
val metadataBuilder: MetadataBuilder = new MetadataBuilder()
metadataBuilder.putString("annotatorType", AnnotatorType.WORD_EMBEDDINGS)
metadataBuilder.putLong("dimension", embeddingsDim.toLong)
embeddingsRef.foreach(ref => metadataBuilder.putString("ref", ref))
col.as(col.toString, metadataBuilder.build)
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy