All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.expressions.udaf.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.expressions

import org.apache.spark.annotation.Stable
import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.internal.{InvokeInlineUserDefinedFunction, UserDefinedFunctionLike}
import org.apache.spark.sql.types._

/**
 * The base class for implementing user-defined aggregate functions (UDAF).
 *
 * @since 1.5.0
 * @deprecated
 *   UserDefinedAggregateFunction is deprecated. Aggregator[IN, BUF, OUT] should now be registered
 *   as a UDF via the functions.udaf(agg) method.
 */
@Stable
@deprecated(
  "Aggregator[IN, BUF, OUT] should now be registered as a UDF" +
    " via the functions.udaf(agg) method.",
  "3.0.0")
abstract class UserDefinedAggregateFunction extends Serializable with UserDefinedFunctionLike {

  /**
   * A `StructType` represents data types of input arguments of this aggregate function. For
   * example, if a [[UserDefinedAggregateFunction]] expects two input arguments with type of
   * `DoubleType` and `LongType`, the returned `StructType` will look like
   *
   * ```
   *   new StructType()
   *    .add("doubleInput", DoubleType)
   *    .add("longInput", LongType)
   * ```
   *
   * The name of a field of this `StructType` is only used to identify the corresponding input
   * argument. Users can choose names to identify the input arguments.
   *
   * @since 1.5.0
   */
  def inputSchema: StructType

  /**
   * A `StructType` represents data types of values in the aggregation buffer. For example, if a
   * [[UserDefinedAggregateFunction]]'s buffer has two values (i.e. two intermediate values) with
   * type of `DoubleType` and `LongType`, the returned `StructType` will look like
   *
   * ```
   *   new StructType()
   *    .add("doubleInput", DoubleType)
   *    .add("longInput", LongType)
   * ```
   *
   * The name of a field of this `StructType` is only used to identify the corresponding buffer
   * value. Users can choose names to identify the input arguments.
   *
   * @since 1.5.0
   */
  def bufferSchema: StructType

  /**
   * The `DataType` of the returned value of this [[UserDefinedAggregateFunction]].
   *
   * @since 1.5.0
   */
  def dataType: DataType

  /**
   * Returns true iff this function is deterministic, i.e. given the same input, always return the
   * same output.
   *
   * @since 1.5.0
   */
  def deterministic: Boolean

  /**
   * Initializes the given aggregation buffer, i.e. the zero value of the aggregation buffer.
   *
   * The contract should be that applying the merge function on two initial buffers should just
   * return the initial buffer itself, i.e. `merge(initialBuffer, initialBuffer)` should equal
   * `initialBuffer`.
   *
   * @since 1.5.0
   */
  def initialize(buffer: MutableAggregationBuffer): Unit

  /**
   * Updates the given aggregation buffer `buffer` with new input data from `input`.
   *
   * This is called once per input row.
   *
   * @since 1.5.0
   */
  def update(buffer: MutableAggregationBuffer, input: Row): Unit

  /**
   * Merges two aggregation buffers and stores the updated buffer values back to `buffer1`.
   *
   * This is called when we merge two partially aggregated data together.
   *
   * @since 1.5.0
   */
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit

  /**
   * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given
   * aggregation buffer.
   *
   * @since 1.5.0
   */
  def evaluate(buffer: Row): Any

  /**
   * Creates a `Column` for this UDAF using given `Column`s as input arguments.
   *
   * @since 1.5.0
   */
  @scala.annotation.varargs
  def apply(exprs: Column*): Column = {
    Column(InvokeInlineUserDefinedFunction(this, exprs.map(_.node)))
  }

  /**
   * Creates a `Column` for this UDAF using the distinct values of the given `Column`s as input
   * arguments.
   *
   * @since 1.5.0
   */
  @scala.annotation.varargs
  def distinct(exprs: Column*): Column = {
    Column(InvokeInlineUserDefinedFunction(this, exprs.map(_.node), isDistinct = true))
  }
}

/**
 * A `Row` representing a mutable aggregation buffer.
 *
 * This is not meant to be extended outside of Spark.
 *
 * @since 1.5.0
 */
@Stable
abstract class MutableAggregationBuffer extends Row {

  /** Update the ith value of this buffer. */
  def update(i: Int, value: Any): Unit
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy