All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.nvidia.spark.rapids.RegexParser.scala Maven / Gradle / Ivy

There is a newer version: 24.10.1
Show newest version
/*
 * Copyright (c) 2021-2024, NVIDIA CORPORATION.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.nvidia.spark.rapids

import java.sql.SQLException

import scala.collection.mutable.ListBuffer

import com.nvidia.spark.rapids.GpuOverrides.regexMetaChars
import com.nvidia.spark.rapids.RegexParser.toReadableString

/**
 * Regular expression parser based on a Pratt Parser design.
 *
 * The goal of this parser is to build a minimal AST that allows us
 * to validate that we can support the expression on the GPU. The goal
 * is not to parse with the level of detail that would be required if
 * we were building an evaluation engine. For example, operator precedence is
 * largely ignored but could be added if we need it later.
 *
 * The Java and cuDF regular expression documentation has been used as a reference:
 *
 * Java regex: https://docs.oracle.com/javase/7/docs/api/java/util/regex/Pattern.html
 * cuDF regex: https://docs.rapids.ai/api/libcudf/stable/md_regex.html
 *
 * The following blog posts provide some background on Pratt Parsers and parsing regex.
 *
 * - https://journal.stuffwithstuff.com/2011/03/19/pratt-parsers-expression-parsing-made-easy/
 * - https://matt.might.net/articles/parsing-regex-with-recursive-descent/
 */
class RegexParser(pattern: String) {
  // Note that [, ] and \ should be part of Punct, but they are handled separately
  private val regexPunct = """!"#$%&'()*+,-./:;<=>?@^_`{|}~"""
  private val escapeChars = Map('n' -> '\n', 'r' -> '\r', 't' -> '\t', 'f' -> '\f', 'a' -> '\u0007',
      'b' -> '\b', 'e' -> '\u001b')

  /** index of current position within the string being parsed */
  private var pos = 0

  def parse(): RegexAST = {
    val ast = parseUntil(() => eof())
    if (!eof()) {
      throw new RegexUnsupportedException("Failed to parse full regex. Last character parsed was",
        Some(pos))
    }
    ast
  }

  def parseReplacement(numCaptureGroups: Int): RegexReplacement = {
    val sequence = RegexReplacement(new ListBuffer(), numCaptureGroups)
    while (!eof()) {
      parseReplacementBase() match {
        case RegexSequence(parts) =>
          sequence.parts ++= parts
        case other =>
          sequence.parts += other
      }
    }
    sequence
  }

  private def parseReplacementBase(): RegexAST = {
      consume() match {
        case '\\' =>
          parseBackrefOrEscaped()
        case '$' =>
          parseBackrefOrLiteralDollar()
        case other =>
          RegexChar(other)
      }
  }


  private def parseUntil(until: () => Boolean): RegexAST = {
    val term = parseTerm(() => until() || peek().contains('|'))
    if (!eof() && peek().contains('|')) {
      consumeExpected('|')
      new RegexChoice(term, parseUntil(until), pos)
    } else {
      term
    }
  }

  private def parseTerm(until: () => Boolean): RegexAST = {
    val sequence = new RegexSequence(new ListBuffer(), pos)
    while (!eof() && !until()) {
      parseFactor(until) match {
        case RegexSequence(parts) =>
          sequence.parts ++= parts
        case other =>
          sequence.parts += other
      }
    }
    sequence
  }

  private def isValidQuantifierAhead(): Boolean = {
    if (peek().contains('{')) {
      val bookmark = pos
      consumeExpected('{')
      val q = parseQuantifierOrLiteralBrace()
      pos = bookmark
      q match {
        case _: QuantifierFixedLength | _: QuantifierVariableLength => true
        case _ => false
      }
    } else {
      false
    }
  }

  private def parseFactor(until: () => Boolean): RegexAST = {
    var start = pos
    var base = parseBase()
    base.position = Some(start)
    while (!eof() && !until()
        && (peek().exists(ch => ch == '*' || ch == '+' || ch == '?')
        || isValidQuantifierAhead())) {
      start = pos
      val quantifier = if (peek().contains('{')) {
        consumeExpected('{')
        parseQuantifierOrLiteralBrace().asInstanceOf[RegexQuantifier]
      } else {
        SimpleQuantifier(consume())
      }
      base = new RegexRepetition(base, quantifier, start)
      quantifier.position = Some(pos-1)
    }
    base
  }

  private def parseBase(): RegexAST = {
    val start = pos
    val base: RegexAST = consume() match {
      case '(' =>
        parseGroup()
      case '[' =>
        parseCharacterClass()
      case ']' =>
        RegexEscaped(']')
      case '}' =>
        RegexEscaped('}')
      case '\\' =>
        parseEscapedCharacter()
      case '\u0000' =>
        RegexGroup(false, RegexEscaped('0'), None)
      case '*' | '+' | '?' =>
        throw new RegexUnsupportedException(
          "Base expression cannot start with quantifier", Some(pos-1))
      case other =>
        RegexChar(other)
    }
    base.position = Some(start)
    base
  }

  private def parseGroup(): RegexAST = {
    var captureGroup = if (pos + 1 < pattern.length
        && pattern.charAt(pos) == '?'
        && pattern.charAt(pos+1) == ':') {
      pos += 2
      false
    } else {
      true
    }
    val lookahead = if (pos + 1 < pattern.length
        && "!=".contains(pattern.charAt(pos))) {
      pos += 1
      captureGroup = false
      pattern.charAt(pos-1) match {
        case '=' => Some(RegexPositiveLookahead)
        case '!' => Some(RegexNegativeLookahead)
      }
    } else {
      None
    }
    val term = parseUntil(() => peek().contains(')'))
    consumeExpected(')')
    RegexGroup(captureGroup, term, lookahead)
  }

  private def parseCharacterClass(): RegexCharacterClass = {
    val supportedMetaCharacters = "\\^-[]+."

    def getEscapedComponent(): RegexCharacterClassComponent = {
      peek() match {
        case Some('x') =>
          consumeExpected('x')
          val hexChar = parseHexDigit
          hexChar.codePoint match {
            case 0 => hexChar
            case codePoint => RegexChar(codePoint.toChar)
          }
        case Some('0') =>
          val octalChar = parseOctalDigit
          octalChar.codePoint match {
            case 0 => RegexHexDigit("00")
            case codePoint => RegexChar(codePoint.toChar)
          }
        case Some(ch) =>
          consumeExpected(ch) match {
            // NOTE: Should switch to ASCII mode to simplify and expand this fix
            case 'd' => RegexCharacterRange(RegexChar('0'), RegexChar('9'))
            // List of character literals with an escape from here, under "Characters"
            // https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html
            case ch if escapeChars.contains(ch) =>
              RegexChar(escapeChars(ch))
            case ch =>
              if (supportedMetaCharacters.contains(ch)) {
                // an escaped metacharacter ('\\', '^', '-', ']', '+')
                RegexEscaped(ch)
              } else {
                throw new RegexUnsupportedException(
                  s"Unsupported escaped character '$ch' in character class", Some(pos-1))
              }
          }
        case None =>
          throw new RegexUnsupportedException(
                s"Unclosed character class", Some(pos))
      }
    }

    val start = pos
    val characterClass = new RegexCharacterClass(negated = false, characters = ListBuffer(), pos)
    // loop until the end of the character class or EOF
    var characterClassComplete = false
    while (!eof() && !characterClassComplete) {
      val ch = consume()
      ch match {
        case '[' =>
          // treat as a literal character and add to the character class
          characterClass.append(new RegexChar(ch, pos-1))
        case ']' if (!characterClass.negated && pos > start + 1) ||
            (characterClass.negated && pos > start + 2) =>
          // "[]" is not a valid character class
          // "[]a]" is a valid character class containing the characters "]" and "a"
          // "[^]a]" is a valid negated character class containing the characters "]" and "a"
          characterClassComplete = true
        case '^' if pos == start + 1 =>
          // Negates the character class, causing it to match a single character not listed in
          // the character class. Only valid immediately after the opening '['
          characterClass.negated = true
        case ch =>
          val nextChar: RegexCharacterClassComponent = ch match {
            case '\\' =>
              getEscapedComponent() match {
                case RegexChar(ch) if supportedMetaCharacters.contains(ch) =>
                  // A hex or octal representation of a meta character gets treated as an escaped
                  // char. Example: [\x5ea] is treated as [\^a], not just [^a]
                  RegexEscaped(ch)
                case other => other
              }
            case '&' =>
              peek() match {
                case Some('&') =>
                  throw new RegexUnsupportedException("" +
                    "cuDF does not support class intersection operator &&", Some(pos-1))
                case _ => // ignore
              }
              RegexChar('&')
            case '\u0000' =>
              RegexHexDigit("00")
            case ch =>
              RegexChar(ch)
          }
          nextChar.position = Some(pos-1)
          peek() match {
            case Some('-') =>
              consumeExpected('-')
              peek() match {
                case Some(']') =>
                  // '-' at end of class e.g. "[abc-]"
                  characterClass.append(nextChar)
                  characterClass.append('-')
                case Some('\\') =>
                  consumeExpected('\\')
                  characterClass.appendRange(nextChar, getEscapedComponent())
                case Some(end) =>
                  skip()
                  characterClass.appendRange(nextChar, RegexChar(end))
                case _ =>
                  throw new RegexUnsupportedException(
                    "Unexpected EOF while parsing character range", Some(pos))
              }
            case _ =>
              characterClass.append(nextChar)
          }
      }
    }
    if (!characterClassComplete) {
      throw new RegexUnsupportedException(s"Unclosed character class", Some(pos))
    }
    characterClass
  }


  /**
   * Parse a quantifier in one of the following formats:
   *
   * {n}
   * {n,}
   * {n,m} (only valid if m >= n)
   */
  private def parseQuantifierOrLiteralBrace(): RegexAST = {

    // assumes that '{' has already been consumed
    val start = pos

    def treatAsLiteralBrace() = {
      // this was not a quantifier, just a literal '{'
      pos = start + 1
      RegexChar('{')
    }

    consumeInt match {
      case Some(minLength) =>
        peek() match {
          case Some(',') =>
            consumeExpected(',')
            val max = consumeInt()
            if (peek().contains('}')) {
              consumeExpected('}')
              max match {
                case None =>
                  QuantifierVariableLength(minLength, None)
                case Some(m) =>
                  if (m >= minLength) {
                    QuantifierVariableLength(minLength, max)
                  } else {
                    treatAsLiteralBrace()
                  }
              }
            } else {
              treatAsLiteralBrace()
            }
          case Some('}') =>
            consumeExpected('}')
            QuantifierFixedLength(minLength)
          case _ =>
            treatAsLiteralBrace()
        }
      case None =>
        treatAsLiteralBrace()
    }
  }

  private def parseBackrefOrEscaped(): RegexAST = {
    val start = pos

    consumeInt match {
      case Some(refNum) =>
        RegexBackref(refNum)
      case None =>
        pos = start
        RegexChar('\\')
    }
  }

  private def parseBackrefOrLiteralDollar(): RegexAST = {
    val start = pos

    def treatAsLiteralDollar() = {
      pos = start
      RegexChar('$')
    }

    peek() match {
      case Some('{') =>
        consumeExpected('{')
        val num = consumeInt()
        if (peek().contains('}')) {
          consumeExpected('}')
          num match {
            case Some(n) =>
              RegexBackref(n)
            case _ =>
              treatAsLiteralDollar()
          }
        } else {
          treatAsLiteralDollar()
        }
      case Some(ch) if ch >= '1' && ch <= '9' =>
        val num = consumeInt()
        num match {
          case Some(n) =>
            RegexBackref(n)
          case _ =>
            treatAsLiteralDollar()
        }
      case _ =>
        treatAsLiteralDollar()
    }
  }

  private def parseEscapedCharacter(): RegexAST = {
    peek() match {
      case None =>
        throw new RegexUnsupportedException("Pattern may not end with trailing escape", Some(pos))
      case Some(ch) =>
        ch match {
          case 'A' | 'Z' | 'z' =>
            // string anchors
            consumeExpected(ch)
            RegexEscaped(ch)
          case 's' | 'S' | 'd' | 'D' | 'w' | 'W' | 'v' | 'V' | 'h' | 'H' | 'R' =>
            // meta sequences
            consumeExpected(ch)
            RegexEscaped(ch)
          case 'B' | 'b' =>
            // word boundaries
            consumeExpected(ch)
            RegexEscaped(ch)
          case '[' | ']' | '\\' | '^' | '$' | '.' | '|' | '?' | '*' | '+' | '(' | ')' | '{' | '}' =>
            // escaped metacharacter
            consumeExpected(ch)
            RegexEscaped(ch)
          case 'x' =>
            consumeExpected(ch)
            parseHexDigit
          case '0' =>
            parseOctalDigit
          case 'p' | 'P' =>
            parsePredefinedClass
          case _ if escapeChars.contains(ch) =>
            consumeExpected(ch)
            RegexChar(escapeChars(ch))
          case _ if regexPunct.contains(ch) =>
            // other punctuation
            // note that this may include metacharacters from earlier, this is just to
            // handle characters not covered by the previous cases earlier
            consumeExpected(ch)
            RegexEscaped(ch)
          case other =>
            throw new RegexUnsupportedException(
              s"Invalid or unsupported escape character '$other'", Some(pos - 1))
        }
    }
  }

  private def parsePredefinedClass: RegexCharacterClass = {
    val negated = consume().isUpper
    consumeExpected('{')
    val start = pos
    while(!eof() && pattern.charAt(pos).isLetter) {
      pos += 1
    }
    val className = pattern.substring(start, pos)
    def getCharacters(className: String): ListBuffer[RegexCharacterClassComponent] = {
      // Character lists from here:
      // https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html
      className match {
        case "Lower" =>
          ListBuffer(RegexCharacterRange(RegexChar('a'), RegexChar('z')))
        case "Upper" =>
          ListBuffer(RegexCharacterRange(RegexChar('A'), RegexChar('Z')))
        case "ASCII" =>
          ListBuffer(RegexCharacterRange(RegexHexDigit("00"), RegexChar('\u007f')))
        case "Alpha" =>
          ListBuffer(getCharacters("Lower"), getCharacters("Upper")).flatten
        case "Digit" =>
          ListBuffer(RegexCharacterRange(RegexChar('0'), RegexChar('9')))
        case "Alnum" =>
          ListBuffer(getCharacters("Alpha"), getCharacters("Digit")).flatten
        case "Punct" =>
          val res:ListBuffer[RegexCharacterClassComponent] =
              ListBuffer(regexPunct.map(RegexChar): _*)
          res ++= ListBuffer(RegexEscaped('['), RegexEscaped(']'), RegexEscaped('\\'))
        case "Graph" =>
          ListBuffer(getCharacters("Alnum"), getCharacters("Punct")).flatten
        case "Print" =>
          val res = getCharacters("Graph")
          res += RegexChar('\u0020')
        case "Blank" =>
          ListBuffer(RegexChar(' '), RegexEscaped('t'))
        case "Cntrl" =>
          ListBuffer(RegexCharacterRange(RegexHexDigit("00"), RegexChar('\u001f')),
            RegexChar('\u007f'))
        case "XDigit" =>
          ListBuffer(RegexCharacterRange(RegexChar('0'), RegexChar('9')),
            RegexCharacterRange(RegexChar('a'), RegexChar('f')),
            RegexCharacterRange(RegexChar('A'), RegexChar('F')))
        case "Space" =>
          ListBuffer(" \t\n\u000B\f\r".map(RegexChar): _*)
        case _ =>
          throw new RegexUnsupportedException(
            s"Predefined character class ${className} is not supported", Some(start))
      }
    }
    consumeExpected('}')
    RegexCharacterClass(negated, characters = getCharacters(className))
  }

  private def isHexDigit(ch: Char): Boolean = ch.isDigit ||
    (ch >= 'a' && ch <= 'f') ||
    (ch >= 'A' && ch <= 'F')

  private def parseHexDigit: RegexHexDigit = {
    // \xhh      The character with hexadecimal value 0xhh
    // \x{h...h} The character with hexadecimal value 0xh...h
    //           (Character.MIN_CODE_POINT  <= 0xh...h <=  Character.MAX_CODE_POINT)

    val varHex = pattern.charAt(pos) == '{'
    if (varHex) {
      consumeExpected('{')
    }
    val start = pos
    while (!eof() && isHexDigit(pattern.charAt(pos))) {
      pos += 1
    }
    val hexDigit = pattern.substring(start, pos)
    if (varHex) {
      consumeExpected('}')
    } else if (hexDigit.length != 2) {
      throw new RegexUnsupportedException(s"Invalid hex digit: $hexDigit", Some(start))
    }

    val value = Integer.parseInt(hexDigit, 16)
    if (value < Character.MIN_CODE_POINT || value > Character.MAX_CODE_POINT) {
      throw new RegexUnsupportedException(s"Invalid hex digit: $hexDigit", Some(start))
    }
    new RegexHexDigit(hexDigit, start - 2)
  }

  private def isOctalDigit(ch: Char): Boolean = ch >= '0' && ch <= '7'

  private def parseOctalDigit: RegexOctalChar = {
    // \0n   The character with octal value 0n (0 <= n <= 7)
    // \0nn  The character with octal value 0nn (0 <= n <= 7)
    // \0mnn The character with octal value 0mnn (0 <= m <= 3, 0 <= n <= 7)

    def parseOctalDigits(n: Integer): RegexOctalChar = {
      val octal = pattern.substring(pos, pos + n)
      pos += n
      new RegexOctalChar(octal, pos)
    }

    if (!eof() && isOctalDigit(pattern.charAt(pos))) {
      if (pos + 1 < pattern.length && isOctalDigit(pattern.charAt(pos + 1))) {
        if (pos + 2 < pattern.length && isOctalDigit(pattern.charAt(pos + 2))
            && pattern.charAt(pos) <= '3') {
          if (pos + 3 < pattern.length && isOctalDigit(pattern.charAt(pos + 3))
              && pattern.charAt(pos+1) <= '3' && pattern.charAt(pos) == '0') {
            parseOctalDigits(4)
          } else {
            parseOctalDigits(3)
          }
        } else {
          parseOctalDigits(2)
        }
      } else {
        parseOctalDigits(1)
      }
    } else {
      throw new RegexUnsupportedException(
        "Invalid octal digit", Some(pos))
    }
  }

  /** Determine if we are at the end of the input */
  private def eof(): Boolean = pos == pattern.length

  /** Advance the index by one */
  private def skip(): Unit = {
    if (eof()) {
      throw new RegexUnsupportedException("Unexpected EOF", Some(pos))
    }
    pos += 1
  }

  /** Get the next character and advance the index by one */
  private def consume(): Char = {
    if (eof()) {
      throw new RegexUnsupportedException("Unexpected EOF", Some(pos))
    } else {
      pos += 1
      pattern.charAt(pos - 1)
    }
  }

  /** Consume the next character if it is the one we expect */
  private def consumeExpected(expected: Char): Char = {
    val consumed = consume()
    if (consumed != expected) {
      throw new RegexUnsupportedException(
        s"Expected '$expected' but found '$consumed'", Some(pos-1))
    }
    consumed
  }

  /** Peek at the next character without consuming it */
  private def peek(): Option[Char] = {
    if (eof()) {
      None
    } else {
      Some(pattern.charAt(pos))
    }
  }

  private def consumeInt(): Option[Int] = {
    val start = pos
    while (!eof() && peek().exists(_.isDigit)) {
      skip()
    }
    if (start == pos) {
      None
    } else {
      Some(pattern.substring(start, pos).toInt)
    }
  }

}

object RegexParser {
  private val regexpChars = Set('\u0000', '\\', '.', '^', '$', '\u0007', '\u001b', '\f')

  def parse(pattern: String): RegexAST = new RegexParser(pattern).parse

  def isRegExpString(s: String): Boolean = {

    def isRegExpString(ast: RegexAST): Boolean = ast match {
      case RegexChar(ch) => regexpChars.contains(ch)
      case RegexEscaped(_) => true
      case RegexSequence(parts) => parts.exists(isRegExpString)
      case _ => true
    }

    try {
      val parser = new RegexParser(s)
      val ast = parser.parse()
      isRegExpString(ast)
    } catch {
      case _: RegexUnsupportedException =>
        // if we cannot parse it then assume that it might be valid regexp
        true
    }
  }

  def toReadableString(x: String): String = {
    x.map {
      case '\r' => "\\r"
      case '\n' => "\\n"
      case '\t' => "\\t"
      case '\f' => "\\f"
      case '\u0000' => "\\u0000"
      case '\u000b' => "\\u000b"
      case '\u0085' => "\\u0085"
      case '\u2028' => "\\u2028"
      case '\u2029' => "\\u2029"
      case other => other
    }.mkString
  }

}

sealed trait RegexMode
object RegexFindMode extends RegexMode
object RegexReplaceMode extends RegexMode
object RegexSplitMode extends RegexMode

sealed trait RegexLookahead
object RegexNegativeLookahead extends RegexLookahead
object RegexPositiveLookahead extends RegexLookahead

sealed class RegexRewriteFlags(val emptyRepetition: Boolean)

/**
 * Transpile Java/Spark regular expression to a format that cuDF supports, or throw an exception
 * if this is not possible.
 *
 * @param mode  RegexFindMode    if matching only (rlike)
                RegexReplaceMode if performing a replacement (regexp_replace)
                RegexSplitMode   if performing a split (string_split)
 */
class CudfRegexTranspiler(mode: RegexMode) {
  private val regexPunct = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
  private val escapeChars = Map('n' -> '\n', 'r' -> '\r', 't' -> '\t', 'f' -> '\f', 'a' -> '\u0007',
      'b' -> '\b', 'e' -> '\u001b')

  private def countCaptureGroups(regex: RegexAST): Int = {
    regex match {
      case RegexSequence(parts) => parts.foldLeft(0)((c, re) => c + countCaptureGroups(re))
      case RegexGroup(capture, base, _) =>
        if (capture) {
          1 + countCaptureGroups(base)
        } else {
          countCaptureGroups(base)
        }
      case _ => 0
    }
  }

  /**
   * Parse Java regular expression and translate into cuDF regular expression.
   *
   * @param pattern Regular expression that is valid in Java's engine
   * @param extractIndex extraction index for regular expression
   * @param repl Optional replacement pattern
   * @return Regular expression and optional replacement in cuDF format
   */
  def transpile(pattern: String, extractIndex: Option[Int], repl: Option[String]):
        (String, Option[String]) = {
    val (cudfRegex, replacement) = getTranspiledAST(pattern, extractIndex, repl)

    // write out to regex string, performing minor transformations
    // such as adding additional escaping
    (cudfRegex.toRegexString, replacement.map(_.toRegexString))
  }

  def getTranspiledAST(
      regex: RegexAST,
      extractIndex: Option[Int],
      repl: Option[String]): (RegexAST, Option[RegexReplacement]) = {
    // if we have a replacement, parse the replacement string using the regex parser to account
    // for backrefs
    val replacement = repl.map(s => new RegexParser(s).parseReplacement(countCaptureGroups(regex)))

    // validate that the regex is supported by cuDF
    val cudfRegex = transpile(regex, extractIndex, replacement, None)

    (cudfRegex, replacement)
  }

  /**
   * Parse Java regular expression and translate into cuDF regular expression in AST form.
   *
   * @param pattern Regular expression that is valid in Java's engine
   * @param extractIndex extraction index for regular expression
   * @param repl Optional replacement pattern
   * @return Regular expression AST and optional replacement in cuDF format
   */
  def getTranspiledAST(
      pattern: String,
      extractIndex: Option[Int],
      repl: Option[String]): (RegexAST, Option[RegexReplacement]) = {
    // parse the source regular expression
    val regex = new RegexParser(pattern).parse()
    getTranspiledAST(regex, extractIndex, repl)
  }

  def transpileToSplittableString(e: RegexAST): Option[String] = {
    e match {
      case RegexEscaped(ch) if escapeChars.contains(ch) => Some(escapeChars(ch).toString)
      case RegexEscaped(ch) if regexPunct.contains(ch) => Some(ch.toString)
      case RegexChar(ch) if !regexMetaChars.contains(ch) => Some(ch.toString)
      case RegexSequence(parts) =>
        parts.foldLeft[Option[String]](Some("")) { (all, x) =>
          all match {
            case Some(current) =>
              transpileToSplittableString(x) match {
                case Some(y) => Some(current + y)
                case _ => None
              }
            case _ => None
          }
        }
      case _ => None
    }
  }

  def transpileToSplittableString(pattern: String): Option[String] = {
    try {
      val regex = new RegexParser(pattern).parse()
      transpileToSplittableString(regex)
    } catch {
      // treat as regex if we can't parse it
      case _: RegexUnsupportedException =>
        None
    }
  }

  @scala.annotation.tailrec
  private def isRepetition(e: RegexAST, checkZeroLength: Boolean): Boolean = {
    e match {
      case RegexRepetition(_, _) if !checkZeroLength => true
      case RegexRepetition(_, quantifier) => quantifier match {
        case SimpleQuantifier(ch) if "*?".contains(ch) => true
        case QuantifierFixedLength(length) if length == 0 => true
        case QuantifierVariableLength(min, _) if min == 0 => true
        case _ => false
      }
      case RegexGroup(_, term, _) => isRepetition(term, checkZeroLength)
      case RegexSequence(parts) if parts.nonEmpty => isRepetition(parts.last, checkZeroLength)
      case _ => false
    }
  }

  private def getUnsupportedRepetitionBaseOption(e: RegexAST): Option[RegexAST] = {
    e match {
      case RegexEscaped(ch) => ch match {
        case 'd' | 'w' | 's' | 'S' | 'h' | 'H' | 'v' | 'V' => None
        case _ => Some(e)
      }
      case RegexChar(a) if "$^".contains(a) =>
        // example: "$*"
        Some(e)

      case RegexRepetition(_, _) =>
        // example: "a*+"
        Some(e)

      case RegexSequence(parts) =>
        parts.foreach { part => getUnsupportedRepetitionBaseOption(part) match {
            case r @ Some(_) => return r
            case None =>
          }
        }
        None

      case RegexGroup(_, term, _) =>
        getUnsupportedRepetitionBaseOption(term)

      case _ => None
    }
  }

  private def getUnsupportedRepetitionBase(e: RegexAST): RegexAST = {
    getUnsupportedRepetitionBaseOption(e) match {
      case None => throw new NoSuchElementException(
        s"Expected repetition base ${e.toRegexString} to be unsupported but was actully supported")
      case Some(unsupportedTerm) => unsupportedTerm
    }
  }

  private def isSupportedRepetitionBase(e: RegexAST): Boolean = {
    getUnsupportedRepetitionBaseOption(e) match {
      case None => true
      case _ => false
    }
  }

  private val lineTerminatorChars = Seq('\n', '\r', '\u0085', '\u2028', '\u2029')

  // from Java 8 documention: a line terminator is a 1 to 2 character sequence that marks
  // the end of a line of an input character sequence.
  // this method produces a RegexAST which outputs a regular expression to match any possible
  // combination of line terminators
  private def lineTerminatorMatcher(exclude: Set[Char], excludeCRLF: Boolean,
      capture: Boolean): RegexAST = {
    val terminatorChars = new ListBuffer[RegexCharacterClassComponent]()
    terminatorChars ++= lineTerminatorChars.filter(!exclude.contains(_)).map(RegexChar)

    if (terminatorChars.size == 0 && excludeCRLF) {
      RegexEmpty()
    } else if (terminatorChars.size == 0) {
      RegexGroup(capture = capture, RegexSequence(ListBuffer(RegexChar('\r'), RegexChar('\n'))),
          None)
    } else if (excludeCRLF) {
      RegexGroup(capture = capture,
        RegexCharacterClass(negated = false, characters = terminatorChars),
        None
      )
    } else {
      RegexGroup(capture = capture, RegexParser.parse("\r|\u0085|\u2028|\u2029|\r\n"), None)
    }
  }

  private def negateCharacterClass(
      components: ListBuffer[RegexCharacterClassComponent]): RegexAST = {
    // There are differences between cuDF and Java handling of `\r`
    // in negated character classes. The expression `[^a]` will match
    // `\r` in Java but not in cuDF, so we replace `[^a]` with
    // `(?:[\r]|[^a])`.
    //
    // Examples:
    //
    // `[^a]`     => `(?:[\r]|[^a])`
    // `[^a\n]`   => `(?:[\r]|[^a\n])`
    //
    // If the negated character class contains `\r` then there is no transformation:
    //
    // `[^a\r]`   => `[^a\r]`
    // `[^a\r\n]` => `[^a\r\n]`

    val componentsWithoutLinefeed = components.filterNot {
      case RegexChar(ch) => ch == '\r'
      case RegexEscaped(ch) => ch == 'r'
      case RegexCharacterRange(startRegex, RegexChar(end)) =>
        val start = startRegex match {
          case RegexChar(ch) => ch
          case r @ RegexOctalChar(_) => r.codePoint.toChar
          case r @ RegexHexDigit(_) => r.codePoint.toChar
          case other => throw new RegexUnsupportedException(
            s"Unexpected expression at start of character range: ${other.toRegexString}",
            other.position)
        }
        start <= '\r' && end >= '\r'
      case _ =>
        false
    }

    if (componentsWithoutLinefeed.length != components.length) {
      // no modification needed in this case
      RegexCharacterClass(negated = true, ListBuffer(components.toSeq: _*))
    } else {
      RegexGroup(capture = false,
        RegexChoice(
          RegexCharacterClass(negated = false,
            characters = ListBuffer(RegexChar('\r'))),
          RegexCharacterClass(negated = true, ListBuffer(components.toSeq: _*))), None)
    }
  }

  private def transpile(regex: RegexAST, extractIndex: Option[Int],
      replacement: Option[RegexReplacement],
      previous: Option[RegexAST]): RegexAST = {

    def containsBeginAnchor(regex: RegexAST): Boolean = {
      contains(regex, {
        case RegexChar('^') | RegexEscaped('A') => true
        case _ => false
      })
    }

    def containsEndAnchor(regex: RegexAST): Boolean = {
      contains(regex, {
        case RegexChar('$') | RegexEscaped('z') | RegexEscaped('Z') => true
        case _ => false
      })
    }

    def containsNewline(regex: RegexAST): Boolean = {
      contains(regex, {
        case RegexChar('\r') | RegexEscaped('r') => true
        case RegexChar('\n') | RegexEscaped('n') => true
        case RegexChar('\u0085') | RegexChar('\u2028') | RegexChar('\u2029') => true
        case RegexEscaped('s') | RegexEscaped('v') | RegexEscaped('R') => true
        case RegexEscaped('W') | RegexEscaped('D') |
          RegexEscaped('S') | RegexEscaped('V') =>
          // these would get transpiled to negated character classes
          // that include newlines
          true
        case RegexCharacterClass(true, _) => true
        case _ => false
      })
    }

    def containsEmpty(regex: RegexAST): Boolean = {
      contains(regex, {
        case RegexRepetition(_, term) => term match {
          case SimpleQuantifier('*') | SimpleQuantifier('?') => true
          case QuantifierFixedLength(0) => true
          case QuantifierVariableLength(0, _) => true
          case _ => false
        }
        case _ => false
      })
    }

    // check a pair of regex ast nodes for unsupported combinations
    // of end string/line anchors and newlines or optional items
    def checkEndAnchorContext(r1: RegexAST, r2: RegexAST): Unit = {
      if ((containsEndAnchor(r1) &&
          (containsNewline(r2) || containsEmpty(r2) || containsBeginAnchor(r2))) ||
        (containsEndAnchor(r2) &&
          (containsNewline(r1) || containsBeginAnchor(r1)))) {
        throw new RegexUnsupportedException(
          s"End of line/string anchor is not supported in this context: " +
            s"${toReadableString(r1.toRegexString)}" +
            s"${toReadableString(r2.toRegexString)}", r1.position)
      }
    }

    def checkEndAnchorContextSplit(r1: RegexAST, r2: RegexAST): Unit = {
      if ((containsEndAnchor(r1) &&
          (containsNewline(r2) || containsEmpty(r2) || containsBeginAnchor(r2))) ||
        (containsEndAnchor(r2) &&
          (containsNewline(r1) || containsEmpty(r1) || containsBeginAnchor(r1)))) {
        throw new RegexUnsupportedException(
          s"End of line/string anchor is not supported in this context: " +
            s"${toReadableString(r1.toRegexString)}" +
            s"${toReadableString(r2.toRegexString)}", r1.position)
      }
    }

    def checkUnsupported(regex: RegexAST): Unit = {
      regex match {
        case RegexSequence(parts) =>
          for (i <- 1 until parts.length) {
            if (mode == RegexSplitMode) {
              checkEndAnchorContextSplit(parts(i - 1), parts(i))
            } else {
              checkEndAnchorContext(parts(i - 1), parts(i))
            }
          }
        case RegexChoice(l, r) =>
          checkUnsupported(l)
          checkUnsupported(r)
        case RegexGroup(_, term, _) => checkUnsupported(term)
        case RegexRepetition(ast, _) => checkUnsupported(ast)
        case RegexCharacterClass(_, components) =>
          for (i <- 1 until components.length) {
            if (mode == RegexSplitMode) {
              checkEndAnchorContextSplit(components(i - 1), components(i))
            } else {
              checkEndAnchorContext(components(i - 1), components(i))
            }
          }
        case _ =>
          // ignore
      }
    }

    def isEmptyRepetition(regex: RegexAST): Boolean = {
      regex match {
        case RegexRepetition(_, term) => term match {
          case SimpleQuantifier('*') | SimpleQuantifier('?') => true
          case QuantifierFixedLength(0) => true
          case QuantifierVariableLength(0, _) => true
          case _ => false
        }
        case RegexGroup(_, term, _) =>
          isEmptyRepetition(term)
        case RegexSequence(parts) =>
          parts.forall(isEmptyRepetition)
        case RegexChoice(l, r) =>
          isEmptyRepetition(l) || isEmptyRepetition(r)
        case _ => false
      }
    }

    checkUnsupported(regex)

    var current = 0
    // capture groups can be nested, so we need to do this logic outside of the rewrite
    def updateGroupsForExtract(regex: RegexAST, n: Int): RegexAST = {
      regex match {
        case RegexGroup(capture, term, lookahead) if capture => {
          current += 1
          RegexGroup(n == current, updateGroupsForExtract(term, n), lookahead)
        }
        case RegexSequence(parts) =>
          RegexSequence(parts.map(updateGroupsForExtract(_, n)))
        case RegexRepetition(term, quantifier) =>
          RegexRepetition(updateGroupsForExtract(term, n), quantifier)
        case _ => regex
      }
    }

    val withUpdatedGroups = extractIndex match {
      case Some(n) =>
        updateGroupsForExtract(regex, n)
      case _ => regex
    }

    val flags = new RegexRewriteFlags(isEmptyRepetition(regex))

    rewrite(withUpdatedGroups, replacement, previous, flags)
  }

  private def rewrite(regex: RegexAST, replacement: Option[RegexReplacement],
      previous: Option[RegexAST], flags: RegexRewriteFlags): RegexAST = {
    regex match {

      case RegexChar(ch) => ch match {
        case '.' =>
          // workaround for https://github.com/rapidsai/cudf/issues/9619
          val terminatorChars = new ListBuffer[RegexCharacterClassComponent]()
          terminatorChars ++= lineTerminatorChars.map(RegexChar)
          RegexCharacterClass(negated = true, terminatorChars)
        case '$' if mode == RegexSplitMode =>
          RegexEscaped('Z')
        case '$' =>
          // in the case of the line anchor $, the JVM has special conditions when handling line
          // terminators in and around the anchor
          // this handles cases where the line terminator characters are *before* the anchor ($)
          // NOTE: this applies to when using *standard* mode. In multiline mode, all these
          // conditions will change. Currently Spark does not use multiline mode.
          previous match {
            case Some(RegexChar('$')) | Some(RegexEscaped('Z')) =>
              // repeating the line anchor in cuDF (for example b$$) causes matches to fail, but in
              // Java, it's treated as a single (b$ and b$$ are synonymous), so we create
              // an empty RegexAST that outputs to empty string
              RegexEmpty()
            case Some(RegexChar(ch)) if mode == RegexReplaceMode
                && lineTerminatorChars.contains(ch) =>
                throw new RegexUnsupportedException("Regex sequences with a line terminator "
                    + "character followed by '$' are not supported in replace mode", regex.position)
            case Some(RegexChar(ch)) if ch == '\r' =>
              // when using the the CR (\r), it prevents the line anchor from handling any other
              // line terminator sequences, so we just output the anchor and we are finished
              // for example: \r$ -> \r$ (no transpilation)
              RegexChar('$')
            case Some(RegexChar(ch)) if lineTerminatorChars.contains(ch) =>
              // when using any other line terminator character, you can match any of the other
              // line terminator characters individually as part of the line anchor match.
              // for example: \n$ -> \n[\r\u0085\u2028\u2029]?$
              if (mode == RegexReplaceMode) {
                replacement match {
                  case Some(rr) => rr.appendBackref(rr.numCaptureGroups + 1)
                  case _ =>
                }
              }
              RegexSequence(ListBuffer(
                RegexRepetition(lineTerminatorMatcher(Set(ch), true,
                    mode == RegexReplaceMode), SimpleQuantifier('?')),
                RegexChar('$')))
            case Some(RegexEscaped('b')) | Some(RegexEscaped('B')) =>
              throw new RegexUnsupportedException(
                      "Regex sequences with \\b or \\B not supported around $", regex.position)
            case _ =>
              // otherwise by default we can match any or none the full set of line terminators
              if (mode == RegexReplaceMode) {
                replacement match {
                  case Some(rr) => rr.appendBackref(rr.numCaptureGroups + 1)
                  case _ =>
                }
              }
              RegexSequence(ListBuffer(
                RegexRepetition(lineTerminatorMatcher(Set.empty, false,
                    mode == RegexReplaceMode), SimpleQuantifier('?')),
                RegexChar('$')))
          }
        case '^' if mode == RegexSplitMode =>
          RegexEscaped('A')
        case '\r' | '\n' if mode == RegexFindMode =>
          previous match {
            case Some(RegexChar('$')) =>
              RegexEmpty()
            case _ =>
              regex
          }
        case _ =>
          regex
      }

      case r @ RegexOctalChar(digits) =>
        val octal = if (digits.charAt(0) == '0' && digits.length == 4) {
          digits.substring(1)
        } else  {
          digits
        }

        if (regexMetaChars.map(_.toInt).contains(r.codePoint)) {
          RegexEscaped(r.codePoint.toChar)
        } else if(r.codePoint >= 128) {
          RegexChar(r.codePoint.toChar)
        } else {
          RegexOctalChar(octal)
        }

      case r @ RegexHexDigit(_) =>
        if (regexMetaChars.map(_.toInt).contains(r.codePoint)) {
          RegexEscaped(r.codePoint.toChar)
        } else if (r.codePoint >= 128) {
          // cuDF only supports 0x00 to 0x7f hexidecimal chars
          RegexChar(r.codePoint.toChar)
        } else {
          RegexHexDigit(String.format("%02x", Int.box(r.codePoint)))
        }

      case RegexEscaped(ch) => ch match {
        case 'd' | 'D' =>
          // cuDF is not compatible with Java for \d  so we transpile to Java's definition
          // of [0-9]
          // https://github.com/rapidsai/cudf/issues/10894
          val components = ListBuffer[RegexCharacterClassComponent](
            RegexCharacterRange(RegexChar('0'), RegexChar('9')))
          if (ch.isUpper) {
            negateCharacterClass(components)
          } else {
            RegexCharacterClass(negated = false, components)
          }
        case 'w' | 'W' =>
          // cuDF is not compatible with Java for \w so we transpile to Java's definition
          // of `[a-zA-Z_0-9]`
          val components = ListBuffer[RegexCharacterClassComponent](
            RegexCharacterRange(RegexChar('a'), RegexChar('z')),
            RegexCharacterRange(RegexChar('A'), RegexChar('Z')),
            RegexChar('_'),
            RegexCharacterRange(RegexChar('0'), RegexChar('9')))
          if (ch.isUpper) {
            negateCharacterClass(components)
          } else {
            RegexCharacterClass(negated = false, components)
          }
        case 'b' | 'B' if mode == RegexSplitMode =>
          // see https://github.com/NVIDIA/spark-rapids/issues/5478
          throw new RegexUnsupportedException(
              "Word boundaries are not supported in split mode", regex.position)
        case 'b' | 'B' =>
          previous match {
            case Some(RegexEscaped(ch)) if "DWSHV".contains(ch) =>
              throw new RegexUnsupportedException(
                  "Word boundaries around \\D, \\S,\\W, \\H, or \\V are not supported",
                  regex.position)
            case Some(RegexCharacterClass(negated, _)) if negated =>
              throw new RegexUnsupportedException(
                  "Word boundaries around negated character classes are not supported",
                  regex.position)
            case _ =>
              RegexEscaped(ch)
          }
        case 'z' if mode == RegexSplitMode =>
          RegexEscaped('Z')
        case 'z' =>
          // cuDF does not support "\z" except for in split mode
          throw new RegexUnsupportedException(
            "\\z is not supported on GPU for find or replace",
            regex.position)
        case 'Z' =>
          // \Z is really a synonymn for $. It's used in Java to preserve that behavior when
          // using modes that change the meaning of $ (such as MULTILINE or UNIX_LINES)
          previous match {
            case Some(RegexEscaped('Z')) =>
              RegexEmpty()
            case _ =>
              rewrite(RegexChar('$'), replacement, previous, flags)
          }
        case 's' | 'S' =>
          // whitespace characters
          val chars: ListBuffer[RegexCharacterClassComponent] = ListBuffer(
            RegexChar(' '), RegexChar('\u000b'))
          chars ++= Seq('n', 't', 'r', 'f').map(RegexEscaped)
          if (ch.isUpper) {
            negateCharacterClass(chars)
          } else {
            RegexCharacterClass(negated = false, characters = chars)
          }
        case 'h' | 'H' =>
          // horizontal whitespace
          // see https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html
          // under "Predefined character classes"
          val chars: ListBuffer[RegexCharacterClassComponent] = ListBuffer(
            RegexChar(' '), RegexChar('\u00A0'), RegexChar('\u1680'), RegexChar('\u180e'),
            RegexChar('\u202f'), RegexChar('\u205f'), RegexChar('\u3000')
          )
          chars += RegexEscaped('t')
          chars += RegexCharacterRange(RegexChar('\u2000'), RegexChar('\u200a'))
          if (ch.isUpper) {
            negateCharacterClass(chars)
          } else {
            RegexCharacterClass(negated = false, characters = chars)
          }
        case 'v' | 'V' =>
          // vertical whitespace
          // see https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html
          // under "Predefined character classes"
          val chars: ListBuffer[RegexCharacterClassComponent] = ListBuffer(
            RegexChar('\u000B'), RegexChar('\u0085'), RegexChar('\u2028'), RegexChar('\u2029')
          )
          chars ++= Seq('n', 'f', 'r').map(RegexEscaped)
          if (ch.isUpper) {
            negateCharacterClass(chars)
          } else {
            RegexCharacterClass(negated = false, characters = chars)
          }
        case 'R' =>
          // linebreak sequence
          // see https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html
          // under "Linebreak matcher"
          val l = RegexSequence(ListBuffer(RegexChar('\u000D'), RegexChar('\u000A')))
          val r = RegexCharacterClass(false, ListBuffer[RegexCharacterClassComponent](
            RegexChar('\u000A'), RegexChar('\u000B'), RegexChar('\u000C'), RegexChar('\u000D'),
            RegexChar('\u0085'), RegexChar('\u2028'), RegexChar('\u2029')
          ))
          RegexGroup(true, RegexChoice(l, r), None)
        case _ if escapeChars.contains(ch) =>
          RegexChar(escapeChars(ch))
        case _ if regexPunct.contains(ch) && !regexMetaChars.contains(ch) =>
          RegexChar(ch)
        case _ =>
          regex
      }

      case RegexCharacterRange(_, _) =>
        regex

      case RegexCharacterClass(negated, characters) =>
        characters.foreach {
          case r @ RegexChar(ch) if ch == '[' || ch == ']' =>
            // examples:
            // - "[a[]" should match the literal characters "a" and "["
            // - "[a-b[c-d]]" is supported by Java but not cuDF
            throw new RegexUnsupportedException(
              "Nested character classes are not supported", r.position)
          case _ =>
        }
        val components = ListBuffer(characters.toSeq
          .map {
            case r @ RegexChar(ch) if "^$.".contains(ch) => r
            case ch => rewrite(ch, replacement, None, flags) match {
              case valid: RegexCharacterClassComponent => valid
              case _ =>
                // this can happen when a character class contains a meta-sequence such as
                // `\s` that gets transpiled into another character class
                throw new RegexUnsupportedException("Character class contains one or more " +
                  "characters that cannot be transpiled to supported character-class components",
                  ch.position)
            }
          }: _*)

        if (negated) {
          negateCharacterClass(components)
        } else {
          RegexCharacterClass(negated, components)
        }

      case sequence @ RegexSequence(parts) =>
        if (parts.isEmpty) {
          // examples: "", "()", "a|", "|b"
          throw new RegexUnsupportedException("Empty sequence not supported",
            sequence.position)
        }
        if (isRegexChar(parts.head, '|')) {
          // example: "|b"
          throw new RegexUnsupportedException("Choice with one empty side not supported",
            parts.head.position)
        }
        if (isRegexChar(parts.last, '|')) {
          // example: "a|"
          throw new RegexUnsupportedException("Choice with one empty side not supported",
            parts.last.position)
        }
        if (isRegexChar(parts.head, '{')) {
          // example: "{"
          // cuDF would treat this as a quantifier even though in this
          // context (being at the start of a sequence) it is not quantifying anything
          // note that we could choose to escape this in the transpiler rather than
          // falling back to CPU
          throw new RegexUnsupportedException("Token preceding '{' is not quantifiable",
            parts.head.position)
        }
        if (parts.forall(isBeginOrEndLineAnchor)) {
          throw new RegexUnsupportedException(
            "Sequences that only contain '^' or '$' are not supported", sequence.position)
        }

        def popBackrefIfNecessary(capture: Boolean): Unit = {
          if (mode == RegexReplaceMode && !capture) {
            replacement match {
              case Some(repl) =>
                repl.popBackref()
              case _ =>
            }
          }
        }

        // Special handling for line anchor ($)
        // This code is implemented here because to make it work in cuDF, we have to reorder
        // the items in the regex.
        // In the JVM, regexes like "\n$" and "$\n" have similar treatment
        RegexSequence(parts.foldLeft((new ListBuffer[RegexAST](),
          Option.empty[RegexAST]))((m, part) => {
            val (r, last) = m
            last match {
              // when the previous character is a line anchor ($), the JVM has special handling
              // when matching against line terminator characters
              case Some(RegexChar('$')) | Some(RegexEscaped('Z')) =>
                val j = r.lastIndexWhere {
                  case RegexEmpty() => false
                  case _ => true
                }
                part match {
                  case RegexGroup(capture, RegexSequence(
                      ListBuffer(RegexCharacterClass(true, parts))), _)
                      if parts.forall(!isBeginOrEndLineAnchor(_)) =>
                    r(j) = RegexSequence(ListBuffer(lineTerminatorMatcher(Set.empty, true, capture),
                        RegexChar('$')))
                    popBackrefIfNecessary(capture)
                  case RegexGroup(capture, RegexCharacterClass(true, parts), _)
                      if parts.forall(!isBeginOrEndLineAnchor(_)) =>
                    r(j) = RegexSequence(ListBuffer(lineTerminatorMatcher(Set.empty, true, capture),
                        RegexChar('$')))
                    popBackrefIfNecessary(capture)
                  case RegexCharacterClass(true, parts)
                      if parts.forall(!isBeginOrEndLineAnchor(_)) =>
                    r(j) = RegexSequence(
                      ListBuffer(lineTerminatorMatcher(Set.empty, true, false), RegexChar('$')))
                    popBackrefIfNecessary(false)
                  case RegexChar(ch) if ch == '\n' =>
                    // what's really needed here is negative lookahead, but that is not
                    // supported by cuDF
                    // in this case: $\n would transpile to (?!\r)\n$
                    throw new RegexUnsupportedException("Regex sequence $\\n is not supported",
                      part.position)
                  case RegexChar(ch) if "\r\u0085\u2028\u2029".contains(ch) =>
                    r(j) = RegexSequence(
                      ListBuffer(
                        rewrite(part, replacement, None, flags),
                        RegexSequence(ListBuffer(
                          RegexRepetition(lineTerminatorMatcher(Set(ch), true, false),
                            SimpleQuantifier('?')), RegexChar('$')))))
                    popBackrefIfNecessary(false)
                  case RegexEscaped('z') =>
                    // \Z\z or $\z transpiles to $
                    r(j) = RegexChar('$')
                    popBackrefIfNecessary(false)
                  case RegexEscaped(a) if "bB".contains(a) =>
                    throw new RegexUnsupportedException(
                      "Regex sequences with \\b or \\B not supported around $", part.position)
                  case _ =>
                    r.append(rewrite(part, replacement, last, flags))
                }
              case _ =>
                r.append(rewrite(part, replacement, last, flags))
            }
            r.last match {
              case RegexEmpty() =>
                (r, last)
              case _ =>
                (r, Some(part))
            }
        })._1)

      case RegexRepetition(base, quantifier) => (base, quantifier) match {
        case (_, SimpleQuantifier(ch)) if mode == RegexSplitMode
            && flags.emptyRepetition && "?*".contains(ch) =>
          // example: pattern " ?", input "] b[", replace with "X":
          // java: X]XXbX[X
          // cuDF: XXXX] b[
          // see https://github.com/NVIDIA/spark-rapids/issues/4884
          throw new RegexUnsupportedException(
            "regexp_split on GPU does not support empty match repetition consistently with Spark",
            quantifier.position)

        case (_, QuantifierVariableLength(0, _)) if mode == RegexSplitMode
            && flags.emptyRepetition =>
          // see https://github.com/NVIDIA/spark-rapids/issues/4884
          throw new RegexUnsupportedException(
            "regexp_split on GPU does not support empty match repetition consistently with Spark",
            quantifier.position)

        case (_, QuantifierVariableLength(0, Some(0))) if mode != RegexFindMode =>
          throw new RegexUnsupportedException(
            "regex_replace and regex_split on GPU do not support repetition with {0,0}",
            quantifier.position)

        case (_, QuantifierFixedLength(0)) if mode != RegexFindMode =>
          throw new RegexUnsupportedException(
            "regex_replace and regex_split on GPU do not support repetition with {0}",
            quantifier.position)

        case (RegexGroup(capture, term, _), SimpleQuantifier(ch))
            if "+*".contains(ch) && !isSupportedRepetitionBase(term) =>
          (term, ch) match {
            // \Z is not supported in groups
            case (RegexEscaped('A'), '+') |
                (RegexSequence(ListBuffer(RegexEscaped('A'))), '+') =>
              // (\A)+ can be transpiled to (\A) (dropping the repetition)
              // we use rewrite(...) here to handle logic regarding modes
              // (\A is not supported in RegexSplitMode)
              RegexGroup(capture, rewrite(term, replacement, previous, flags), None)
            // NOTE: (\A)* can be transpiled to (\A)?
            // however, (\A)? is not supported in libcudf yet
            case _ =>
              val unsupportedTerm = getUnsupportedRepetitionBase(term)
              throw new RegexUnsupportedException(
                s"cuDF does not support repetition of group containing: " +
                  s"${unsupportedTerm.toRegexString}", term.position)
          }
        case (RegexGroup(capture, term, _), QuantifierVariableLength(n, _))
            if !isSupportedRepetitionBase(term) =>
          term match {
            // \Z is not supported in groups
            case RegexEscaped('A') |
              RegexSequence(ListBuffer(RegexEscaped('A'))) if n > 0 =>
              // (\A){1,} can be transpiled to (\A) (dropping the repetition)
              // we use rewrite(...) here to handle logic regarding modes
              // (\A is not supported in RegexSplitMode)
              RegexGroup(capture, rewrite(term, replacement, previous, flags), None)
            // NOTE: (\A)* can be transpiled to (\A)?
            // however, (\A)? is not supported in libcudf yet
            case _ =>
              val unsupportedTerm = getUnsupportedRepetitionBase(term)
              throw new RegexUnsupportedException(
                s"cuDF does not support repetition of group containing: " +
                  s"${unsupportedTerm.toRegexString}", term.position)
          }
        case (RegexGroup(capture, term, _), QuantifierFixedLength(n))
            if !isSupportedRepetitionBase(term) =>
          term match {
            // \Z is not supported in groups
            case RegexEscaped('A') |
              RegexSequence(ListBuffer(RegexEscaped('A'))) if n > 0 =>
              // (\A){1,} can be transpiled to (\A) (dropping the repetition)
              // we use rewrite(...) here to handle logic regarding modes
              // (\A is not supported in RegexSplitMode)
              RegexGroup(capture, rewrite(term, replacement, previous, flags), None)
            // NOTE: (\A)* can be transpiled to (\A)?
            // however, (\A)? is not supported in libcudf yet
            case _ =>
              val unsupportedTerm = getUnsupportedRepetitionBase(term)
              throw new RegexUnsupportedException(
                s"cuDF does not support repetition of group containing: " +
                  s"${unsupportedTerm.toRegexString}", term.position)
          }
        case (RegexGroup(_, term, _), SimpleQuantifier(ch)) if ch == '?' =>
          if (isEntirelyWordBoundary(term) || isEntirelyLineAnchor(term)) {
            throw new RegexUnsupportedException(
                s"cuDF does not support repetition of: ${term.toRegexString}", term.position)
          }
          RegexRepetition(rewrite(base, replacement, None, flags), quantifier)
        case (RegexEscaped(ch), SimpleQuantifier('+')) if "AZ".contains(ch) =>
          // \A+ can be transpiled to \A (dropping the repetition)
          // \Z+ can be transpiled to \Z (dropping the repetition)
          // we use rewrite(...) here to handle logic regarding modes
          // (\A and \Z are not supported in RegexSplitMode)
          rewrite(base, replacement, previous, flags)
        // NOTE: \A* can be transpiled to \A?
        // however, \A? is not supported in libcudf yet
        case (RegexEscaped(ch), QuantifierFixedLength(n)) if n > 0 && "AZ".contains(ch) =>
          // \A{2} can be transpiled to \A (dropping the repetition)
          // \Z{2} can be transpiled to \Z (dropping the repetition)
          rewrite(base, replacement, previous, flags)
        case (RegexEscaped(ch), QuantifierVariableLength(n,_)) if n > 0 && "AZ".contains(ch) =>
          // \A{1,5} can be transpiled to \A (dropping the repetition)
          // \Z{1,} can be transpiled to \Z (dropping the repetition)
          rewrite(base, replacement, previous, flags)
        case _ if isSupportedRepetitionBase(base) =>
          RegexRepetition(rewrite(base, replacement, None, flags), quantifier)
        case (RegexRepetition(_, SimpleQuantifier('*')), SimpleQuantifier('+')) =>
          throw new RegexUnsupportedException("Possessive quantifier *+ not supported",
            quantifier.position)
        case (RegexRepetition(_, SimpleQuantifier('?' | '*' | '+')), SimpleQuantifier('?')) =>
          RegexRepetition(rewrite(base, replacement, None, flags), quantifier)
        case _ =>
          throw new RegexUnsupportedException("Preceding token cannot be quantified",
            quantifier.position)

      }

      case RegexChoice(l, r) =>
        val ll = rewrite(l, replacement, None, flags)
        val rr = rewrite(r, replacement, None, flags)

        // cuDF does not support zero-length repetition in replace or split mode
        // cuDF does support +, fixed-length, and variable length with min > 0
        if (mode != RegexFindMode) {
          if (isRepetition(ll, true)) {
            throw new RegexUnsupportedException(
              "cuDF does not support replace or split with zero-length repetition on one side of a"
              + " choice",
              l.position)
          } else if (isRepetition(rr, true)) {
            throw new RegexUnsupportedException(
              "cuDF does not support replace or split with zero-length repetition on one side of a"
              + " choice",
              r.position)
          }
        }

        if (mode == RegexSplitMode) {
          if (beginsWithLineAnchor(ll) || beginsWithLineAnchor(rr) ||
              endsWithLineAnchor(ll) || endsWithLineAnchor(rr)) {
            throw new RegexUnsupportedException(
              "cuDF does not support either side of a choice containing a line anchor in split "
               + "mode", l.position)
          }
        }

        // cuDF does not support terms ending with word boundaries on one side
        // of a choice, such as "\\b|a"
        if (endsWithWordBoundary(ll)) {
          throw new RegexUnsupportedException(
            "cuDF does not support terms ending with word boundaries on one side of a choice",
            l.position)
        } else if (endsWithWordBoundary(rr)) {
          throw new RegexUnsupportedException(
            "cuDF does not support terms ending with word boundaries on one side of a choice",
            r.position)
        }

        (ll, rr) match {
          // ll = lazyQuantifier inside a choice
          case (RegexSequence(ListBuffer(RegexRepetition(
          RegexRepetition(_, SimpleQuantifier('?')), SimpleQuantifier('?')))), _) |
               // rr = lazyQuantifier inside a choice
               (_, RegexSequence(ListBuffer(RegexRepetition(
               RegexRepetition(_, SimpleQuantifier('?')), SimpleQuantifier('?'))))) =>
            throw new RegexUnsupportedException(
              "cuDF does not support lazy quantifier inside choice", r.position)
          case (_, RegexChoice(RegexSequence(_), RegexSequence(ListBuffer(RegexRepetition(
          RegexEscaped('A'), SimpleQuantifier('?')), _)))) =>
            throw new RegexUnsupportedException("Invalid regex pattern at position", r.position)
          case _ =>
        }
        RegexChoice(ll, rr)

      case g @ RegexGroup(_, _, Some(lookahead)) =>
        val msg = lookahead match {
          case RegexPositiveLookahead =>
            "Positive lookahead groups are not supported"
          case RegexNegativeLookahead =>
            "Negative lookahead groups are not supported"
        }
        throw new RegexUnsupportedException(msg, g.position)

      case RegexGroup(capture, term, _) =>
        term match {
          case RegexSequence(parts) =>
            parts.foreach { part =>
              if (isBeginOrEndLineAnchor(part)) {
                throw new RegexUnsupportedException(
                  "Line and string anchors are not supported in capture groups", part.position)
              }
              part match {
                case RegexRepetition(base, quantifier) => (base, quantifier) match {
                  case (_, QuantifierVariableLength(0, Some(0))) =>
                    throw new RegexUnsupportedException(
                      "Repetition with {0,0} not supported in capture groups",
                      quantifier.position)

                  case (_, QuantifierFixedLength(0)) =>
                    throw new RegexUnsupportedException(
                      "Reptition with {0} not supported in capture groups",
                      quantifier.position)
                  case _ =>
                }
                case _ =>
              }
            }
            RegexGroup(capture, rewrite(term, replacement, None, flags), None)
          case _ =>
            RegexGroup(capture, rewrite(term, replacement, None, flags), None)
        }

      case other =>
        throw new RegexUnsupportedException(s"Unhandled expression in transpiler: $other",
          other.position)
    }
  }

  private def contains(regex: RegexAST, f: RegexAST => Boolean): Boolean = {
    if (f(regex)) {
      true
    } else {
      regex match {
        case RegexSequence(parts) => parts.exists(x => contains(x, f))
        case RegexGroup(_, term, _) => contains(term, f)
        case RegexChoice(l, r) => contains(l, f) || contains(r, f)
        case RegexRepetition(term, _) => contains(term, f)
        case RegexCharacterClass(_, chars) => chars.exists(ch => contains(ch, f))
        case leaf => f(leaf)
      }
    }
  }

  @scala.annotation.tailrec
  private def isEntirely(regex: RegexAST, f: RegexAST => Boolean): Boolean = {
    regex match {
      case RegexSequence(parts) if parts.nonEmpty =>
        parts.forall(f)
      case RegexGroup(_, term, _) =>
        isEntirely(term, f)
      case _ => f(regex)
    }
  }

  private def isEntirelyWordBoundary(regex: RegexAST): Boolean = {
    isEntirely(regex, {
      case RegexEscaped(ch) if "bB".contains(ch) => true
      case _ => false
    })
  }

  private def isEntirelyLineAnchor(regex: RegexAST): Boolean = {
    isEntirely(regex, {
      case RegexEscaped('A') => true
      case other => isBeginOrEndLineAnchor(other)
    })
  }

