com.nvidia.spark.rapids.GpuRegExpReplaceMeta.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of rapids-4-spark_2.12 Show documentation
Show all versions of rapids-4-spark_2.12 Show documentation
Creates the distribution package of the RAPIDS plugin for Apache Spark
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.nvidia.spark.rapids
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpReplaceWithBackref, GpuRegExpUtils}
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.unsafe.types.UTF8String
trait GpuRegExpReplaceOpt extends Serializable
@SerialVersionUID(100L)
object GpuRegExpStringReplace extends GpuRegExpReplaceOpt
@SerialVersionUID(100L)
object GpuRegExpStringReplaceMulti extends GpuRegExpReplaceOpt
class GpuRegExpReplaceMeta(
expr: RegExpReplace,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends QuaternaryExprMeta[RegExpReplace](expr, conf, parent, rule) {
private var javaPattern: Option[String] = None
private var cudfPattern: Option[String] = None
private var replacement: Option[String] = None
private var searchList: Option[Seq[String]] = None
private var replaceOpt: Option[GpuRegExpReplaceOpt] = None
private var containsBackref: Boolean = false
override def tagExprForGpu(): Unit = {
replacement = expr.rep match {
case Literal(s: UTF8String, DataTypes.StringType) if s != null => Some(s.toString)
case _ => None
}
expr.regexp match {
case Literal(s: UTF8String, DataTypes.StringType) if s != null =>
javaPattern = Some(s.toString())
try {
val (pat, repl) =
new CudfRegexTranspiler(RegexReplaceMode).getTranspiledAST(s.toString, None,
replacement)
repl.map { r => GpuRegExpUtils.backrefConversion(r.toRegexString) }.foreach {
case (hasBackref, convertedRep) =>
containsBackref = hasBackref
replacement = Some(GpuRegExpUtils.unescapeReplaceString(convertedRep))
}
if (!containsBackref && GpuOverrides.isSupportedStringReplacePattern(expr.regexp)) {
replaceOpt = Some(GpuRegExpStringReplace)
} else {
searchList = GpuRegExpUtils.getChoicesFromRegex(pat)
searchList match {
case Some(_) if !containsBackref =>
replaceOpt = Some(GpuRegExpStringReplaceMulti)
case _ =>
GpuRegExpUtils.tagForRegExpEnabled(this)
GpuRegExpUtils.validateRegExpComplexity(this, pat)
cudfPattern = Some(pat.toRegexString)
}
}
} catch {
case e: RegexUnsupportedException =>
willNotWorkOnGpu(e.getMessage)
}
case _ =>
willNotWorkOnGpu(s"only non-null literal strings are supported on GPU")
}
GpuOverrides.extractLit(expr.pos).foreach { lit =>
if (lit.value.asInstanceOf[Int] != 1) {
willNotWorkOnGpu("only a search starting position of 1 is supported")
}
}
}
override def convertToGpu(
lhs: Expression,
regexp: Expression,
rep: Expression,
pos: Expression): GpuExpression = {
// ignore the pos expression which must be a literal 1 after tagging check
require(childExprs.length == 4,
s"Unexpected child count for RegExpReplace: ${childExprs.length}")
replaceOpt match {
case None =>
(javaPattern, cudfPattern, replacement) match {
case (Some(javaPattern), Some(cudfPattern), Some(cudfReplacement)) =>
if (containsBackref) {
GpuRegExpReplaceWithBackref(lhs, regexp, rep)(javaPattern,
cudfPattern, cudfReplacement)
} else {
GpuRegExpReplace(lhs, regexp, rep)(javaPattern, cudfPattern, cudfReplacement,
None, None)
}
case _ =>
throw new IllegalStateException("Expression has not been tagged correctly")
}
case _ =>
(javaPattern, replacement) match {
case (Some(javaPattern), Some(replacement)) =>
GpuRegExpReplace(lhs, regexp, rep)(javaPattern, javaPattern, replacement,
searchList, replaceOpt)
case _ =>
throw new IllegalStateException("Expression has not been tagged correctly")
}
}
}
}