All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.catalyst.expressions.EachTopK.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.spark.sql.catalyst.expressions

import scala.collection.mutable

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.catalyst.utils.InternalRowPriorityQueue
import org.apache.spark.sql.types._

trait TopKHelper {

  def k: Int
  def scoreType: DataType

  @transient val ScoreTypes = TypeCollection(
    ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType
  )

  protected case class ScoreWriter(writer: UnsafeRowWriter, ordinal: Int) {

    def write(v: Any): Unit = scoreType match {
      case ByteType => writer.write(ordinal, v.asInstanceOf[Byte])
      case ShortType => writer.write(ordinal, v.asInstanceOf[Short])
      case IntegerType => writer.write(ordinal, v.asInstanceOf[Int])
      case LongType => writer.write(ordinal, v.asInstanceOf[Long])
      case FloatType => writer.write(ordinal, v.asInstanceOf[Float])
      case DoubleType => writer.write(ordinal, v.asInstanceOf[Double])
      case d: DecimalType => writer.write(ordinal, v.asInstanceOf[Decimal], d.precision, d.scale)
    }
  }

  protected lazy val scoreOrdering = {
    val ordering = TypeUtils.getInterpretedOrdering(scoreType)
    if (k > 0) ordering else ordering.reverse
  }

  protected lazy val reverseScoreOrdering = scoreOrdering.reverse

  protected lazy val queue: InternalRowPriorityQueue = {
    new InternalRowPriorityQueue(Math.abs(k), (x: Any, y: Any) => scoreOrdering.compare(x, y))
  }
}

case class EachTopK(
    k: Int,
    scoreExpr: Expression,
    groupExprs: Seq[Expression],
    elementSchema: StructType,
    children: Seq[Attribute])
  extends Generator with TopKHelper with CodegenFallback {

  override val scoreType: DataType = scoreExpr.dataType

  private lazy val groupingProjection: UnsafeProjection = UnsafeProjection.create(groupExprs)
  private lazy val scoreProjection: UnsafeProjection = UnsafeProjection.create(scoreExpr :: Nil)

  // The grouping key of the current partition
  private var currentGroupingKeys: UnsafeRow = _

  override def checkInputDataTypes(): TypeCheckResult = {
    if (!ScoreTypes.acceptsType(scoreExpr.dataType)) {
      TypeCheckResult.TypeCheckFailure(s"$scoreExpr must have a comparable type")
    } else {
      TypeCheckResult.TypeCheckSuccess
    }
  }

  private def topKRowsForGroup(): Seq[InternalRow] = if (queue.size > 0) {
    val outputRows = queue.iterator.toSeq.sortBy(_._1)(scoreOrdering).reverse
    val (headScore, _) = outputRows.head
    val rankNum = outputRows.scanLeft((1, headScore)) {
      case ((rank, prevScore), (score, _)) =>
        if (prevScore == score) (rank, score) else (rank + 1, score)
    }.tail
    val buf = mutable.ArrayBuffer[InternalRow]()
    var i = 0
    while (rankNum.length > i) {
      val rank = rankNum(i)._1
      val row = new JoinedRow(InternalRow.fromSeq(rank :: Nil), outputRows(i)._2)
      buf.append(row)
      i += 1
    }
    buf
  } else {
    Seq.empty
  }

  override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
    val groupingKeys = groupingProjection(input)
    val ret = if (currentGroupingKeys != groupingKeys) {
      val topKRows = topKRowsForGroup()
      currentGroupingKeys = groupingKeys.copy()
      queue.clear()
      topKRows
    } else {
      Iterator.empty
    }
    queue += Tuple2(scoreProjection(input).get(0, scoreType), input)
    ret
  }

  override def terminate(): TraversableOnce[InternalRow] = {
    if (queue.size > 0) {
      val topKRows = topKRowsForGroup()
      queue.clear()
      topKRows
    } else {
      Iterator.empty
    }
  }

  // TODO: Need to support codegen
  // protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy