All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.catalyst.expressions.ExpressionSet.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.catalyst.expressions

import scala.collection.{mutable, GenTraversableOnce}
import scala.collection.mutable.ArrayBuffer

object ExpressionSet {
  /** Constructs a new [[ExpressionSet]] by applying [[Canonicalize]] to `expressions`. */
  def apply(expressions: TraversableOnce[Expression]): ExpressionSet = {
    val set = new ExpressionSet()
    expressions.foreach(set.add)
    set
  }

  def apply(): ExpressionSet = {
    new ExpressionSet()
  }
}

/**
 * A [[Set]] where membership is determined based on determinacy and a canonical representation of
 * an [[Expression]] (i.e. one that attempts to ignore cosmetic differences).
 * See [[Canonicalize]] for more details.
 *
 * Internally this set uses the canonical representation, but keeps also track of the original
 * expressions to ease debugging.  Since different expressions can share the same canonical
 * representation, this means that operations that extract expressions from this set are only
 * guaranteed to see at least one such expression.  For example:
 *
 * {{{
 *   val set = ExpressionSet(a + 1, 1 + a)
 *
 *   set.iterator => Iterator(a + 1)
 *   set.contains(a + 1) => true
 *   set.contains(1 + a) => true
 *   set.contains(a + 2) => false
 * }}}
 *
 * For non-deterministic expressions, they are always considered as not contained in the [[Set]].
 * On adding a non-deterministic expression, simply append it to the original expressions.
 * This is consistent with how we define `semanticEquals` between two expressions.
 */
class ExpressionSet protected(
    private val baseSet: mutable.Set[Expression] = new mutable.HashSet,
    private val originals: mutable.Buffer[Expression] = new ArrayBuffer)
  extends Iterable[Expression] {

  //  Note: this class supports Scala 2.12. A parallel source tree has a 2.13 implementation.

  protected def add(e: Expression): Unit = {
    if (!e.deterministic) {
      originals += e
    } else if (!baseSet.contains(e.canonicalized)) {
      baseSet.add(e.canonicalized)
      originals += e
    }
  }

  protected def remove(e: Expression): Unit = {
    if (e.deterministic) {
      baseSet --= baseSet.filter(_ == e.canonicalized)
      originals --= originals.filter(_.canonicalized == e.canonicalized)
    }
  }

  def contains(elem: Expression): Boolean = baseSet.contains(elem.canonicalized)

  override def filter(p: Expression => Boolean): ExpressionSet = {
    val newBaseSet = baseSet.filter(e => p(e.canonicalized))
    val newOriginals = originals.filter(e => p(e.canonicalized))
    new ExpressionSet(newBaseSet, newOriginals)
  }

  override def filterNot(p: Expression => Boolean): ExpressionSet = {
    val newBaseSet = baseSet.filterNot(e => p(e.canonicalized))
    val newOriginals = originals.filterNot(e => p(e.canonicalized))
    new ExpressionSet(newBaseSet, newOriginals)
  }

  def +(elem: Expression): ExpressionSet = {
    val newSet = clone()
    newSet.add(elem)
    newSet
  }

  def ++(elems: GenTraversableOnce[Expression]): ExpressionSet = {
    val newSet = clone()
    elems.foreach(newSet.add)
    newSet
  }

  def -(elem: Expression): ExpressionSet = {
    val newSet = clone()
    newSet.remove(elem)
    newSet
  }

  def --(elems: GenTraversableOnce[Expression]): ExpressionSet = {
    val newSet = clone()
    elems.foreach(newSet.remove)
    newSet
  }

  def map(f: Expression => Expression): ExpressionSet = {
    val newSet = new ExpressionSet()
    this.iterator.foreach(elem => newSet.add(f(elem)))
    newSet
  }

  def flatMap(f: Expression => Iterable[Expression]): ExpressionSet = {
    val newSet = new ExpressionSet()
    this.iterator.foreach(f(_).foreach(newSet.add))
    newSet
  }

  def iterator: Iterator[Expression] = originals.iterator

  def union(that: ExpressionSet): ExpressionSet = {
    val newSet = clone()
    that.iterator.foreach(newSet.add)
    newSet
  }

  def subsetOf(that: ExpressionSet): Boolean = this.iterator.forall(that.contains)

  def intersect(that: ExpressionSet): ExpressionSet = this.filter(that.contains)

  def diff(that: ExpressionSet): ExpressionSet = this -- that

  def apply(elem: Expression): Boolean = this.contains(elem)

  override def equals(obj: Any): Boolean = obj match {
    case other: ExpressionSet => this.baseSet == other.baseSet
    case _ => false
  }

  override def hashCode(): Int = baseSet.hashCode()

  override def clone(): ExpressionSet = new ExpressionSet(baseSet.clone(), originals.clone())

  /**
   * Returns a string containing both the post [[Canonicalize]] expressions and the original
   * expressions in this set.
   */
  def toDebugString: String =
    s"""
       |baseSet: ${baseSet.mkString(", ")}
       |originals: ${originals.mkString(", ")}
     """.stripMargin
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy