All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.mrpowers.spark.stringmetric.expressions.similarityFunctions.scala Maven / Gradle / Ivy

The newest version!
package com.github.mrpowers.spark.stringmetric.expressions

import com.github.mrpowers.spark.stringmetric.unsafe.UTF8StringFunctions
import org.apache.commons.text.similarity.CosineDistance
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{
  CodegenContext,
  ExprCode
}
import org.apache.spark.sql.types.{ DataType, IntegerType, StringType }

/*
 * Alleviates painfully long codegen strings.
 *
 * TODO: See if there is a less hacky way to inject the imports
 * into generated code.
 */
trait UTF8StringFunctionsHelper {
  val stringFuncs: String = "com.github.mrpowers.spark.stringmetric.unsafe.UTF8StringFunctions"
}

trait StringString2IntegerExpression
extends ImplicitCastInputTypes
with NullIntolerant
with UTF8StringFunctionsHelper { self: BinaryExpression =>
  override def dataType: DataType = IntegerType
  override def inputTypes: Seq[DataType] = Seq(StringType, StringType)

  protected override def nullSafeEval(left: Any, right: Any): Any = -1
}

case class HammingDistance(left: Expression, right: Expression)
extends BinaryExpression with StringString2IntegerExpression {
  override def prettyName: String = "hamming"

  override def nullSafeEval(leftVal: Any, righValt: Any): Any = {
    val leftStr = left.asInstanceOf[UTF8String]
    val rightStr = right.asInstanceOf[UTF8String]
    UTF8StringFunctions.hammingDistance(leftStr, rightStr)
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    defineCodeGen(ctx, ev, (s1, s2) => s"$stringFuncs.hammingDistance($s1, $s2)")
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy