dev.tauri.choam.internal.mcas.Hamt.scala Maven / Gradle / Ivy
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright 2016-2024 Daniel Urban and contributors listed in NOTICE.txt
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dev.tauri.choam
package internal
package mcas
import java.util.Arrays
/**
* Immutable HAMT (hash array mapped trie)
*
* HAMTs are described in "Ideal Hash Trees" by Phil Bagwell
* (https://infoscience.epfl.ch/record/64398/files/idealhashtrees.pdf).
* An improved encoding called CHAMP (compressed hash array mapped
* prefix trie) is described in "Optimizing Hash-Array Mapped Tries for
* Fast and Lean Immutable JVM Collections" by Michael J. Steindorfer
* and Jurgen J. Vinju (https://ir.cwi.nl/pub/24029/24029B.pdf).
* (TODO: Currently we don't use CHAMP here, because we need an iteration
* order that is globally consistent, and it's not clear how to do
* that in a CHAMP.)
*
* Unlike most HAMTs, we use 64-bit hashes in a 64-ary tree. (We can
* do this, because keys eventually will be `Ref` IDs, which are `Long`s.)
* We're using the IDs directly, without hashing, since they're already
* generated with Fibonacci hashing. We don't store the keys separately,
* since we can always get them from the values (which will be HWDs).
* Values can't be `null` (we use `null` as the "not found" sentinel).
* Deletion is not implemented, because we don't need it. We also don't
* implement collision nodes, because `Ref` IDs are globally unique,
* so it's not possible to have a collision on every level or the tree.
*
* This class contains the generic HAMT implementation. Abstract methods
* are provided for the "MCAS-specific" things. The API is weird, because
* we need this separation, but really don't want to allocate unnecessary
* objects. (Overall, this is not really a general HAMT, although it's
* somewhat abstracted away from the MCAS details.) For the MCAS-specific
* parts, see `LogMap2`.
*
* Type parameters are as follows:
* `K` is the type of keys.
* `V` is the type of values in the map (i.e., HWDs; keys/hashes are `Long`s).
* `E` is they type `toArray` converts the values to (`emcas.WordDescriptor[A]`
* on the JVM, and equals to `A` on JS).
* `T1` is the type of the "extra" value passed to `toArray`, which it just
* passes on to `convertForArray` (this is the parent `EmcasDescriptor`).
* `T2` is similarly the type of the "extra" value passed to `forAll`, which
* it just passes on to `predicateForForAll` (an `Mcas.ThreadContext` to
* implement revalidation).
* `H` is the self-type (we use F-bounded polymorphism here, because we need
* to create new nodes on insert/update, and `Hamt` is also the type of the
* sub-nodes, not just the root).
*
* Public methods are the "external" API. We take care never to call them
* on a node in lower levels (they assume they're called on the root).
*/
private[mcas] abstract class Hamt[K <: Hamt.HasHash, V <: Hamt.HasKey[K], E, T1, T2, H <: Hamt[K, V, E, T1, T2, H]] protected[mcas] (
private val sizeAndBlue: Int,
/**
* Contains 1 bits in exactly the places where the imaginary 64-element
* sparse array has "something" (either a value or a sub-node).
*/
private val bitmap: Long,
/**
* The dense array containing the values and/or sub-nodes. At most
* 64-element long, but shorter for a "not full" node. Can be a
* zero-element array (only for the root node of an empty tree).
*/
private val contents: Array[AnyRef],
) extends AbstractHamt[K, V, E, T1, T2, H] { this: H =>
/**
* The highest 6 bits set; we start masking
* the hash with this, and go to lower bits
* as we go down in the tree.
*
* We use the highest 6 bits first, and
* lower ones later, because for a
* multiplicative hash (like the Fibonacci
* we're using to generate IDs) the low
* bits are the "worst quality" ones (see
* https://stackoverflow.com/a/11872511).
*/
private[this] final val START_MASK = 0xFC00000000000000L
/** 6 bits for indexing a 64-element array */
private[this] final val W = 6
private[this] final val OP_UPDATE = 0
private[this] final val OP_INSERT = 1
private[this] final val OP_UPSERT = 2
protected def newNode(sizeAndBlue: Int, bitmap: Long, contents: Array[AnyRef]): H
protected final override def contentsArr: Array[AnyRef] =
this.contents
protected final override def insertInternal(v: V): H =
this.inserted(v)
private[mcas] final def isBlueSubtree: Boolean = {
this.sizeAndBlue >= 0
}
// API (should only be called on a root node!):
/**
* The number of values in `this` subtree (i.e., if `this` is the
* root, then this number is size of the whole tree).
*/
final override def size: Int = {
java.lang.Math.abs(this.sizeAndBlue)
}
final def nonEmpty: Boolean = {
this.size > 0
}
final def getOrElseNull(hash: Long): V = {
this.lookupOrNull(hash, 0)
}
/** Must already contain the key of `a` */
final def updated(a: V): H = {
this.insertOrOverwrite(a.key.hash, a, 0, OP_UPDATE) match {
case null => this
case newRoot => newRoot
}
}
/** Mustn't already contain the key of `a` */
final def inserted(a: V): H = {
val newRoot = this.insertOrOverwrite(a.key.hash, a, 0, OP_INSERT)
assert(newRoot ne null)
newRoot
}
final def insertedAllFrom(that: H): H = {
that.insertIntoHamt(this)
}
/** May or may not already contain the key of `a` */
final def upserted(a: V): H = {
this.insertOrOverwrite(a.key.hash, a, 0, OP_UPSERT) match {
case null => this
case newRoot => newRoot
}
}
final def computeIfAbsent[T](k: K, tok: T, visitor: Hamt.EntryVisitor[K, V, T]): H = {
this.visit(k, k.hash, tok, visitor, modify = false, shift = 0) match {
case null => this
case newRoot => newRoot
}
}
final def computeOrModify[T](k: K, tok: T, visitor: Hamt.EntryVisitor[K, V, T]): H = {
this.visit(k, k.hash, tok, visitor, modify = true, shift = 0) match {
case null => this
case newRoot => newRoot
}
}
/**
* Converts all values with `convertForArray`
* (implemented in a subclass), and copies the
* results into an array (created with `newArray`,
* also implemented in a subclass).
*/
final def toArray(tok: T1, flag: Boolean, nullIfBlue: Boolean): Array[E] =
this.copyToArrayInternal(tok, flag = flag, nullIfBlue = nullIfBlue)
final override def equals(that: Any): Boolean = {
if (equ(this, that)) {
true
} else {
that match {
case that: Hamt[_, _, _, _, _, _] =>
this.equalsInternal(that)
case _ =>
false
}
}
}
final override def hashCode: Int = {
this.hashCodeInternal(0xf9ee8a53)
}
final override def toString: String = {
this.toString(pre = "Hamt(", post = ")")
}
// Internal:
// @tailrec
private final def lookupOrNull(hash: Long, shift: Int): V = {
this.getValueOrNodeOrNull(hash, shift) match {
case null =>
nullOf[V]
case node: Hamt[_, _, _, _, _, _] =>
node.lookupOrNull(hash, shift + W).asInstanceOf[V]
case value =>
val a = value.asInstanceOf[V]
val hashA = a.key.hash
if (hash == hashA) {
a
} else {
nullOf[V]
}
}
}
private[this] final def getValueOrNodeOrNull(hash: Long, shift: Int): AnyRef = {
val bitmap = this.bitmap
if (bitmap != 0L) {
val flag: Long = 1L << logicalIdx(hash, shift) // only 1 bit set, at the position in bitmap
if ((bitmap & flag) != 0L) {
// we have an entry for this:
val idx: Int = physicalIdx(bitmap, flag)
this.contents(idx)
} else {
// no entry for this hash:
null
}
} else {
// empty HAMT
null
}
}
private final def visit[T](k: K, hash: Long, tok: T, visitor: Hamt.EntryVisitor[K, V, T], modify: Boolean, shift: Int): H = {
this.getValueOrNodeOrNull(hash, shift) match {
case null =>
visitor.entryAbsent(k, tok) match {
case null =>
nullOf[H]
case newVal =>
assert(newVal.key.hash == hash)
// TODO: this will compute physIdx again:
this.insertOrOverwrite(hash, newVal, shift, op = OP_INSERT)
}
case node: Hamt[_, _, _, _, _, _] =>
node.asInstanceOf[H].visit(k, hash, tok, visitor, modify = modify, shift = shift + W) match {
case null =>
nullOf[H]
case newNode =>
val oldSize = this.size
val newSize = oldSize + (newNode.size - node.size)
assert((modify && ((newSize == oldSize) || (newSize == (oldSize + 1)))) || (newSize == (oldSize + 1)))
val bitmap = this.bitmap
// TODO: we're computing physIdx twice:
val physIdx: Int = physicalIdx(bitmap, 1L << logicalIdx(hash, shift))
this.withNode(newSize, bitmap, newNode, physIdx)
}
case value =>
val a = value.asInstanceOf[V]
val hashA = a.key.hash
if (hash == hashA) {
val newEntry = visitor.entryPresent(k, a, tok)
if (modify) {
if (equ(newEntry, a)) {
nullOf[H]
} else {
assert(newEntry.key.hash == hashA)
this.insertOrOverwrite(hashA, newEntry, shift, op = OP_UPDATE)
}
} else {
assert(equ(newEntry, a))
nullOf[H]
}
} else {
visitor.entryAbsent(k, tok) match {
case null =>
nullOf[H]
case newVal =>
assert(newVal.key.hash == hash)
// TODO: this will compute physIdx again:
this.insertOrOverwrite(hash, newVal, shift, op = OP_INSERT)
}
}
}
}
private final def insertOrOverwrite(hash: Long, value: V, shift: Int, op: Int): H = {
val flag: Long = 1L << logicalIdx(hash, shift) // only 1 bit set, at the position in bitmap
val bitmap = this.bitmap
if (bitmap != 0L) {
val contents = this.contents
val physIdx: Int = physicalIdx(bitmap, flag)
if ((bitmap & flag) != 0L) {
// we have an entry for this:
contents(physIdx) match {
case node: Hamt[_, _, _, _, _, _] =>
node.asInstanceOf[H].insertOrOverwrite(hash, value, shift + W, op) match {
case null =>
nullOf[H]
case newNode =>
this.withNode(this.size + (newNode.size - node.size), bitmap, newNode, physIdx)
}
case ov =>
val oh = ov.asInstanceOf[V].key.hash
if (hash == oh) {
if (op == OP_INSERT) {
throw new IllegalArgumentException
} else if (equ(ov, value)) {
nullOf[H]
} else {
this.withValue(bitmap, value, physIdx)
}
} else {
// hash collision at this level,
// so we go down one level:
val childNode = {
val cArr = new Array[AnyRef](1)
cArr(0) = ov
val oFlag = 1L << logicalIdx(oh, shift + W)
this.newNode(sizeAndBlue = packSizeAndBlueInternal(1, isBlue(ov.asInstanceOf[V])), bitmap = oFlag, contents = cArr)
}
val childNode2 = childNode.insertOrOverwrite(hash, value, shift + W, op)
this.withNode(this.size + (childNode2.size - 1), bitmap, childNode2, physIdx)
}
}
} else {
// no entry for this hash:
if (op == OP_UPDATE) {
throw new IllegalArgumentException
} else {
val newBitmap: Long = bitmap | flag
val len = contents.length
val newArr = new Array[AnyRef](len + 1)
System.arraycopy(contents, 0, newArr, 0, physIdx)
System.arraycopy(contents, physIdx, newArr, physIdx + 1, len - physIdx)
newArr(physIdx) = box(value)
this.newNode(
sizeAndBlue = packSizeAndBlueInternal(this.size + 1, this.isBlueSubtree && isBlue(value)),
bitmap = newBitmap,
contents = newArr
)
}
}
} else {
// empty node
if (op == OP_UPDATE) {
throw new IllegalArgumentException
} else {
val newArr = new Array[AnyRef](1)
newArr(0) = box(value)
this.newNode(packSizeAndBlueInternal(1, isBlue(value)), flag, newArr)
}
}
}
protected override def equalsInternal(that: AbstractHamt[_, _, _, _, _, _]): Boolean = {
if (this.bitmap == that.asInstanceOf[H].bitmap) {
super.equalsInternal(that)
} else {
// fast path:
false
}
}
private[this] final def withValue(bitmap: Long, value: V, physIdx: Int): H = {
this.newNode(
sizeAndBlue = packSizeAndBlueInternal(this.size, this.isBlueSubtree && isBlue(value)),
bitmap = bitmap,
contents = arrReplacedValue(this.contents, box(value), physIdx),
)
}
private[this] final def withNode(size: Int, bitmap: Long, node: Hamt[K, V, E, _, _, _], physIdx: Int): H = {
this.newNode(
sizeAndBlue = packSizeAndBlueInternal(size, this.isBlueSubtree && node.isBlueSubtree),
bitmap = bitmap,
contents = arrReplacedValue(this.contents, node, physIdx),
)
}
private[this] final def arrReplacedValue(arr: Array[AnyRef], value: AnyRef, idx: Int): Array[AnyRef] = {
val newArr = Arrays.copyOf(arr, arr.length)
newArr(idx) = value
newArr
}
/** Index into the imaginary 64-element sparse array */
private[this] final def logicalIdx(hash: Long, shift: Int): Int = {
// Note: this logic is duplicated in `MemoryLocationOrdering`.
val mask = START_MASK >>> shift // masks the bits we're interested in
val sh = java.lang.Long.numberOfTrailingZeros(mask) // we'll shift the masked result
// we do it this way, because at the end, when `shift` is 60,
// we don't actually need to shift (i.e., `sh` will be 0),
// because we just need the 4 lowest bits
((hash & mask) >>> sh).toInt
// TODO: It it measurably slower this way, than
// TODO: just using the lowest bits first (see
// TODO: `ShiftBench`). We should check if it
// TODO: really matters.
}
/** For testing only! */
private[mcas] final def logicalIdx_public(hash: Long, shift: Int): Int = {
this.logicalIdx(hash, shift)
}
/** Index into the actual dense array (`contents`) */
private[this] final def physicalIdx(bitmap: Long, flag: Long): Int = {
java.lang.Long.bitCount(bitmap & (flag - 1L))
}
// TODO: this is duplicated with `AbstractHamt`
private[this] final def packSizeAndBlueInternal(size: Int, isBlue: Boolean): Int = {
val x = (-1) * java.lang.Math.abs(java.lang.Boolean.compare(isBlue, true))
size * ((x << 1) + 1)
}
}
private[choam] object Hamt {
trait HasKey[K <: HasHash] {
def key: K
}
trait HasHash {
def hash: Long
}
trait EntryVisitor[K, V, T] {
def entryPresent(k: K, v: V, tok: T): V
def entryAbsent(k: K, tok: T): V
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy