All Downloads are FREE. Search and download functionalities are using the official Maven repository.

shark.execution.ReduceKey.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (C) 2012 The Regents of The University California.
 * All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package shark.execution

import java.io.{Externalizable, ObjectOutput, ObjectInput}

import com.google.common.primitives.UnsignedBytes

import org.apache.hadoop.io.{BytesWritable, WritableComparator}

import org.apache.spark.HashPartitioner


/**
 * A data structure used for shuffling data that supports comparison. We wrap
 * ReduceKey around a normal key byte array so the byte array can be used in
 * binary comparison, as well as enabling partitioning data based on the
 * partitionCode field using a ReduceKeyPartitioner.
 *
 * Note that this data structure needs to be serializable because in the case of a total
 * order sort, Spark's range partitioner uses a collect operation to find the ranges.
 * The collect serializes the ReduceKey objects and send them back to the master.
 */
sealed trait ReduceKey extends Comparable[ReduceKey] {

  def length: Int

  def byteArray: Array[Byte]

  override def hashCode: Int

  override def toString: String

  override def equals(other: Any): Boolean

  override def compareTo(that: ReduceKey): Int
}


class ReduceKeyMapSide(var bytesWritable: BytesWritable) extends ReduceKey
  with Externalizable {

  // A default no-arg constructor for Java serializer to use.
  def this() = this(new BytesWritable)

  /** Used by ReduceKeyPartitioner to determine the hash partition. */
  @transient var partitionCode = 0

  def createDeepCopy(): ReduceKeyMapSide = {
    val writable = new BytesWritable
    writable.set(bytesWritable.getBytes, 0, bytesWritable.getLength)
    val copy = new ReduceKeyMapSide(writable)
    copy.partitionCode = partitionCode
    copy
  }

  override def length: Int = bytesWritable.getLength()

  override def byteArray: Array[Byte] = bytesWritable.getBytes()

  override def equals(other: Any): Boolean  = {
    other match {
      case other: ReduceKey => {
        if (length != other.length) {
          false
        } else {
          WritableComparator.compareBytes(
            byteArray, 0, length, other.byteArray, 0, other.length) == 0
        }
      }
      case _ => false
    }
  }

  override def compareTo(that: ReduceKey): Int = {
    bytesWritable.compareTo(that.asInstanceOf[ReduceKeyMapSide].bytesWritable)
  }

  override def hashCode = bytesWritable.hashCode()

  override def toString = {
    this.getClass.getName + "(" + partitionCode + ": " + bytesWritable.toString + ")"
  }

  override def writeExternal(out: ObjectOutput) {
    bytesWritable.write(out)
  }

  override def readExternal(in: ObjectInput) {
    bytesWritable = new BytesWritable
    bytesWritable.readFields(in)
  }
}


class ReduceKeyReduceSide(private val _byteArray: Array[Byte])
  extends ReduceKey
  with Serializable {
  // This is serializable because of the possible external spill.

  def this() = this(null)

  override def byteArray: Array[Byte] = _byteArray

  override def length: Int = byteArray.length

  override def equals(other: Any): Boolean = {
    other match {
      case that: ReduceKeyReduceSide => {
        (this.byteArray.length == that.byteArray.length) && (this.compareTo(that) == 0)
      }
      case _ => false
    }
  }

  override def compareTo(that: ReduceKey): Int = {
    ReduceKeyReduceSide.comparator.compare(this.byteArray, that.byteArray)
  }

  override def hashCode = WritableComparator.hashBytes(_byteArray, _byteArray.length)

  override def toString = {
    this.getClass.getName + "(" + new BytesWritable(byteArray).toString + ")"
  }
}


object ReduceKeyReduceSide {
  val comparator = UnsignedBytes.lexicographicalComparator()
}


/**
 * A special Spark partitioner that allows hash partitioning of data based on
 * the partitionCode field in ReduceKey.
 */
class ReduceKeyPartitioner(partitions: Int) extends HashPartitioner(partitions) {

  override def getPartition(key: Any): Int = {
    key match {
      case k: ReduceKeyMapSide => {
        val mod = k.partitionCode % partitions
        if (mod < 0) mod + partitions else mod  // Guard against negative hash codes.
      }
      case other => {
        throw new Exception("ReduceKeyPartitioner expects object of class ReduceKeyMapSide, " +
          "but got " + other.getClass.getName)
      }
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy