scala.collection.parallel.mutable.ParHashMap.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of scala-library Show documentation
Show all versions of scala-library Show documentation
Standard library for the Scala Programming Language
/* __ *\
** ________ ___ / / ___ Scala API **
** / __/ __// _ | / / / _ | (c) 2003-2013, LAMP/EPFL **
** __\ \/ /__/ __ |/ /__/ __ | http://scala-lang.org/ **
** /____/\___/_/ |_/____/_/ | | **
** |/ **
\* */
package scala.collection.parallel
package mutable
import scala.collection.generic._
import scala.collection.mutable.DefaultEntry
import scala.collection.mutable.HashEntry
import scala.collection.mutable.HashTable
import scala.collection.mutable.UnrolledBuffer
import scala.collection.parallel.Task
/** A parallel hash map.
*
* `ParHashMap` is a parallel map which internally keeps elements within a hash table.
* It uses chaining to resolve collisions.
*
* @tparam K type of the keys in the parallel hash map
* @tparam V type of the values in the parallel hash map
*
* @define Coll `ParHashMap`
* @define coll parallel hash map
*
* @author Aleksandar Prokopec
* @see [[http://docs.scala-lang.org/overviews/parallel-collections/concrete-parallel-collections.html#parallel_hash_tables Scala's Parallel Collections Library overview]]
* section on Parallel Hash Tables for more information.
*/
@SerialVersionUID(1L)
class ParHashMap[K, V] private[collection] (contents: HashTable.Contents[K, DefaultEntry[K, V]])
extends ParMap[K, V]
with GenericParMapTemplate[K, V, ParHashMap]
with ParMapLike[K, V, ParHashMap[K, V], scala.collection.mutable.HashMap[K, V]]
with ParHashTable[K, DefaultEntry[K, V]]
with Serializable
{
self =>
initWithContents(contents)
type Entry = scala.collection.mutable.DefaultEntry[K, V]
def this() = this(null)
override def mapCompanion: GenericParMapCompanion[ParHashMap] = ParHashMap
override def empty: ParHashMap[K, V] = new ParHashMap[K, V]
protected[this] override def newCombiner = ParHashMapCombiner[K, V]
override def seq = new scala.collection.mutable.HashMap[K, V](hashTableContents)
def splitter = new ParHashMapIterator(1, table.length, size, table(0).asInstanceOf[DefaultEntry[K, V]])
override def size = tableSize
override def clear() = clearTable()
def get(key: K): Option[V] = {
val e = findEntry(key)
if (e eq null) None
else Some(e.value)
}
def put(key: K, value: V): Option[V] = {
val e = findOrAddEntry(key, value)
if (e eq null) None
else { val v = e.value; e.value = value; Some(v) }
}
def update(key: K, value: V): Unit = put(key, value)
def remove(key: K): Option[V] = {
val e = removeEntry(key)
if (e ne null) Some(e.value)
else None
}
def += (kv: (K, V)): this.type = {
val e = findOrAddEntry(kv._1, kv._2)
if (e ne null) e.value = kv._2
this
}
def -=(key: K): this.type = { removeEntry(key); this }
override def stringPrefix = "ParHashMap"
class ParHashMapIterator(start: Int, untilIdx: Int, totalSize: Int, e: DefaultEntry[K, V])
extends EntryIterator[(K, V), ParHashMapIterator](start, untilIdx, totalSize, e) {
def entry2item(entry: DefaultEntry[K, V]) = (entry.key, entry.value);
def newIterator(idxFrom: Int, idxUntil: Int, totalSz: Int, es: DefaultEntry[K, V]) =
new ParHashMapIterator(idxFrom, idxUntil, totalSz, es)
}
protected def createNewEntry[V1](key: K, value: V1): Entry = {
new Entry(key, value.asInstanceOf[V])
}
private def writeObject(out: java.io.ObjectOutputStream) {
serializeTo(out, { entry =>
out.writeObject(entry.key)
out.writeObject(entry.value)
})
}
private def readObject(in: java.io.ObjectInputStream) {
init(in, createNewEntry(in.readObject().asInstanceOf[K], in.readObject()))
}
private[parallel] override def brokenInvariants = {
// bucket by bucket, count elements
val buckets = for (i <- 0 until (table.length / sizeMapBucketSize)) yield checkBucket(i)
// check if each element is in the position corresponding to its key
val elems = for (i <- 0 until table.length) yield checkEntry(i)
buckets.flatMap(x => x) ++ elems.flatMap(x => x)
}
private def checkBucket(i: Int) = {
def count(e: HashEntry[K, DefaultEntry[K, V]]): Int = if (e eq null) 0 else 1 + count(e.next)
val expected = sizemap(i)
val found = ((i * sizeMapBucketSize) until ((i + 1) * sizeMapBucketSize)).foldLeft(0) {
(acc, c) => acc + count(table(c))
}
if (found != expected) List("Found " + found + " elements, while sizemap showed " + expected)
else Nil
}
private def checkEntry(i: Int) = {
def check(e: HashEntry[K, DefaultEntry[K, V]]): List[String] = if (e eq null) Nil else
if (index(elemHashCode(e.key)) == i) check(e.next)
else ("Element " + e.key + " at " + i + " with " + elemHashCode(e.key) + " maps to " + index(elemHashCode(e.key))) :: check(e.next)
check(table(i))
}
}
/** $factoryInfo
* @define Coll `mutable.ParHashMap`
* @define coll parallel hash map
*/
object ParHashMap extends ParMapFactory[ParHashMap] {
var iters = 0
def empty[K, V]: ParHashMap[K, V] = new ParHashMap[K, V]
def newCombiner[K, V]: Combiner[(K, V), ParHashMap[K, V]] = ParHashMapCombiner.apply[K, V]
implicit def canBuildFrom[K, V]: CanCombineFrom[Coll, (K, V), ParHashMap[K, V]] = new CanCombineFromMap[K, V]
}
private[mutable] abstract class ParHashMapCombiner[K, V](private val tableLoadFactor: Int)
extends scala.collection.parallel.BucketCombiner[(K, V), ParHashMap[K, V], DefaultEntry[K, V], ParHashMapCombiner[K, V]](ParHashMapCombiner.numblocks)
with scala.collection.mutable.HashTable.HashUtils[K]
{
private var mask = ParHashMapCombiner.discriminantmask
private var nonmasklen = ParHashMapCombiner.nonmasklength
private var seedvalue = 27
def +=(elem: (K, V)) = {
sz += 1
val hc = improve(elemHashCode(elem._1), seedvalue)
val pos = (hc >>> nonmasklen)
if (buckets(pos) eq null) {
// initialize bucket
buckets(pos) = new UnrolledBuffer[DefaultEntry[K, V]]()
}
// add to bucket
buckets(pos) += new DefaultEntry(elem._1, elem._2)
this
}
def result: ParHashMap[K, V] = if (size >= (ParHashMapCombiner.numblocks * sizeMapBucketSize)) { // 1024
// construct table
val table = new AddingHashTable(size, tableLoadFactor, seedvalue)
val bucks = buckets.map(b => if (b ne null) b.headPtr else null)
val insertcount = combinerTaskSupport.executeAndWaitResult(new FillBlocks(bucks, table, 0, bucks.length))
table.setSize(insertcount)
// TODO compare insertcount and size to see if compression is needed
val c = table.hashTableContents
new ParHashMap(c)
} else {
// construct a normal table and fill it sequentially
// TODO parallelize by keeping separate sizemaps and merging them
object table extends HashTable[K, DefaultEntry[K, V]] {
type Entry = DefaultEntry[K, V]
def insertEntry(e: Entry) { super.findOrAddEntry(e.key, e) }
def createNewEntry[E](key: K, entry: E): Entry = entry.asInstanceOf[Entry]
sizeMapInit(table.length)
}
var i = 0
while (i < ParHashMapCombiner.numblocks) {
if (buckets(i) ne null) {
for (elem <- buckets(i)) table.insertEntry(elem)
}
i += 1
}
new ParHashMap(table.hashTableContents)
}
/* classes */
/** A hash table which will never resize itself. Knowing the number of elements in advance,
* it allocates the table of the required size when created.
*
* Entries are added using the `insertEntry` method. This method checks whether the element
* exists and updates the size map. It returns false if the key was already in the table,
* and true if the key was successfully inserted. It does not update the number of elements
* in the table.
*/
private[ParHashMapCombiner] class AddingHashTable(numelems: Int, lf: Int, _seedvalue: Int) extends HashTable[K, DefaultEntry[K, V]] {
import HashTable._
_loadFactor = lf
table = new Array[HashEntry[K, DefaultEntry[K, V]]](capacity(sizeForThreshold(_loadFactor, numelems)))
tableSize = 0
seedvalue = _seedvalue
threshold = newThreshold(_loadFactor, table.length)
sizeMapInit(table.length)
def setSize(sz: Int) = tableSize = sz
def insertEntry(/*block: Int, */e: DefaultEntry[K, V]) = {
var h = index(elemHashCode(e.key))
// assertCorrectBlock(h, block)
var olde = table(h).asInstanceOf[DefaultEntry[K, V]]
// check if key already exists
var ce = olde
while (ce ne null) {
if (ce.key == e.key) {
h = -1
ce = null
} else ce = ce.next
}
// if key does not already exist
if (h != -1) {
e.next = olde
table(h) = e
nnSizeMapAdd(h)
true
} else false
}
private def assertCorrectBlock(h: Int, block: Int) {
val blocksize = table.length / (1 << ParHashMapCombiner.discriminantbits)
if (!(h >= block * blocksize && h < (block + 1) * blocksize)) {
println("trying to put " + h + " into block no.: " + block + ", range: [" + block * blocksize + ", " + (block + 1) * blocksize + ">")
assert(h >= block * blocksize && h < (block + 1) * blocksize)
}
}
protected def createNewEntry[X](key: K, x: X) = ???
}
/* tasks */
import UnrolledBuffer.Unrolled
class FillBlocks(buckets: Array[Unrolled[DefaultEntry[K, V]]], table: AddingHashTable, offset: Int, howmany: Int)
extends Task[Int, FillBlocks] {
var result = Int.MinValue
def leaf(prev: Option[Int]) = {
var i = offset
val until = offset + howmany
result = 0
while (i < until) {
result += fillBlock(i, buckets(i))
i += 1
}
}
private def fillBlock(block: Int, elems: Unrolled[DefaultEntry[K, V]]) = {
var insertcount = 0
var unrolled = elems
var i = 0
val t = table
while (unrolled ne null) {
val chunkarr = unrolled.array
val chunksz = unrolled.size
while (i < chunksz) {
val elem = chunkarr(i)
// assertCorrectBlock(block, elem.key)
if (t.insertEntry(elem)) insertcount += 1
i += 1
}
i = 0
unrolled = unrolled.next
}
insertcount
}
private def assertCorrectBlock(block: Int, k: K) {
val hc = improve(elemHashCode(k), seedvalue)
if ((hc >>> nonmasklen) != block) {
println(hc + " goes to " + (hc >>> nonmasklen) + ", while expected block is " + block)
assert((hc >>> nonmasklen) == block)
}
}
def split = {
val fp = howmany / 2
List(new FillBlocks(buckets, table, offset, fp), new FillBlocks(buckets, table, offset + fp, howmany - fp))
}
override def merge(that: FillBlocks) {
this.result += that.result
}
def shouldSplitFurther = howmany > scala.collection.parallel.thresholdFromSize(ParHashMapCombiner.numblocks, combinerTaskSupport.parallelismLevel)
}
}
private[parallel] object ParHashMapCombiner {
private[mutable] val discriminantbits = 5
private[mutable] val numblocks = 1 << discriminantbits
private[mutable] val discriminantmask = ((1 << discriminantbits) - 1);
private[mutable] val nonmasklength = 32 - discriminantbits
def apply[K, V] = new ParHashMapCombiner[K, V](HashTable.defaultLoadFactor) {} // was: with EnvironmentPassingCombiner[(K, V), ParHashMap[K, V]]
}