All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.execution.aggregate.ObjectAggregationMap.scala Maven / Gradle / Ivy

There is a newer version: 2.4.8
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.execution.aggregate

import java.{util => ju}

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.internal.config
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, TypedImperativeAggregate}
import org.apache.spark.sql.execution.UnsafeKVExternalSorter
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter

/**
 * An aggregation map that supports using safe `SpecificInternalRow`s aggregation buffers, so that
 * we can support storing arbitrary Java objects as aggregate function states in the aggregation
 * buffers. This class is only used together with [[ObjectHashAggregateExec]].
 */
class ObjectAggregationMap() {
  private[this] val hashMap = new ju.LinkedHashMap[UnsafeRow, InternalRow]

  def getAggregationBuffer(groupingKey: UnsafeRow): InternalRow = {
    hashMap.get(groupingKey)
  }

  def putAggregationBuffer(groupingKey: UnsafeRow, aggBuffer: InternalRow): Unit = {
    hashMap.put(groupingKey, aggBuffer)
  }

  def size: Int = hashMap.size()

  def iterator: Iterator[AggregationBufferEntry] = {
    val iter = hashMap.entrySet().iterator()
    new Iterator[AggregationBufferEntry] {

      override def hasNext: Boolean = {
        iter.hasNext
      }
      override def next(): AggregationBufferEntry = {
        val entry = iter.next()
        new AggregationBufferEntry(entry.getKey, entry.getValue)
      }
    }
  }

  /**
   * Dumps all entries into a newly created external sorter, clears the hash map, and returns the
   * external sorter.
   */
  def dumpToExternalSorter(
      groupingAttributes: Seq[Attribute],
      aggregateFunctions: Seq[AggregateFunction]): UnsafeKVExternalSorter = {
    val aggBufferAttributes = aggregateFunctions.flatMap(_.aggBufferAttributes)
    val sorter = new UnsafeKVExternalSorter(
      StructType.fromAttributes(groupingAttributes),
      StructType.fromAttributes(aggBufferAttributes),
      SparkEnv.get.blockManager,
      SparkEnv.get.serializerManager,
      TaskContext.get().taskMemoryManager().pageSizeBytes,
      SparkEnv.get.conf.get(config.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD),
      null
    )

    val mapIterator = iterator
    val unsafeAggBufferProjection =
      UnsafeProjection.create(aggBufferAttributes.map(_.dataType).toArray)

    while (mapIterator.hasNext) {
      val entry = mapIterator.next()
      aggregateFunctions.foreach {
        case agg: TypedImperativeAggregate[_] =>
          agg.serializeAggregateBufferInPlace(entry.aggregationBuffer)
        case _ =>
      }

      sorter.insertKV(
        entry.groupingKey,
        unsafeAggBufferProjection(entry.aggregationBuffer)
      )
    }

    hashMap.clear()
    sorter
  }

  def clear(): Unit = {
    hashMap.clear()
  }
}

// Stores the grouping key and aggregation buffer
class AggregationBufferEntry(var groupingKey: UnsafeRow, var aggregationBuffer: InternalRow)




© 2015 - 2024 Weber Informatics LLC | Privacy Policy