  @scala.annotation.tailrec
  private def beginsWith(regex: RegexAST, f: RegexAST => Boolean): Boolean = {
    regex match {
      case RegexSequence(parts) if parts.nonEmpty =>
        val j = parts.indexWhere {
            case RegexEmpty() => false
            case _ => true
        }
        beginsWith(parts(j), f)
      case RegexGroup(_, term, _) =>
        beginsWith(term, f)
      case _ => f(regex)
    }

  }

  @scala.annotation.tailrec
  private def endsWith(regex: RegexAST, f: RegexAST => Boolean): Boolean = {
    regex match {
      case RegexSequence(parts) if parts.nonEmpty =>
        val j = parts.lastIndexWhere {
            case RegexEmpty() => false
            case _ => true
        }
        endsWith(parts(j), f)
      case RegexGroup(_, term, _) =>
        endsWith(term, f)
      case _ => f(regex)
    }
  }

  private def endsWithLineAnchor(e: RegexAST): Boolean = {
    endsWith(e, {
      case RegexEscaped(ch) if "AZ".contains(ch) => true
      case other => isBeginOrEndLineAnchor(other)
    })
  }

  private def beginsWithLineAnchor(e: RegexAST): Boolean = {
    beginsWith(e, {
      case RegexEscaped(ch) if "AZ".contains(ch) => true
      case other => isBeginOrEndLineAnchor(other)
    })
  }

  private def endsWithWordBoundary(e: RegexAST): Boolean = {
    endsWith(e, {
      case RegexEscaped(a) if "bB".contains(a) => true
      case _ => false
    })
  }

  private def isBeginOrEndLineAnchor(regex: RegexAST): Boolean = regex match {
    case RegexSequence(parts) => parts.nonEmpty && parts.forall(isBeginOrEndLineAnchor)
    case RegexGroup(_, term, _) => isBeginOrEndLineAnchor(term)
    case RegexChoice(l, r) => isBeginOrEndLineAnchor(l) && isBeginOrEndLineAnchor(r)
    case RegexRepetition(term, _) => isBeginOrEndLineAnchor(term)
    case RegexChar(ch) => ch == '^' || ch == '$'
    case RegexEscaped(ch) if "zZ".contains(ch) => true // \z gets translated to $
    case RegexEscaped(ch) if 'A' == ch => true
    case _ => false
  }

  private def isRegexChar(expr: RegexAST, value: Char): Boolean = expr match {
    case RegexChar(ch) => ch == value
    case _ => false
  }
}

sealed trait RegexAST {
  def children(): Seq[RegexAST]
  def toRegexString: String
  var position: Option[Int] = None
}

sealed case class RegexEmpty() extends RegexAST {
  override def children(): Seq[RegexAST] = Seq.empty
  override def toRegexString: String = ""
}

sealed case class RegexSequence(parts: ListBuffer[RegexAST]) extends RegexAST {
  def this(parts: ListBuffer[RegexAST], position: Int) = {
    this(parts)
    this.position = Some(position)
  }
  override def children(): Seq[RegexAST] = parts.toSeq
  override def toRegexString: String = parts.map(_.toRegexString).mkString
}

sealed case class RegexGroup(capture: Boolean, term: RegexAST,
    lookahead: Option[RegexLookahead])
    extends RegexAST {
  def this(capture: Boolean, term: RegexAST) = {
    this(capture, term, None)
  }
  def this(capture: Boolean, term: RegexAST, position: Int) = {
    this(capture, term, None)
    this.position = Some(position)
  }
  def this(capture: Boolean, term: RegexAST, position: Int, lookahead: Option[RegexLookahead]) = {
    this(capture, term, lookahead)
    this.position = Some(position)
  }
  override def children(): Seq[RegexAST] = Seq(term)
  override def toRegexString: String = if (capture) {
    s"(${term.toRegexString})"
  } else if (lookahead.isDefined) {
    lookahead match {
      case Some(RegexPositiveLookahead) => s"(=${term.toRegexString})"
      case Some(RegexNegativeLookahead) => s"(!${term.toRegexString})"
      case _ => throw new IllegalStateException("Should not reach here")
    }
  } else {
    s"(?:${term.toRegexString})"
  }
}

sealed case class RegexChoice(a: RegexAST, b: RegexAST) extends RegexAST {
  def this(a: RegexAST, b: RegexAST, position: Int) = {
    this(a, b)
    this.position = Some(position)
  }
  override def children(): Seq[RegexAST] = Seq(a, b)
  override def toRegexString: String = s"${a.toRegexString}|${b.toRegexString}"
}

sealed case class RegexRepetition(a: RegexAST, quantifier: RegexQuantifier) extends RegexAST {
  def this(a: RegexAST, quantifier: RegexQuantifier, position: Int) = {
    this(a, quantifier)
    this.position = Some(position)
  }
  override def children(): Seq[RegexAST] = Seq(a)
  override def toRegexString: String = s"${a.toRegexString}${quantifier.toRegexString}"
}

sealed trait RegexQuantifier extends RegexAST

sealed case class SimpleQuantifier(ch: Char) extends RegexQuantifier {
  def this(ch: Char, position: Int) = {
    this(ch)
    this.position = Some(position)
  }
  override def children(): Seq[RegexAST] = Seq.empty
  override def toRegexString: String = ch.toString
}

sealed case class QuantifierFixedLength(length: Int)
    extends RegexQuantifier {
  def this(length: Int, position: Int) = {
    this(length)
    this.position = Some(position)
  }
  override def children(): Seq[RegexAST] = Seq.empty
  override def toRegexString: String = {
    s"{$length}"
  }
}

sealed case class QuantifierVariableLength(minLength: Int, maxLength: Option[Int])
    extends RegexQuantifier{
  def this(minLength: Int, maxLength: Option[Int], position: Int) = {
    this(minLength, maxLength)
    this.position = Some(position)
  }
  override def children(): Seq[RegexAST] = Seq.empty
  override def toRegexString: String = {
    maxLength match {
      case Some(max) =>
        s"{$minLength,$max}"
      case _ =>
        s"{$minLength,}"
    }
  }
}

sealed trait RegexCharacterClassComponent extends RegexAST

sealed case class RegexHexDigit(a: String) extends RegexCharacterClassComponent {
  def this(a: String, position: Int) = {
    this(a)
    this.position = Some(position)
  }
  val codePoint = Integer.parseInt(a, 16)

  override def children(): Seq[RegexAST] = Seq.empty
  override def toRegexString: String = {
    if (a.length == 2) {
      s"\\x$a"
    } else {
      s"\\x{$a}"
    }
  }
}

sealed case class RegexOctalChar(a: String) extends RegexCharacterClassComponent {
  def this(a: String, position: Int) = {
    this(a)
    this.position = Some(position)
  }
  val codePoint = Integer.parseInt(a, 8)

  override def children(): Seq[RegexAST] = Seq.empty
  override def toRegexString: String = s"\\$a"
}

sealed case class RegexChar(ch: Char) extends RegexCharacterClassComponent {
  def this(ch: Char, position: Int) = {
    this(ch)
    this.position = Some(position)
  }
  override def children(): Seq[RegexAST] = Seq.empty
  override def toRegexString: String = ch.toString
}

sealed case class RegexEscaped(a: Char) extends RegexCharacterClassComponent {
  def this(a: Char, position: Int) = {
    this(a)
    this.position = Some(position)
  }
  override def children(): Seq[RegexAST] = Seq.empty
  override def toRegexString: String = s"\\$a"
}

sealed case class RegexCharacterRange(start: RegexCharacterClassComponent,
    end: RegexCharacterClassComponent)
  extends RegexCharacterClassComponent{
  def this(start: RegexCharacterClassComponent,
           end: RegexCharacterClassComponent,
           position: Int) = {
    this(start, end)
    this.position = Some(position)
  }

  override def children(): Seq[RegexAST] = Seq.empty
  override def toRegexString: String =  s"${start.toRegexString}-${end.toRegexString}"
}

sealed case class RegexCharacterClass(
    var negated: Boolean,
    var characters: ListBuffer[RegexCharacterClassComponent])
  extends RegexAST {
  def this(
      negated: Boolean,
      characters: ListBuffer[RegexCharacterClassComponent],
      position: Int) = {
    this(negated, characters)
    this.position = Some(position)
  }

  override def children(): Seq[RegexAST] = characters.toSeq

  def append(ch: Char): Unit = {
    characters += RegexChar(ch)
  }

  def append(component: RegexCharacterClassComponent): Unit = {
    characters += component
  }

  def appendEscaped(ch: Char): Unit = {
    characters += RegexEscaped(ch)
  }

  def appendRange(start: RegexCharacterClassComponent,
      end: RegexCharacterClassComponent): Unit = {
    characters += RegexCharacterRange(start, end)
  }

  override def toRegexString: String = {
    val builder = new StringBuilder("[")
    if (negated) {
      builder.append("^")
    }
    for (a <- characters) {
      a match {
        case RegexChar(ch) if requiresEscaping(ch) =>
          // cuDF has stricter escaping requirements for certain characters
          // within a character class compared to Java or Python regex
          builder.append(s"\\$ch")
        case other =>
          builder.append(other.toRegexString)
      }
    }
    builder.append("]")
    builder.toString()
  }

  private def requiresEscaping(ch: Char): Boolean = {
    // there are likely other cases that we will need to add here but this
    // covers everything we have seen so far during fuzzing
    ch match {
      case '-' =>
        // cuDF requires '-' to be escaped when used as a character within a character
        // to disambiguate from the character range syntax 'a-b'
        true
      case _ =>
        false
    }
  }
}

sealed case class RegexBackref(num: Int, isNew: Boolean = false) extends RegexAST {
  def this(num: Int, isNew: Boolean, position: Int) = {
    this(num, isNew)
    this.position = Some(position)
  }
  override def children(): Seq[RegexAST] = Seq.empty
  override def toRegexString(): String = s"$$$num"
}

sealed case class RegexReplacement(parts: ListBuffer[RegexAST],
    var numCaptureGroups: Int = 0) extends RegexAST {
  def this(parts: ListBuffer[RegexAST], numCaptureGroups: Int, position: Int) = {
    this(parts, numCaptureGroups)
    this.position = Some(position)
  }
  override def children(): Seq[RegexAST] = parts.toSeq
  override def toRegexString: String = parts.map(_.toRegexString).mkString

  def appendBackref(num: Int): Unit = {
    numCaptureGroups += 1
    parts += RegexBackref(num, true)
  }

  def popBackref(): Unit = {
    parts.last match {
      case RegexBackref(_, true) => {
        numCaptureGroups -= 1
        parts.trimEnd(1)
      }
      case _ =>
    }
  }

  def hasBackrefs: Boolean = numCaptureGroups > 0
}

class RegexUnsupportedException(message: String, index: Option[Int])
  extends SQLException {
  override def getMessage: String = {
    index match {
      case Some(i) => s"$message near index $i"
      case _ => message
    }
  }
}

sealed trait RegexOptimizationType
object RegexOptimizationType {
  case class StartsWith(literal: String) extends RegexOptimizationType
  case class Contains(literal: String) extends RegexOptimizationType
  case class PrefixRange(literal: String, length: Int, rangeStart: Int, rangeEnd: Int) 
    extends RegexOptimizationType
  case class MultipleContains(literals: Seq[String]) extends RegexOptimizationType
  case object NoOptimization extends RegexOptimizationType
}

object RegexRewrite {

  @scala.annotation.tailrec
  private def removeBrackets(astLs: collection.Seq[RegexAST]): collection.Seq[RegexAST] = {
    astLs match {
      case collection.Seq(RegexGroup(_, RegexSequence(terms), None)) => removeBrackets(terms)
      case _ => astLs
    }
  }

  /* 
   * Extracts the prefix range pattern info from the given AST sequence.
   * 
   * @param astLs The AST sequence to extract the prefix range pattern from.
   * @return Some(prefix, length, start, end) if astLs is a `prefix[start-end]{x}` pattern
   * None otherwise. start and end are the code points of the start and end characters.
   */
  private def getPrefixRangePattern(astLs: collection.Seq[RegexAST]): 
      Option[(String, Int, Int, Int)] = {
    val haveLiteralPrefix = isLiteralString(astLs.dropRight(1))
    val endsWithRange = astLs.lastOption match {
      case Some(ast) => removeBrackets(collection.Seq(ast)) match {
        case collection.Seq(RegexRepetition(
            RegexCharacterClass(false, ListBuffer(RegexCharacterRange(a,b))), 
            quantifier)) => {
          val (start, end) = (a, b) match {
            case (RegexChar(start), RegexChar(end)) => (start, end)
            case _ => return None
          }
          val length = quantifier match {
            // In Rlike, contains [a-b]{minLen,maxLen} pattern is equivalent to contains 
            // [a-b]{minLen} because the matching will return the result once it finds the 
            // minimum match so y here is unnecessary.
            case QuantifierVariableLength(minLen, _) => minLen
            case QuantifierFixedLength(len) => len
            case SimpleQuantifier(ch) => ch match {
              case '*' | '?' => 0
              case '+' => 1
              case _ => return None
            }
            case _ => return None
          }
          // Convert start and end to code points
          Some((length, start.toInt, end.toInt))
        }
        case _ => None
      }
      case _ => None
    }
    (haveLiteralPrefix, endsWithRange) match {
      case (true, Some((length, start, end))) => {
        val prefix = RegexCharsToString(astLs.dropRight(1))
        Some((prefix, length, start, end))
      }
      case _ => None
    }
  }

  private def isLiteralString(astLs: collection.Seq[RegexAST]): Boolean = {
    removeBrackets(astLs).forall {
      case RegexChar(ch) => !regexMetaChars.contains(ch)
      case _ => false
    }
  }

  private def getMultipleContainsLiterals(ast: RegexAST): Seq[String] = {
    ast match {
      case RegexGroup(_, term, _) => getMultipleContainsLiterals(term)
      case RegexChoice(RegexSequence(parts), ls) if isLiteralString(parts) => {
        getMultipleContainsLiterals(ls) match {
          case Seq() => Seq.empty
          case literals => RegexCharsToString(parts) +: literals
        }
      }
      case RegexSequence(parts) if (isLiteralString(parts)) => Seq(RegexCharsToString(parts))
      case _ => Seq.empty
    }
  }

  private def isWildcard(ast: RegexAST): Boolean = {
    ast match {
      case RegexRepetition(RegexChar('.'), SimpleQuantifier('*')) => true
      case RegexSequence(parts) if parts.forall(isWildcard) => true
      case RegexGroup(_, term, _) if isWildcard(term) => true
      case _ => false
    }
  }

  private def stripLeadingWildcards(astLs: collection.Seq[RegexAST]): 
      collection.Seq[RegexAST] = {
    astLs.dropWhile(isWildcard)
  }

  private def stripTailingWildcards(astLs: collection.Seq[RegexAST]): 
      collection.Seq[RegexAST] = {
    astLs.reverse.dropWhile(isWildcard).reverse
  }

  private def RegexCharsToString(chars: collection.Seq[RegexAST]): String = {
    removeBrackets(chars).map {
      case RegexChar(ch) => ch
      case _ => throw new IllegalArgumentException("Invalid character")
    }.mkString
  }

  /**
   * Matches the given regex ast to a regex optimization type for regex rewrite
   * optimization.
   *
   * @param ast Abstract Syntax Tree parsed from a regex pattern.
   * @return The `RegexOptimizationType` for the given pattern.
   */
  def matchSimplePattern(ast: RegexAST): RegexOptimizationType = {
    val astLs = ast match {
      case RegexSequence(_) => ast.children()
      case _ => Seq(ast)
    }
    val noTailingWildcards = stripTailingWildcards(astLs)
    if (noTailingWildcards.headOption.exists(
        ast => ast == RegexChar('^') || ast == RegexEscaped('A'))) {
      val possibleLiteral = noTailingWildcards.drop(1)
      if (isLiteralString(possibleLiteral)) {
        return RegexOptimizationType.StartsWith(RegexCharsToString(possibleLiteral))
      }
    }

    val noStartsWithAst = removeBrackets(stripLeadingWildcards(noTailingWildcards))

    // Check if the pattern is a contains literal pattern
    if (isLiteralString(noStartsWithAst)) {
      // literal or .*(literal).* => contains literal
      return RegexOptimizationType.Contains(RegexCharsToString(noStartsWithAst))
    }

    // Check if the pattern is a multiple contains literal pattern (e.g. "abc|def|ghi")
    if (noStartsWithAst.length == 1) {
      val containsLiterals = getMultipleContainsLiterals(noStartsWithAst.head)
      if (!containsLiterals.isEmpty) {
        return RegexOptimizationType.MultipleContains(containsLiterals)
      }
    }

    // Check if the pattern is a prefix range pattern (e.g. "abc[a-z]{3}")
    val prefixRangeInfo = getPrefixRangePattern(noStartsWithAst)
    if (prefixRangeInfo.isDefined) {
      val (prefix, length, start, end) = prefixRangeInfo.get
      // (literal[a-b]{x,y}) => prefix range pattern
      return RegexOptimizationType.PrefixRange(prefix, length, start, end)
    }
    
    // return NoOptimization if the pattern is not a simple pattern and use cuDF
    RegexOptimizationType.NoOptimization
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy