All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.catalyst.expressions.aggregate.CountMinSketchAgg.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription, Literal}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.sketch.CountMinSketch

/**
 * This function returns a count-min sketch of a column with the given esp, confidence and seed.
 * A count-min sketch is a probabilistic data structure used for summarizing streams of data in
 * sub-linear space, which is useful for equality predicates and join size estimation.
 * The result returned by the function is an array of bytes, which should be deserialized to a
 * `CountMinSketch` before usage.
 *
 * @param child child expression that can produce column value with `child.eval(inputRow)`
 * @param epsExpression relative error, must be positive
 * @param confidenceExpression confidence, must be positive and less than 1.0
 * @param seedExpression random seed
 */
@ExpressionDescription(
  usage = """
    _FUNC_(col, eps, confidence, seed) - Returns a count-min sketch of a column with the given esp,
      confidence and seed. The result is an array of bytes, which can be deserialized to a
      `CountMinSketch` before usage. Count-min sketch is a probabilistic data structure used for
      cardinality estimation using sub-linear space.
  """)
case class CountMinSketchAgg(
    child: Expression,
    epsExpression: Expression,
    confidenceExpression: Expression,
    seedExpression: Expression,
    override val mutableAggBufferOffset: Int,
    override val inputAggBufferOffset: Int)
  extends TypedImperativeAggregate[CountMinSketch] with ExpectsInputTypes {

  def this(
      child: Expression,
      epsExpression: Expression,
      confidenceExpression: Expression,
      seedExpression: Expression) = {
    this(child, epsExpression, confidenceExpression, seedExpression, 0, 0)
  }

  // Mark as lazy so that they are not evaluated during tree transformation.
  private lazy val eps: Double = epsExpression.eval().asInstanceOf[Double]
  private lazy val confidence: Double = confidenceExpression.eval().asInstanceOf[Double]
  private lazy val seed: Int = seedExpression.eval().asInstanceOf[Int]

  override def checkInputDataTypes(): TypeCheckResult = {
    val defaultCheck = super.checkInputDataTypes()
    if (defaultCheck.isFailure) {
      defaultCheck
    } else if (!epsExpression.foldable || !confidenceExpression.foldable ||
      !seedExpression.foldable) {
      TypeCheckFailure(
        "The eps, confidence or seed provided must be a literal or foldable")
    } else if (epsExpression.eval() == null || confidenceExpression.eval() == null ||
      seedExpression.eval() == null) {
      TypeCheckFailure("The eps, confidence or seed provided should not be null")
    } else if (eps <= 0.0) {
      TypeCheckFailure(s"Relative error must be positive (current value = $eps)")
    } else if (confidence <= 0.0 || confidence >= 1.0) {
      TypeCheckFailure(s"Confidence must be within range (0.0, 1.0) (current value = $confidence)")
    } else {
      TypeCheckSuccess
    }
  }

  override def createAggregationBuffer(): CountMinSketch = {
    CountMinSketch.create(eps, confidence, seed)
  }

  override def update(buffer: CountMinSketch, input: InternalRow): CountMinSketch = {
    val value = child.eval(input)
    // Ignore empty rows
    if (value != null) {
      child.dataType match {
        // For string type, we can get bytes of our `UTF8String` directly, and call the `addBinary`
        // instead of `addString` to avoid unnecessary conversion.
        case StringType => buffer.addBinary(value.asInstanceOf[UTF8String].getBytes)
        case _ => buffer.add(value)
      }
    }
    buffer
  }

  override def merge(buffer: CountMinSketch, input: CountMinSketch): CountMinSketch = {
    buffer.mergeInPlace(input)
    buffer
  }

  override def eval(buffer: CountMinSketch): Any = serialize(buffer)

  override def serialize(buffer: CountMinSketch): Array[Byte] = {
    buffer.toByteArray
  }

  override def deserialize(storageFormat: Array[Byte]): CountMinSketch = {
    CountMinSketch.readFrom(storageFormat)
  }

  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): CountMinSketchAgg =
    copy(mutableAggBufferOffset = newMutableAggBufferOffset)

  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): CountMinSketchAgg =
    copy(inputAggBufferOffset = newInputAggBufferOffset)

  override def inputTypes: Seq[AbstractDataType] = {
    Seq(TypeCollection(IntegralType, StringType, BinaryType), DoubleType, DoubleType, IntegerType)
  }

  override def nullable: Boolean = false

  override def dataType: DataType = BinaryType

  override def defaultResult: Option[Literal] =
    Option(Literal.create(eval(createAggregationBuffer()), dataType))

  override def children: Seq[Expression] =
    Seq(child, epsExpression, confidenceExpression, seedExpression)

  override def prettyName: String = "count_min_sketch"
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy