All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.rdd.ZippedPartitionsWithLocalityRDD.scala Maven / Gradle / Ivy

/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.rdd

import java.io.{IOException, ObjectOutputStream}

import org.apache.logging.log4j.{LogManager, Logger}
import org.apache.spark.util.Utils
import org.apache.spark.{Partition, SparkContext}

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

object ZippedPartitionsWithLocalityRDD {
  def apply[T: ClassTag, B: ClassTag, V: ClassTag]
  (rdd1: RDD[T], rdd2: RDD[B], preservesPartitioning: Boolean = false)
    (f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] = rdd1.withScope {
    val sc = rdd1.sparkContext
    new ZippedPartitionsWithLocalityRDD(
      sc, sc.clean(f), rdd1, rdd2, preservesPartitioning)
  }

  val logger: Logger = LogManager.getLogger(getClass)
}

/**
 * Prefer to zip partitions of rdd1 and rdd2 in the same location.
 * Remaining partitions not in same location will be zipped by order.
 * For example:
 * Say we have two RDDs, rdd1 and rdd2. The first partition of rdd1 is on node A, and the second
 * is on node B. The first partition of rdd2 is on node B and the second one is on node A.
 * If we just use rdd1.zipPartition(rdd2), the result will be the first partition of rdd1 is
 * zipped with the first partition of rdd2, so there will be cross node communication. This is
 * bad for performance. That's why we introduce the ZippedPartitionsWithLocalityRDD.
 * In our method, the first partition of rdd1 will be zipped with the second partition of rdd2,
 * as they are on the same node. This will reduce the network communication cost and result in
 * a better performance.
 * @param sc spark context
 * @param _f
 * @param _rdd1
 * @param _rdd2
 * @param preservesPartitioning
 */
class ZippedPartitionsWithLocalityRDD[A: ClassTag, B: ClassTag, V: ClassTag](
  sc: SparkContext,
  _f: (Iterator[A], Iterator[B]) => Iterator[V],
  _rdd1: RDD[A],
  _rdd2: RDD[B],
  preservesPartitioning: Boolean = false)
  extends ZippedPartitionsRDD2[A, B, V](sc, _f, _rdd1, _rdd2, preservesPartitioning) {

  override def getPartitions: Array[Partition] = {
    if (rdds.length != 2) {
      //     scalastyle:off
      throw new IllegalArgumentException("this is only for 2 rdd zip")
      //     scalastyle:on
    }
    val numParts = rdds.head.partitions.length
    if (!rdds.forall(rdd => rdd.partitions.length == numParts)) {
      //     scalastyle:off
      throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
      //     scalastyle:on
    }

    val candidateLocs = new ArrayBuffer[(Int, Seq[String])]()
    (0 until numParts).foreach(p => {
      candidateLocs.append((p, rdds(1)
        .context.getPreferredLocs(rdds(1), p)
        .map(_.toString).distinct))
    })
    val nonmatchPartitionId = new ArrayBuffer[Int]()
    val parts = new Array[Partition](numParts)

    (0 until  numParts).foreach { i =>
      val curPrefs = rdds(0).context.getPreferredLocs(rdds(0), i).map(_.toString).distinct
      var p = 0
      var matchPartition: (Int, Seq[String]) = null
      var locs: Seq[String] = null
      while (p < candidateLocs.length) {
        locs = candidateLocs(p)._2.intersect(curPrefs)
        if (!locs.isEmpty) {
          matchPartition = candidateLocs.remove(p)
          p = Integer.MAX_VALUE - 1
        }
        p += 1
      }
      if (matchPartition != null) {
        parts(i) =
          new ZippedPartitionsLocalityPartition(i, Array(i, matchPartition._1), rdds, locs)
      } else {
        ZippedPartitionsWithLocalityRDD.logger.warn(s"can't find locality partition" +
          s"for partition $i Partition locations are (${curPrefs}) Candidate partition" +
          s" locations are\n" + s"${candidateLocs.mkString("\n")}.")
        nonmatchPartitionId.append(i)
      }
    }

    if (nonmatchPartitionId.size != candidateLocs.size) {
      //     scalastyle:off
      throw new IllegalArgumentException("unmatched partition size should be the same" +
        "with candidateLocs size")
      //     scalastyle:on
    }
    nonmatchPartitionId.foreach { i =>
      val locs = rdds(0).context.getPreferredLocs(rdds(0), i).map(_.toString).distinct
      val matchPartition = candidateLocs.remove(0)
      parts(i) = new ZippedPartitionsLocalityPartition(i, Array(i, matchPartition._1), rdds, locs)
    }
    parts
  }
}


private[spark] class ZippedPartitionsLocalityPartition(
  idx: Int,
  @transient val indexes: Seq[Int],
  @transient val rdds: Seq[RDD[_]],
  @transient override val preferredLocations: Seq[String])
  extends ZippedPartitionsPartition(idx, rdds, preferredLocations) {

  override val index: Int = idx
  var _partitionValues = rdds.zip(indexes).map{ case (rdd, i) => rdd.partitions(i) }
  override def partitions: Seq[Partition] = _partitionValues

  @throws(classOf[IOException])
  private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
    // Update the reference to parent split at the time of task serialization
    _partitionValues = rdds.zip(indexes).map{ case (rdd, i) => rdd.partitions(i) }
    oos.defaultWriteObject()
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy