scala.collection.concurrent.TrieMap.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of spark-core Show documentation
Show all versions of spark-core Show documentation
Shaded version of Apache Spark 2.x.x for Presto
The newest version!
/* __ *\
** ________ ___ / / ___ Scala API **
** / __/ __// _ | / / / _ | (c) 2003-2013, LAMP/EPFL **
** __\ \/ /__/ __ |/ /__/ __ | http://scala-lang.org/ **
** /____/\___/_/ |_/____/_/ | | **
** |/ **
\* */
package scala
package collection
package concurrent
import java.util.concurrent.atomic._
import scala.collection.immutable.{ ListMap => ImmutableListMap }
import scala.collection.parallel.mutable.ParTrieMap
import scala.util.hashing.Hashing
import scala.util.control.ControlThrowable
import generic._
import scala.annotation.tailrec
import scala.annotation.switch
private[collection] final class INode[K, V](bn: MainNode[K, V], g: Gen) extends INodeBase[K, V](g) {
import INodeBase._
WRITE(bn)
def this(g: Gen) = this(null, g)
def WRITE(nval: MainNode[K, V]) = INodeBase.updater.set(this, nval)
def CAS(old: MainNode[K, V], n: MainNode[K, V]) = INodeBase.updater.compareAndSet(this, old, n)
def gcasRead(ct: TrieMap[K, V]): MainNode[K, V] = GCAS_READ(ct)
def GCAS_READ(ct: TrieMap[K, V]): MainNode[K, V] = {
val m = /*READ*/mainnode
val prevval = /*READ*/m.prev
if (prevval eq null) m
else GCAS_Complete(m, ct)
}
@tailrec private def GCAS_Complete(m: MainNode[K, V], ct: TrieMap[K, V]): MainNode[K, V] = if (m eq null) null else {
// complete the GCAS
val prev = /*READ*/m.prev
val ctr = ct.readRoot(abort = true)
prev match {
case null =>
m
case fn: FailedNode[_, _] => // try to commit to previous value
if (CAS(m, fn.prev)) fn.prev
else GCAS_Complete(/*READ*/mainnode, ct)
case vn: MainNode[_, _] =>
// Assume that you've read the root from the generation G.
// Assume that the snapshot algorithm is correct.
// ==> you can only reach nodes in generations <= G.
// ==> `gen` is <= G.
// We know that `ctr.gen` is >= G.
// ==> if `ctr.gen` = `gen` then they are both equal to G.
// ==> otherwise, we know that either `ctr.gen` > G, `gen` < G,
// or both
if ((ctr.gen eq gen) && ct.nonReadOnly) {
// try to commit
if (m.CAS_PREV(prev, null)) m
else GCAS_Complete(m, ct)
} else {
// try to abort
m.CAS_PREV(prev, new FailedNode(prev))
GCAS_Complete(/*READ*/mainnode, ct)
}
}
}
def GCAS(old: MainNode[K, V], n: MainNode[K, V], ct: TrieMap[K, V]): Boolean = {
n.WRITE_PREV(old)
if (CAS(old, n)) {
GCAS_Complete(n, ct)
/*READ*/n.prev eq null
} else false
}
private def equal(k1: K, k2: K, ct: TrieMap[K, V]) = ct.equality.equiv(k1, k2)
private def inode(cn: MainNode[K, V]) = {
val nin = new INode[K, V](gen)
nin.WRITE(cn)
nin
}
def copyToGen(ngen: Gen, ct: TrieMap[K, V]) = {
val nin = new INode[K, V](ngen)
val main = GCAS_READ(ct)
nin.WRITE(main)
nin
}
/** Inserts a key value pair, overwriting the old pair if the keys match.
*
* @return true if successful, false otherwise
*/
@tailrec def rec_insert(k: K, v: V, hc: Int, lev: Int, parent: INode[K, V], startgen: Gen, ct: TrieMap[K, V]): Boolean = {
val m = GCAS_READ(ct) // use -Yinline!
m match {
case cn: CNode[K, V] => // 1) a multiway node
val idx = (hc >>> lev) & 0x1f
val flag = 1 << idx
val bmp = cn.bitmap
val mask = flag - 1
val pos = Integer.bitCount(bmp & mask)
if ((bmp & flag) != 0) {
// 1a) insert below
cn.array(pos) match {
case in: INode[K, V] =>
if (startgen eq in.gen) in.rec_insert(k, v, hc, lev + 5, this, startgen, ct)
else {
if (GCAS(cn, cn.renewed(startgen, ct), ct)) rec_insert(k, v, hc, lev, parent, startgen, ct)
else false
}
case sn: SNode[K, V] =>
if (sn.hc == hc && equal(sn.k, k, ct)) GCAS(cn, cn.updatedAt(pos, new SNode(k, v, hc), gen), ct)
else {
val rn = if (cn.gen eq gen) cn else cn.renewed(gen, ct)
val nn = rn.updatedAt(pos, inode(CNode.dual(sn, sn.hc, new SNode(k, v, hc), hc, lev + 5, gen)), gen)
GCAS(cn, nn, ct)
}
}
} else {
val rn = if (cn.gen eq gen) cn else cn.renewed(gen, ct)
val ncnode = rn.insertedAt(pos, flag, new SNode(k, v, hc), gen)
GCAS(cn, ncnode, ct)
}
case tn: TNode[K, V] =>
clean(parent, ct, lev - 5)
false
case ln: LNode[K, V] => // 3) an l-node
val nn = ln.inserted(k, v)
GCAS(ln, nn, ct)
}
}
/** Inserts a new key value pair, given that a specific condition is met.
*
* @param cond null - don't care if the key was there; KEY_ABSENT - key wasn't there; KEY_PRESENT - key was there; other value `v` - key must be bound to `v`
* @return null if unsuccessful, Option[V] otherwise (indicating previous value bound to the key)
*/
@tailrec def rec_insertif(k: K, v: V, hc: Int, cond: AnyRef, lev: Int, parent: INode[K, V], startgen: Gen, ct: TrieMap[K, V]): Option[V] = {
val m = GCAS_READ(ct) // use -Yinline!
m match {
case cn: CNode[K, V] => // 1) a multiway node
val idx = (hc >>> lev) & 0x1f
val flag = 1 << idx
val bmp = cn.bitmap
val mask = flag - 1
val pos = Integer.bitCount(bmp & mask)
if ((bmp & flag) != 0) {
// 1a) insert below
cn.array(pos) match {
case in: INode[K, V] =>
if (startgen eq in.gen) in.rec_insertif(k, v, hc, cond, lev + 5, this, startgen, ct)
else {
if (GCAS(cn, cn.renewed(startgen, ct), ct)) rec_insertif(k, v, hc, cond, lev, parent, startgen, ct)
else null
}
case sn: SNode[K, V] => cond match {
case null =>
if (sn.hc == hc && equal(sn.k, k, ct)) {
if (GCAS(cn, cn.updatedAt(pos, new SNode(k, v, hc), gen), ct)) Some(sn.v) else null
} else {
val rn = if (cn.gen eq gen) cn else cn.renewed(gen, ct)
val nn = rn.updatedAt(pos, inode(CNode.dual(sn, sn.hc, new SNode(k, v, hc), hc, lev + 5, gen)), gen)
if (GCAS(cn, nn, ct)) None
else null
}
case INode.KEY_ABSENT =>
if (sn.hc == hc && equal(sn.k, k, ct)) Some(sn.v)
else {
val rn = if (cn.gen eq gen) cn else cn.renewed(gen, ct)
val nn = rn.updatedAt(pos, inode(CNode.dual(sn, sn.hc, new SNode(k, v, hc), hc, lev + 5, gen)), gen)
if (GCAS(cn, nn, ct)) None
else null
}
case INode.KEY_PRESENT =>
if (sn.hc == hc && equal(sn.k, k, ct)) {
if (GCAS(cn, cn.updatedAt(pos, new SNode(k, v, hc), gen), ct)) Some(sn.v) else null
} else None
case otherv =>
if (sn.hc == hc && equal(sn.k, k, ct) && sn.v == otherv) {
if (GCAS(cn, cn.updatedAt(pos, new SNode(k, v, hc), gen), ct)) Some(sn.v) else null
} else None
}
}
} else cond match {
case null | INode.KEY_ABSENT =>
val rn = if (cn.gen eq gen) cn else cn.renewed(gen, ct)
val ncnode = rn.insertedAt(pos, flag, new SNode(k, v, hc), gen)
if (GCAS(cn, ncnode, ct)) None else null
case INode.KEY_PRESENT => None
case otherv => None
}
case sn: TNode[K, V] =>
clean(parent, ct, lev - 5)
null
case ln: LNode[K, V] => // 3) an l-node
def insertln() = {
val nn = ln.inserted(k, v)
GCAS(ln, nn, ct)
}
cond match {
case null =>
val optv = ln.get(k)
if (insertln()) optv else null
case INode.KEY_ABSENT =>
ln.get(k) match {
case None => if (insertln()) None else null
case optv => optv
}
case INode.KEY_PRESENT =>
ln.get(k) match {
case Some(v0) => if (insertln()) Some(v0) else null
case None => None
}
case otherv =>
ln.get(k) match {
case Some(v0) if v0 == otherv => if (insertln()) Some(otherv.asInstanceOf[V]) else null
case _ => None
}
}
}
}
/** Looks up the value associated with the key.
*
* @return null if no value has been found, RESTART if the operation wasn't successful, or any other value otherwise
*/
@tailrec def rec_lookup(k: K, hc: Int, lev: Int, parent: INode[K, V], startgen: Gen, ct: TrieMap[K, V]): AnyRef = {
val m = GCAS_READ(ct) // use -Yinline!
m match {
case cn: CNode[K, V] => // 1) a multinode
val idx = (hc >>> lev) & 0x1f
val flag = 1 << idx
val bmp = cn.bitmap
if ((bmp & flag) == 0) null // 1a) bitmap shows no binding
else { // 1b) bitmap contains a value - descend
val pos = if (bmp == 0xffffffff) idx else Integer.bitCount(bmp & (flag - 1))
val sub = cn.array(pos)
sub match {
case in: INode[K, V] =>
if (ct.isReadOnly || (startgen eq in.gen)) in.rec_lookup(k, hc, lev + 5, this, startgen, ct)
else {
if (GCAS(cn, cn.renewed(startgen, ct), ct)) rec_lookup(k, hc, lev, parent, startgen, ct)
else RESTART // used to be throw RestartException
}
case sn: SNode[K, V] => // 2) singleton node
if (sn.hc == hc && equal(sn.k, k, ct)) sn.v.asInstanceOf[AnyRef]
else null
}
}
case tn: TNode[K, V] => // 3) non-live node
def cleanReadOnly(tn: TNode[K, V]) = if (ct.nonReadOnly) {
clean(parent, ct, lev - 5)
RESTART // used to be throw RestartException
} else {
if (tn.hc == hc && tn.k == k) tn.v.asInstanceOf[AnyRef]
else null
}
cleanReadOnly(tn)
case ln: LNode[K, V] => // 5) an l-node
ln.get(k).asInstanceOf[Option[AnyRef]].orNull
}
}
/** Removes the key associated with the given value.
*
* @param v if null, will remove the key irregardless of the value; otherwise removes only if binding contains that exact key and value
* @return null if not successful, an Option[V] indicating the previous value otherwise
*/
def rec_remove(k: K, v: V, hc: Int, lev: Int, parent: INode[K, V], startgen: Gen, ct: TrieMap[K, V]): Option[V] = {
val m = GCAS_READ(ct) // use -Yinline!
m match {
case cn: CNode[K, V] =>
val idx = (hc >>> lev) & 0x1f
val bmp = cn.bitmap
val flag = 1 << idx
if ((bmp & flag) == 0) None
else {
val pos = Integer.bitCount(bmp & (flag - 1))
val sub = cn.array(pos)
val res = sub match {
case in: INode[K, V] =>
if (startgen eq in.gen) in.rec_remove(k, v, hc, lev + 5, this, startgen, ct)
else {
if (GCAS(cn, cn.renewed(startgen, ct), ct)) rec_remove(k, v, hc, lev, parent, startgen, ct)
else null
}
case sn: SNode[K, V] =>
if (sn.hc == hc && equal(sn.k, k, ct) && (v == null || sn.v == v)) {
val ncn = cn.removedAt(pos, flag, gen).toContracted(lev)
if (GCAS(cn, ncn, ct)) Some(sn.v) else null
} else None
}
if (res == None || (res eq null)) res
else {
@tailrec def cleanParent(nonlive: AnyRef) {
val pm = parent.GCAS_READ(ct)
pm match {
case cn: CNode[K, V] =>
val idx = (hc >>> (lev - 5)) & 0x1f
val bmp = cn.bitmap
val flag = 1 << idx
if ((bmp & flag) == 0) {} // somebody already removed this i-node, we're done
else {
val pos = Integer.bitCount(bmp & (flag - 1))
val sub = cn.array(pos)
if (sub eq this) nonlive match {
case tn: TNode[K, V] =>
val ncn = cn.updatedAt(pos, tn.copyUntombed, gen).toContracted(lev - 5)
if (!parent.GCAS(cn, ncn, ct))
if (ct.readRoot().gen == startgen) cleanParent(nonlive)
}
}
case _ => // parent is no longer a cnode, we're done
}
}
if (parent ne null) { // never tomb at root
val n = GCAS_READ(ct)
if (n.isInstanceOf[TNode[_, _]])
cleanParent(n)
}
res
}
}
case tn: TNode[K, V] =>
clean(parent, ct, lev - 5)
null
case ln: LNode[K, V] =>
if (v == null) {
val optv = ln.get(k)
val nn = ln.removed(k, ct)
if (GCAS(ln, nn, ct)) optv else null
} else ln.get(k) match {
case optv @ Some(v0) if v0 == v =>
val nn = ln.removed(k, ct)
if (GCAS(ln, nn, ct)) optv else null
case _ => None
}
}
}
private def clean(nd: INode[K, V], ct: TrieMap[K, V], lev: Int) {
val m = nd.GCAS_READ(ct)
m match {
case cn: CNode[K, V] => nd.GCAS(cn, cn.toCompressed(ct, lev, gen), ct)
case _ =>
}
}
def isNullInode(ct: TrieMap[K, V]) = GCAS_READ(ct) eq null
def cachedSize(ct: TrieMap[K, V]): Int = {
val m = GCAS_READ(ct)
m.cachedSize(ct)
}
/* this is a quiescent method! */
def string(lev: Int) = "%sINode -> %s".format(" " * lev, mainnode match {
case null => ""
case tn: TNode[_, _] => "TNode(%s, %s, %d, !)".format(tn.k, tn.v, tn.hc)
case cn: CNode[_, _] => cn.string(lev)
case ln: LNode[_, _] => ln.string(lev)
case x => "".format(x)
})
}
private[concurrent] object INode {
val KEY_PRESENT = new AnyRef
val KEY_ABSENT = new AnyRef
def newRootNode[K, V] = {
val gen = new Gen
val cn = new CNode[K, V](0, new Array(0), gen)
new INode[K, V](cn, gen)
}
}
private[concurrent] final class FailedNode[K, V](p: MainNode[K, V]) extends MainNode[K, V] {
WRITE_PREV(p)
def string(lev: Int) = throw new UnsupportedOperationException
def cachedSize(ct: AnyRef): Int = throw new UnsupportedOperationException
override def toString = "FailedNode(%s)".format(p)
}
private[concurrent] trait KVNode[K, V] {
def kvPair: (K, V)
}
private[collection] final class SNode[K, V](final val k: K, final val v: V, final val hc: Int)
extends BasicNode with KVNode[K, V] {
final def copy = new SNode(k, v, hc)
final def copyTombed = new TNode(k, v, hc)
final def copyUntombed = new SNode(k, v, hc)
final def kvPair = (k, v)
final def string(lev: Int) = (" " * lev) + "SNode(%s, %s, %x)".format(k, v, hc)
}
private[collection] final class TNode[K, V](final val k: K, final val v: V, final val hc: Int)
extends MainNode[K, V] with KVNode[K, V] {
final def copy = new TNode(k, v, hc)
final def copyTombed = new TNode(k, v, hc)
final def copyUntombed = new SNode(k, v, hc)
final def kvPair = (k, v)
final def cachedSize(ct: AnyRef): Int = 1
final def string(lev: Int) = (" " * lev) + "TNode(%s, %s, %x, !)".format(k, v, hc)
}
private[collection] final class LNode[K, V](final val listmap: immutable.ListMap[K, V])
extends MainNode[K, V] {
def this(k: K, v: V) = this(immutable.ListMap(k -> v))
def this(k1: K, v1: V, k2: K, v2: V) = this(immutable.ListMap(k1 -> v1, k2 -> v2))
def inserted(k: K, v: V) = new LNode(listmap + ((k, v)))
def removed(k: K, ct: TrieMap[K, V]): MainNode[K, V] = {
val updmap = listmap - k
if (updmap.size > 1) new LNode(updmap)
else {
val (k, v) = updmap.iterator.next()
new TNode(k, v, ct.computeHash(k)) // create it tombed so that it gets compressed on subsequent accesses
}
}
def get(k: K) = listmap.get(k)
def cachedSize(ct: AnyRef): Int = listmap.size
def string(lev: Int) = (" " * lev) + "LNode(%s)".format(listmap.mkString(", "))
}
private[collection] final class CNode[K, V](val bitmap: Int, val array: Array[BasicNode], val gen: Gen) extends CNodeBase[K, V] {
// this should only be called from within read-only snapshots
def cachedSize(ct: AnyRef) = {
val currsz = READ_SIZE()
if (currsz != -1) currsz
else {
val sz = computeSize(ct.asInstanceOf[TrieMap[K, V]])
while (READ_SIZE() == -1) CAS_SIZE(-1, sz)
READ_SIZE()
}
}
// lends itself towards being parallelizable by choosing
// a random starting offset in the array
// => if there are concurrent size computations, they start
// at different positions, so they are more likely to
// to be independent
private def computeSize(ct: TrieMap[K, V]): Int = {
var i = 0
var sz = 0
val offset =
if (array.length > 0)
//util.Random.nextInt(array.length) /* <-- benchmarks show that this causes observable contention */
scala.concurrent.forkjoin.ThreadLocalRandom.current.nextInt(0, array.length)
else 0
while (i < array.length) {
val pos = (i + offset) % array.length
array(pos) match {
case sn: SNode[_, _] => sz += 1
case in: INode[K, V] => sz += in.cachedSize(ct)
}
i += 1
}
sz
}
def updatedAt(pos: Int, nn: BasicNode, gen: Gen) = {
val len = array.length
val narr = new Array[BasicNode](len)
Array.copy(array, 0, narr, 0, len)
narr(pos) = nn
new CNode[K, V](bitmap, narr, gen)
}
def removedAt(pos: Int, flag: Int, gen: Gen) = {
val arr = array
val len = arr.length
val narr = new Array[BasicNode](len - 1)
Array.copy(arr, 0, narr, 0, pos)
Array.copy(arr, pos + 1, narr, pos, len - pos - 1)
new CNode[K, V](bitmap ^ flag, narr, gen)
}
def insertedAt(pos: Int, flag: Int, nn: BasicNode, gen: Gen) = {
val len = array.length
val bmp = bitmap
val narr = new Array[BasicNode](len + 1)
Array.copy(array, 0, narr, 0, pos)
narr(pos) = nn
Array.copy(array, pos, narr, pos + 1, len - pos)
new CNode[K, V](bmp | flag, narr, gen)
}
/** Returns a copy of this cnode such that all the i-nodes below it are copied
* to the specified generation `ngen`.
*/
def renewed(ngen: Gen, ct: TrieMap[K, V]) = {
var i = 0
val arr = array
val len = arr.length
val narr = new Array[BasicNode](len)
while (i < len) {
arr(i) match {
case in: INode[K, V] => narr(i) = in.copyToGen(ngen, ct)
case bn: BasicNode => narr(i) = bn
}
i += 1
}
new CNode[K, V](bitmap, narr, ngen)
}
private def resurrect(inode: INode[K, V], inodemain: AnyRef): BasicNode = inodemain match {
case tn: TNode[_, _] => tn.copyUntombed
case _ => inode
}
def toContracted(lev: Int): MainNode[K, V] = if (array.length == 1 && lev > 0) array(0) match {
case sn: SNode[K, V] => sn.copyTombed
case _ => this
} else this
// - if the branching factor is 1 for this CNode, and the child
// is a tombed SNode, returns its tombed version
// - otherwise, if there is at least one non-null node below,
// returns the version of this node with at least some null-inodes
// removed (those existing when the op began)
// - if there are only null-i-nodes below, returns null
def toCompressed(ct: TrieMap[K, V], lev: Int, gen: Gen) = {
val bmp = bitmap
var i = 0
val arr = array
val tmparray = new Array[BasicNode](arr.length)
while (i < arr.length) { // construct new bitmap
val sub = arr(i)
sub match {
case in: INode[K, V] =>
val inodemain = in.gcasRead(ct)
assert(inodemain ne null)
tmparray(i) = resurrect(in, inodemain)
case sn: SNode[K, V] =>
tmparray(i) = sn
}
i += 1
}
new CNode[K, V](bmp, tmparray, gen).toContracted(lev)
}
private[concurrent] def string(lev: Int): String = "CNode %x\n%s".format(bitmap, array.map(_.string(lev + 1)).mkString("\n"))
/* quiescently consistent - don't call concurrently to anything involving a GCAS!! */
private def collectElems: Seq[(K, V)] = array flatMap {
case sn: SNode[K, V] => Some(sn.kvPair)
case in: INode[K, V] => in.mainnode match {
case tn: TNode[K, V] => Some(tn.kvPair)
case ln: LNode[K, V] => ln.listmap.toList
case cn: CNode[K, V] => cn.collectElems
}
}
private def collectLocalElems: Seq[String] = array flatMap {
case sn: SNode[K, V] => Some(sn.kvPair._2.toString)
case in: INode[K, V] => Some(in.toString.drop(14) + "(" + in.gen + ")")
}
override def toString = {
val elems = collectLocalElems
"CNode(sz: %d; %s)".format(elems.size, elems.sorted.mkString(", "))
}
}
private[concurrent] object CNode {
def dual[K, V](x: SNode[K, V], xhc: Int, y: SNode[K, V], yhc: Int, lev: Int, gen: Gen): MainNode[K, V] = if (lev < 35) {
val xidx = (xhc >>> lev) & 0x1f
val yidx = (yhc >>> lev) & 0x1f
val bmp = (1 << xidx) | (1 << yidx)
if (xidx == yidx) {
val subinode = new INode[K, V](gen)//(TrieMap.inodeupdater)
subinode.mainnode = dual(x, xhc, y, yhc, lev + 5, gen)
new CNode(bmp, Array(subinode), gen)
} else {
if (xidx < yidx) new CNode(bmp, Array(x, y), gen)
else new CNode(bmp, Array(y, x), gen)
}
} else {
new LNode(x.k, x.v, y.k, y.v)
}
}
private[concurrent] case class RDCSS_Descriptor[K, V](old: INode[K, V], expectedmain: MainNode[K, V], nv: INode[K, V]) {
@volatile var committed = false
}
/** A concurrent hash-trie or TrieMap is a concurrent thread-safe lock-free
* implementation of a hash array mapped trie. It is used to implement the
* concurrent map abstraction. It has particularly scalable concurrent insert
* and remove operations and is memory-efficient. It supports O(1), atomic,
* lock-free snapshots which are used to implement linearizable lock-free size,
* iterator and clear operations. The cost of evaluating the (lazy) snapshot is
* distributed across subsequent updates, thus making snapshot evaluation horizontally scalable.
*
* For details, see: http://lampwww.epfl.ch/~prokopec/ctries-snapshot.pdf
*
* @author Aleksandar Prokopec
* @since 2.10
*/
@SerialVersionUID(0L - 6402774413839597105L)
final class TrieMap[K, V] private (r: AnyRef, rtupd: AtomicReferenceFieldUpdater[TrieMap[K, V], AnyRef], hashf: Hashing[K], ef: Equiv[K])
extends scala.collection.concurrent.Map[K, V]
with scala.collection.mutable.MapLike[K, V, TrieMap[K, V]]
with CustomParallelizable[(K, V), ParTrieMap[K, V]]
with Serializable
{
private var hashingobj = if (hashf.isInstanceOf[Hashing.Default[_]]) new TrieMap.MangledHashing[K] else hashf
private var equalityobj = ef
private var rootupdater = rtupd
def hashing = hashingobj
def equality = equalityobj
@volatile var root = r
def this(hashf: Hashing[K], ef: Equiv[K]) = this(
INode.newRootNode,
AtomicReferenceFieldUpdater.newUpdater(classOf[TrieMap[K, V]], classOf[AnyRef], "root"),
hashf,
ef
)
def this() = this(Hashing.default, Equiv.universal)
/* internal methods */
private def writeObject(out: java.io.ObjectOutputStream) {
out.writeObject(hashingobj)
out.writeObject(equalityobj)
val it = iterator
while (it.hasNext) {
val (k, v) = it.next()
out.writeObject(k)
out.writeObject(v)
}
out.writeObject(TrieMapSerializationEnd)
}
private def readObject(in: java.io.ObjectInputStream) {
root = INode.newRootNode
rootupdater = AtomicReferenceFieldUpdater.newUpdater(classOf[TrieMap[K, V]], classOf[AnyRef], "root")
hashingobj = in.readObject().asInstanceOf[Hashing[K]]
equalityobj = in.readObject().asInstanceOf[Equiv[K]]
var obj: AnyRef = null
do {
obj = in.readObject()
if (obj != TrieMapSerializationEnd) {
val k = obj.asInstanceOf[K]
val v = in.readObject().asInstanceOf[V]
update(k, v)
}
} while (obj != TrieMapSerializationEnd)
}
def CAS_ROOT(ov: AnyRef, nv: AnyRef) = rootupdater.compareAndSet(this, ov, nv)
def readRoot(abort: Boolean = false): INode[K, V] = RDCSS_READ_ROOT(abort)
def RDCSS_READ_ROOT(abort: Boolean = false): INode[K, V] = {
val r = /*READ*/root
r match {
case in: INode[K, V] => in
case desc: RDCSS_Descriptor[K, V] => RDCSS_Complete(abort)
}
}
@tailrec private def RDCSS_Complete(abort: Boolean): INode[K, V] = {
val v = /*READ*/root
v match {
case in: INode[K, V] => in
case desc: RDCSS_Descriptor[K, V] =>
val RDCSS_Descriptor(ov, exp, nv) = desc
if (abort) {
if (CAS_ROOT(desc, ov)) ov
else RDCSS_Complete(abort)
} else {
val oldmain = ov.gcasRead(this)
if (oldmain eq exp) {
if (CAS_ROOT(desc, nv)) {
desc.committed = true
nv
} else RDCSS_Complete(abort)
} else {
if (CAS_ROOT(desc, ov)) ov
else RDCSS_Complete(abort)
}
}
}
}
private def RDCSS_ROOT(ov: INode[K, V], expectedmain: MainNode[K, V], nv: INode[K, V]): Boolean = {
val desc = RDCSS_Descriptor(ov, expectedmain, nv)
if (CAS_ROOT(ov, desc)) {
RDCSS_Complete(abort = false)
/*READ*/desc.committed
} else false
}
@tailrec private def inserthc(k: K, hc: Int, v: V) {
val r = RDCSS_READ_ROOT()
if (!r.rec_insert(k, v, hc, 0, null, r.gen, this)) inserthc(k, hc, v)
}
@tailrec private def insertifhc(k: K, hc: Int, v: V, cond: AnyRef): Option[V] = {
val r = RDCSS_READ_ROOT()
val ret = r.rec_insertif(k, v, hc, cond, 0, null, r.gen, this)
if (ret eq null) insertifhc(k, hc, v, cond)
else ret
}
@tailrec private def lookuphc(k: K, hc: Int): AnyRef = {
val r = RDCSS_READ_ROOT()
val res = r.rec_lookup(k, hc, 0, null, r.gen, this)
if (res eq INodeBase.RESTART) lookuphc(k, hc)
else res
}
/* slower:
//@tailrec
private def lookuphc(k: K, hc: Int): AnyRef = {
val r = RDCSS_READ_ROOT()
try {
r.rec_lookup(k, hc, 0, null, r.gen, this)
} catch {
case RestartException =>
lookuphc(k, hc)
}
}
*/
@tailrec private def removehc(k: K, v: V, hc: Int): Option[V] = {
val r = RDCSS_READ_ROOT()
val res = r.rec_remove(k, v, hc, 0, null, r.gen, this)
if (res ne null) res
else removehc(k, v, hc)
}
def string = RDCSS_READ_ROOT().string(0)
/* public methods */
override def seq = this
override def par = new ParTrieMap(this)
override def empty: TrieMap[K, V] = new TrieMap[K, V]
def isReadOnly = rootupdater eq null
def nonReadOnly = rootupdater ne null
/** Returns a snapshot of this TrieMap.
* This operation is lock-free and linearizable.
*
* The snapshot is lazily updated - the first time some branch
* in the snapshot or this TrieMap are accessed, they are rewritten.
* This means that the work of rebuilding both the snapshot and this
* TrieMap is distributed across all the threads doing updates or accesses
* subsequent to the snapshot creation.
*/
@tailrec def snapshot(): TrieMap[K, V] = {
val r = RDCSS_READ_ROOT()
val expmain = r.gcasRead(this)
if (RDCSS_ROOT(r, expmain, r.copyToGen(new Gen, this))) new TrieMap(r.copyToGen(new Gen, this), rootupdater, hashing, equality)
else snapshot()
}
/** Returns a read-only snapshot of this TrieMap.
* This operation is lock-free and linearizable.
*
* The snapshot is lazily updated - the first time some branch
* of this TrieMap are accessed, it is rewritten. The work of creating
* the snapshot is thus distributed across subsequent updates
* and accesses on this TrieMap by all threads.
* Note that the snapshot itself is never rewritten unlike when calling
* the `snapshot` method, but the obtained snapshot cannot be modified.
*
* This method is used by other methods such as `size` and `iterator`.
*/
@tailrec def readOnlySnapshot(): scala.collection.Map[K, V] = {
val r = RDCSS_READ_ROOT()
val expmain = r.gcasRead(this)
if (RDCSS_ROOT(r, expmain, r.copyToGen(new Gen, this))) new TrieMap(r, null, hashing, equality)
else readOnlySnapshot()
}
@tailrec override def clear() {
val r = RDCSS_READ_ROOT()
if (!RDCSS_ROOT(r, r.gcasRead(this), INode.newRootNode[K, V])) clear()
}
def computeHash(k: K) = hashingobj.hash(k)
def lookup(k: K): V = {
val hc = computeHash(k)
lookuphc(k, hc).asInstanceOf[V]
}
override def apply(k: K): V = {
val hc = computeHash(k)
val res = lookuphc(k, hc)
if (res eq null) throw new NoSuchElementException
else res.asInstanceOf[V]
}
def get(k: K): Option[V] = {
val hc = computeHash(k)
Option(lookuphc(k, hc)).asInstanceOf[Option[V]]
}
override def put(key: K, value: V): Option[V] = {
val hc = computeHash(key)
insertifhc(key, hc, value, null)
}
override def update(k: K, v: V) {
val hc = computeHash(k)
inserthc(k, hc, v)
}
def +=(kv: (K, V)) = {
update(kv._1, kv._2)
this
}
override def remove(k: K): Option[V] = {
val hc = computeHash(k)
removehc(k, null.asInstanceOf[V], hc)
}
def -=(k: K) = {
remove(k)
this
}
def putIfAbsent(k: K, v: V): Option[V] = {
val hc = computeHash(k)
insertifhc(k, hc, v, INode.KEY_ABSENT)
}
// TODO once computeIfAbsent is added to concurrent.Map,
// move the comment there and tweak the 'at most once' part
/** If the specified key is not already in the map, computes its value using
* the given thunk `op` and enters it into the map.
*
* Since concurrent maps cannot contain `null` for keys or values,
* a `NullPointerException` is thrown if the thunk `op`
* returns `null`.
*
* If the specified mapping function throws an exception,
* that exception is rethrown.
*
* Note: This method will invoke op at most once.
* However, `op` may be invoked without the result being added to the map if
* a concurrent process is also trying to add a value corresponding to the
* same key `k`.
*
* @param k the key to modify
* @param op the expression that computes the value
* @return the newly added value
*/
override def getOrElseUpdate(k: K, op: =>V): V = {
val oldv = lookup(k)
if (oldv != null) oldv.asInstanceOf[V]
else {
val v = op
if (v == null) {
throw new NullPointerException("Concurrent TrieMap values cannot be null.")
} else {
val hc = computeHash(k)
insertifhc(k, hc, v, INode.KEY_ABSENT) match {
case Some(oldv) => oldv
case None => v
}
}
}
}
def remove(k: K, v: V): Boolean = {
val hc = computeHash(k)
removehc(k, v, hc).nonEmpty
}
def replace(k: K, oldvalue: V, newvalue: V): Boolean = {
val hc = computeHash(k)
insertifhc(k, hc, newvalue, oldvalue.asInstanceOf[AnyRef]).nonEmpty
}
def replace(k: K, v: V): Option[V] = {
val hc = computeHash(k)
insertifhc(k, hc, v, INode.KEY_PRESENT)
}
def iterator: Iterator[(K, V)] =
if (nonReadOnly) readOnlySnapshot().iterator
else new TrieMapIterator(0, this)
private def cachedSize() = {
val r = RDCSS_READ_ROOT()
r.cachedSize(this)
}
override def size: Int =
if (nonReadOnly) readOnlySnapshot().size
else cachedSize()
override def stringPrefix = "TrieMap"
}
object TrieMap extends MutableMapFactory[TrieMap] {
val inodeupdater = AtomicReferenceFieldUpdater.newUpdater(classOf[INodeBase[_, _]], classOf[MainNode[_, _]], "mainnode")
implicit def canBuildFrom[K, V]: CanBuildFrom[Coll, (K, V), TrieMap[K, V]] = new MapCanBuildFrom[K, V]
def empty[K, V]: TrieMap[K, V] = new TrieMap[K, V]
class MangledHashing[K] extends Hashing[K] {
def hash(k: K)= scala.util.hashing.byteswap32(k.##)
}
}
private[collection] class TrieMapIterator[K, V](var level: Int, private var ct: TrieMap[K, V], mustInit: Boolean = true) extends Iterator[(K, V)] {
private val stack = new Array[Array[BasicNode]](7)
private val stackpos = new Array[Int](7)
private var depth = -1
private var subiter: Iterator[(K, V)] = null
private var current: KVNode[K, V] = null
if (mustInit) initialize()
def hasNext = (current ne null) || (subiter ne null)
def next() = if (hasNext) {
var r: (K, V) = null
if (subiter ne null) {
r = subiter.next()
checkSubiter()
} else {
r = current.kvPair
advance()
}
r
} else Iterator.empty.next()
private def readin(in: INode[K, V]) = in.gcasRead(ct) match {
case cn: CNode[K, V] =>
depth += 1
stack(depth) = cn.array
stackpos(depth) = -1
advance()
case tn: TNode[K, V] =>
current = tn
case ln: LNode[K, V] =>
subiter = ln.listmap.iterator
checkSubiter()
case null =>
current = null
}
private def checkSubiter() = if (!subiter.hasNext) {
subiter = null
advance()
}
private def initialize() {
assert(ct.isReadOnly)
val r = ct.RDCSS_READ_ROOT()
readin(r)
}
def advance(): Unit = if (depth >= 0) {
val npos = stackpos(depth) + 1
if (npos < stack(depth).length) {
stackpos(depth) = npos
stack(depth)(npos) match {
case sn: SNode[K, V] =>
current = sn
case in: INode[K, V] =>
readin(in)
}
} else {
depth -= 1
advance()
}
} else current = null
protected def newIterator(_lev: Int, _ct: TrieMap[K, V], _mustInit: Boolean) = new TrieMapIterator[K, V](_lev, _ct, _mustInit)
protected def dupTo(it: TrieMapIterator[K, V]) = {
it.level = this.level
it.ct = this.ct
it.depth = this.depth
it.current = this.current
// these need a deep copy
Array.copy(this.stack, 0, it.stack, 0, 7)
Array.copy(this.stackpos, 0, it.stackpos, 0, 7)
// this one needs to be evaluated
if (this.subiter == null) it.subiter = null
else {
val lst = this.subiter.toList
this.subiter = lst.iterator
it.subiter = lst.iterator
}
}
/** Returns a sequence of iterators over subsets of this iterator.
* It's used to ease the implementation of splitters for a parallel version of the TrieMap.
*/
protected def subdivide(): Seq[Iterator[(K, V)]] = if (subiter ne null) {
// the case where an LNode is being iterated
val it = newIterator(level + 1, ct, _mustInit = false)
it.depth = -1
it.subiter = this.subiter
it.current = null
this.subiter = null
advance()
this.level += 1
Seq(it, this)
} else if (depth == -1) {
this.level += 1
Seq(this)
} else {
var d = 0
while (d <= depth) {
val rem = stack(d).length - 1 - stackpos(d)
if (rem > 0) {
val (arr1, arr2) = stack(d).drop(stackpos(d) + 1).splitAt(rem / 2)
stack(d) = arr1
stackpos(d) = -1
val it = newIterator(level + 1, ct, _mustInit = false)
it.stack(0) = arr2
it.stackpos(0) = -1
it.depth = 0
it.advance() // <-- fix it
this.level += 1
return Seq(this, it)
}
d += 1
}
this.level += 1
Seq(this)
}
def printDebug() {
println("ctrie iterator")
println(stackpos.mkString(","))
println("depth: " + depth)
println("curr.: " + current)
println(stack.mkString("\n"))
}
}
private[concurrent] object RestartException extends ControlThrowable
/** Only used for ctrie serialization. */
@SerialVersionUID(0L - 7237891413820527142L)
private[concurrent] case object TrieMapSerializationEnd
private[concurrent] object Debug {
import scala.collection._
lazy val logbuffer = new java.util.concurrent.ConcurrentLinkedQueue[AnyRef]
def log(s: AnyRef) = logbuffer.add(s)
def flush() {
for (s <- JavaConversions.asScalaIterator(logbuffer.iterator())) Console.out.println(s.toString)
logbuffer.clear()
}
def clear() {
logbuffer.clear()
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy