All Downloads are FREE. Search and download functionalities are using the official Maven repository.

scala.collection.mutable.HashMap.scala Maven / Gradle / Ivy

The newest version!
/*
 * Scala (https://www.scala-lang.org)
 *
 * Copyright EPFL and Lightbend, Inc.
 *
 * Licensed under Apache License 2.0
 * (http://www.apache.org/licenses/LICENSE-2.0).
 *
 * See the NOTICE file distributed with this work for
 * additional information regarding copyright ownership.
 */

package scala.collection
package mutable

import scala.annotation.{nowarn, tailrec}
import scala.collection.Stepper.EfficientSplit
import scala.collection.generic.DefaultSerializationProxy
import scala.util.hashing.MurmurHash3

/** This class implements mutable maps using a hashtable.
  *
  *  @see [[https://docs.scala-lang.org/overviews/collections-2.13/concrete-mutable-collection-classes.html#hash-tables "Scala's Collection Library overview"]]
  *  section on `Hash Tables` for more information.
  *
  *  @tparam K    the type of the keys contained in this hash map.
  *  @tparam V    the type of the values assigned to keys in this hash map.
  *
  *  @define Coll `mutable.HashMap`
  *  @define coll mutable hash map
  *  @define mayNotTerminateInf
  *  @define willNotTerminateInf
  */
@deprecatedInheritance("HashMap will be made final; use .withDefault for the common use case of computing a default value", "2.13.0")
class HashMap[K, V](initialCapacity: Int, loadFactor: Double)
  extends AbstractMap[K, V]
    with MapOps[K, V, HashMap, HashMap[K, V]]
    with StrictOptimizedIterableOps[(K, V), Iterable, HashMap[K, V]]
    with StrictOptimizedMapOps[K, V, HashMap, HashMap[K, V]]
    with MapFactoryDefaults[K, V, HashMap, Iterable]
    with Serializable {

  /* The HashMap class holds the following invariant:
   * - For each i between  0 and table.length, the bucket at table(i) only contains keys whose hash-index is i.
   * - Every bucket is sorted in ascendent hash order
   * - The sum of the lengths of all buckets is equal to contentSize.
   */
  def this() = this(HashMap.defaultInitialCapacity, HashMap.defaultLoadFactor)

  import HashMap.Node

  /** The actual hash table. */
  private[this] var table = new Array[Node[K, V]](tableSizeFor(initialCapacity))

  /** The next size value at which to resize (capacity * load factor). */
  private[this] var threshold: Int = newThreshold(table.length)

  private[this] var contentSize = 0

  override def size: Int = contentSize

  /** Performs the inverse operation of improveHash. In this case, it happens to be identical to improveHash*/
  @`inline` private[collection] def unimproveHash(improvedHash: Int): Int = improveHash(improvedHash)

  /** Computes the improved hash of an original (`any.##`) hash. */
  @`inline` private[this] def improveHash(originalHash: Int): Int = {
    // Improve the hash by xoring the high 16 bits into the low 16 bits just in case entropy is skewed towards the
    // high-value bits. We only use the lowest bits to determine the hash bucket. This is the same improvement
    // algorithm as in java.util.HashMap.
    //
    // This function is also its own inverse. That is, for all ints i, improveHash(improveHash(i)) = i
    // this allows us to retrieve the original hash when we need it, for instance when appending to an immutable.HashMap
    // and that is why unimproveHash simply forwards to this method
    originalHash ^ (originalHash >>> 16)
  }

  /** Computes the improved hash of this key */
  @`inline` private[this] def computeHash(o: K): Int = improveHash(o.##)

  @`inline` private[this] def index(hash: Int) = hash & (table.length - 1)

  override def contains(key: K): Boolean = findNode(key) ne null

  @`inline` private[this] def findNode(key: K): Node[K, V] = {
    val hash = computeHash(key)
    table(index(hash)) match {
      case null => null
      case nd => nd.findNode(key, hash)
    }
  }

  override def sizeHint(size: Int): Unit = {
    val target = tableSizeFor(((size + 1).toDouble / loadFactor).toInt)
    if(target > table.length) growTable(target)
  }

  override def addAll(xs: IterableOnce[(K, V)]): this.type = {
    sizeHint(xs)

    xs match {
      case hm: immutable.HashMap[K, V] =>
        hm.foreachWithHash((k, v, h) => put0(k, v, improveHash(h), getOld = false))
        this
      case hm: mutable.HashMap[K, V] =>
        val iter = hm.nodeIterator
        while (iter.hasNext) {
          val next = iter.next()
          put0(next.key, next.value, next.hash, getOld = false)
        }
        this
      case lhm: mutable.LinkedHashMap[K, V] =>
        val iter = lhm.entryIterator
        while (iter.hasNext) {
          val entry = iter.next()
          put0(entry.key, entry.value, entry.hash, getOld = false)
        }
        this
      case thatMap: Map[K, V] =>
        thatMap.foreachEntry { (key: K, value: V) =>
          put0(key, value, improveHash(key.##), getOld = false)
        }
        this
      case _ =>
        super.addAll(xs)
    }
  }

  // Override updateWith for performance, so we can do the update while hashing
  // the input key only once and performing one lookup into the hash table
  override def updateWith(key: K)(remappingFunction: Option[V] => Option[V]): Option[V] = {
    if (getClass != classOf[HashMap[_, _]]) {
      // subclasses of HashMap might customise `get` ...
      super.updateWith(key)(remappingFunction)
    } else {
      val hash = computeHash(key)
      val indexedHash = index(hash)

      var foundNode: Node[K, V] = null
      var previousNode: Node[K, V] = null
      table(indexedHash) match {
        case null =>
        case nd =>
          @tailrec
          def findNode(prev: Node[K, V], nd: Node[K, V], k: K, h: Int): Unit = {
            if (h == nd.hash && k == nd.key) {
              previousNode = prev
              foundNode = nd
            }
            else if ((nd.next eq null) || (nd.hash > h)) ()
            else findNode(nd, nd.next, k, h)
          }

          findNode(null, nd, key, hash)
      }

      val previousValue = foundNode match {
        case null => None
        case nd => Some(nd.value)
      }

      val nextValue = remappingFunction(previousValue)

      (previousValue, nextValue) match {
        case (None, None) => // do nothing

        case (Some(_), None) =>
          if (previousNode != null) previousNode.next = foundNode.next
          else table(indexedHash) = foundNode.next
          contentSize -= 1

        case (None, Some(value)) =>
          val newIndexedHash =
            if (contentSize + 1 >= threshold) {
              growTable(table.length * 2)
              index(hash)
            } else indexedHash
          put0(key, value, getOld = false, hash, newIndexedHash)

        case (Some(_), Some(newValue)) => foundNode.value = newValue
      }
      nextValue
    }
  }

  override def subtractAll(xs: IterableOnce[K]): this.type = {
    if (size == 0) {
      return this
    }

    xs match {
      case hs: immutable.HashSet[K] =>
        hs.foreachWithHashWhile { (k, h) =>
          remove0(k, improveHash(h))
          size > 0
        }
        this
      case hs: mutable.HashSet[K] =>
        val iter = hs.nodeIterator
        while (iter.hasNext) {
          val next = iter.next()
          remove0(next.key, next.hash)
          if (size == 0) return this
        }
        this
      case lhs: mutable.LinkedHashSet[K] =>
        val iter = lhs.entryIterator
        while (iter.hasNext) {
          val next = iter.next()
          remove0(next.key, next.hash)
          if (size == 0) return this
        }
        this
      case _ => super.subtractAll(xs)
    }
  }

  /** Adds a key-value pair to this map
    *
    * @param key the key to add
    * @param value the value to add
    * @param hash the **improved** hashcode of `key` (see computeHash)
    * @param getOld if true, then the previous value for `key` will be returned, otherwise, false
    */
  private[this] def put0(key: K, value: V, hash: Int, getOld: Boolean): Some[V] = {
    if(contentSize + 1 >= threshold) growTable(table.length * 2)
    val idx = index(hash)
    put0(key, value, getOld, hash, idx)
  }

  private[this] def put0(key: K, value: V, getOld: Boolean): Some[V] = {
    if(contentSize + 1 >= threshold) growTable(table.length * 2)
    val hash = computeHash(key)
    val idx = index(hash)
    put0(key, value, getOld, hash, idx)
  }


  private[this] def put0(key: K, value: V, getOld: Boolean, hash: Int, idx: Int): Some[V] = {
    table(idx) match {
      case null =>
        table(idx) = new Node[K, V](key, hash, value, null)
      case old =>
        var prev: Node[K, V] = null
        var n = old
        while((n ne null) && n.hash <= hash) {
          if(n.hash == hash && key == n.key) {
            val old = n.value
            n.value = value
            return if(getOld) Some(old) else null
          }
          prev = n
          n = n.next
        }
        if(prev eq null) table(idx) = new Node(key, hash, value, old)
        else prev.next = new Node(key, hash, value, prev.next)
    }
    contentSize += 1
    null
  }

  private def remove0(elem: K) : Node[K, V] = remove0(elem, computeHash(elem))

  /** Removes a key from this map if it exists
    *
    * @param elem the element to remove
    * @param hash the **improved** hashcode of `element` (see computeHash)
    * @return the node that contained element if it was present, otherwise null
    */
  private[this] def remove0(elem: K, hash: Int) : Node[K, V] = {
    val idx = index(hash)
    table(idx) match {
      case null => null
      case nd if nd.hash == hash && nd.key == elem =>
        // first element matches
        table(idx) = nd.next
        contentSize -= 1
        nd
      case nd =>
        // find an element that matches
        var prev = nd
        var next = nd.next
        while((next ne null) && next.hash <= hash) {
          if(next.hash == hash && next.key == elem) {
            prev.next = next.next
            contentSize -= 1
            return next
          }
          prev = next
          next = next.next
        }
        null
    }
  }

  private[this] abstract class HashMapIterator[A] extends AbstractIterator[A] {
    private[this] var i = 0
    private[this] var node: Node[K, V] = null
    private[this] val len = table.length

    protected[this] def extract(nd: Node[K, V]): A

    def hasNext: Boolean = {
      if(node ne null) true
      else {
        while(i < len) {
          val n = table(i)
          i += 1
          if(n ne null) { node = n; return true }
        }
        false
      }
    }

    def next(): A =
      if(!hasNext) Iterator.empty.next()
      else {
        val r = extract(node)
        node = node.next
        r
      }
  }

  override def iterator: Iterator[(K, V)] =
    if(size == 0) Iterator.empty
    else new HashMapIterator[(K, V)] {
      protected[this] def extract(nd: Node[K, V]) = (nd.key, nd.value)
    }

  override def keysIterator: Iterator[K] =
    if(size == 0) Iterator.empty
    else new HashMapIterator[K] {
      protected[this] def extract(nd: Node[K, V]) = nd.key
    }

  override def valuesIterator: Iterator[V] =
    if(size == 0) Iterator.empty
    else new HashMapIterator[V] {
      protected[this] def extract(nd: Node[K, V]) = nd.value
    }


  /** Returns an iterator over the nodes stored in this HashMap */
  private[collection] def nodeIterator: Iterator[Node[K, V]] =
    if(size == 0) Iterator.empty
    else new HashMapIterator[Node[K, V]] {
      protected[this] def extract(nd: Node[K, V]) = nd
    }

  override def stepper[S <: Stepper[_]](implicit shape: StepperShape[(K, V), S]): S with EfficientSplit =
    shape.
      parUnbox(new convert.impl.AnyTableStepper[(K, V), Node[K, V]](size, table, _.next, node => (node.key, node.value), 0, table.length)).
      asInstanceOf[S with EfficientSplit]

  override def keyStepper[S <: Stepper[_]](implicit shape: StepperShape[K, S]): S with EfficientSplit = {
    import convert.impl._
    val s = shape.shape match {
      case StepperShape.IntShape    => new IntTableStepper[Node[K, V]]   (size, table, _.next, _.key.asInstanceOf[Int],    0, table.length)
      case StepperShape.LongShape   => new LongTableStepper[Node[K, V]]  (size, table, _.next, _.key.asInstanceOf[Long],   0, table.length)
      case StepperShape.DoubleShape => new DoubleTableStepper[Node[K, V]](size, table, _.next, _.key.asInstanceOf[Double], 0, table.length)
      case _         => shape.parUnbox(new AnyTableStepper[K, Node[K, V]](size, table, _.next, _.key,                      0, table.length))
    }
    s.asInstanceOf[S with EfficientSplit]
  }

  override def valueStepper[S <: Stepper[_]](implicit shape: StepperShape[V, S]): S with EfficientSplit = {
    import convert.impl._
    val s = shape.shape match {
      case StepperShape.IntShape    => new IntTableStepper[Node[K, V]]   (size, table, _.next, _.value.asInstanceOf[Int],    0, table.length)
      case StepperShape.LongShape   => new LongTableStepper[Node[K, V]]  (size, table, _.next, _.value.asInstanceOf[Long],   0, table.length)
      case StepperShape.DoubleShape => new DoubleTableStepper[Node[K, V]](size, table, _.next, _.value.asInstanceOf[Double], 0, table.length)
      case _         => shape.parUnbox(new AnyTableStepper[V, Node[K, V]](size, table, _.next, _.value,                      0, table.length))
    }
    s.asInstanceOf[S with EfficientSplit]
  }

  private[this] def growTable(newlen: Int) = {
    if (newlen < 0)
      throw new RuntimeException(s"new HashMap table size $newlen exceeds maximum")
    var oldlen = table.length
    threshold = newThreshold(newlen)
    if(size == 0) table = new Array(newlen)
    else {
      table = java.util.Arrays.copyOf(table, newlen)
      val preLow: Node[K, V] = new Node(null.asInstanceOf[K], 0, null.asInstanceOf[V], null)
      val preHigh: Node[K, V] = new Node(null.asInstanceOf[K], 0, null.asInstanceOf[V], null)
      // Split buckets until the new length has been reached. This could be done more
      // efficiently when growing an already filled table to more than double the size.
      while(oldlen < newlen) {
        var i = 0
        while (i < oldlen) {
          val old = table(i)
          if(old ne null) {
            preLow.next = null
            preHigh.next = null
            var lastLow: Node[K, V] = preLow
            var lastHigh: Node[K, V] = preHigh
            var n = old
            while(n ne null) {
              val next = n.next
              if((n.hash & oldlen) == 0) { // keep low
                lastLow.next = n
                lastLow = n
              } else { // move to high
                lastHigh.next = n
                lastHigh = n
              }
              n = next
            }
            lastLow.next = null
            if(old ne preLow.next) table(i) = preLow.next
            if(preHigh.next ne null) {
              table(i + oldlen) = preHigh.next
              lastHigh.next = null
            }
          }
          i += 1
        }
        oldlen *= 2
      }
    }
  }

  private[this] def tableSizeFor(capacity: Int) =
    (Integer.highestOneBit((capacity-1).max(4))*2).min(1 << 30)

  private[this] def newThreshold(size: Int) = (size.toDouble * loadFactor).toInt

  override def clear(): Unit = {
    java.util.Arrays.fill(table.asInstanceOf[Array[AnyRef]], null)
    contentSize = 0
  }

  def get(key: K): Option[V] = findNode(key) match {
    case null => None
    case nd => Some(nd.value)
  }

  @throws[NoSuchElementException]
  override def apply(key: K): V = findNode(key) match {
    case null => default(key)
    case nd => nd.value
  }

  override def getOrElse[V1 >: V](key: K, default: => V1): V1 = {
    if (getClass != classOf[HashMap[_, _]]) {
      // subclasses of HashMap might customise `get` ...
      super.getOrElse(key, default)
    } else {
      // .. but in the common case, we can avoid the Option boxing.
      val nd = findNode(key)
      if (nd eq null) default else nd.value
    }
  }

  override def getOrElseUpdate(key: K, defaultValue: => V): V = {
    if (getClass != classOf[HashMap[_, _]]) {
      // subclasses of HashMap might customise `get` ...
      super.getOrElseUpdate(key, defaultValue)
    } else {
      val hash = computeHash(key)
      val idx = index(hash)
      val nd = table(idx) match {
        case null => null
        case nd => nd.findNode(key, hash)
      }
      if(nd != null) nd.value
      else {
        val table0 = table
        val default = defaultValue
        if(contentSize + 1 >= threshold) growTable(table.length * 2)
        // Avoid recomputing index if the `defaultValue()` or new element hasn't triggered a table resize.
        val newIdx = if (table0 eq table) idx else index(hash)
        put0(key, default, getOld = false, hash, newIdx)
        default
      }
    }
  }

  override def put(key: K, value: V): Option[V] = put0(key, value, getOld = true) match {
    case null => None
    case sm => sm
  }

  override def remove(key: K): Option[V] = remove0(key) match {
    case null => None
    case nd => Some(nd.value)
  }

  override def update(key: K, value: V): Unit = put0(key, value, getOld = false)

  def addOne(elem: (K, V)): this.type = { put0(elem._1, elem._2, getOld = false); this }

  def subtractOne(elem: K): this.type = { remove0(elem); this }

  override def knownSize: Int = size

  override def isEmpty: Boolean = size == 0

  override def foreach[U](f: ((K, V)) => U): Unit = {
    val len = table.length
    var i = 0
    while(i < len) {
      val n = table(i)
      if(n ne null) n.foreach(f)
      i += 1
    }
  }

  override def foreachEntry[U](f: (K, V) => U): Unit = {
    val len = table.length
    var i = 0
    while(i < len) {
      val n = table(i)
      if(n ne null) n.foreachEntry(f)
      i += 1
    }
  }

  protected[this] def writeReplace(): AnyRef = new DefaultSerializationProxy(new mutable.HashMap.DeserializationFactory[K, V](table.length, loadFactor), this)

  override def filterInPlace(p: (K, V) => Boolean): this.type = {
    if (nonEmpty) {
      var bucket = 0

      while (bucket < table.length) {
        var head = table(bucket)

        while ((head ne null) && !p(head.key, head.value)) {
          head = head.next
          contentSize -= 1
        }

        if (head ne null) {
          var prev = head
          var next = head.next

          while (next ne null) {
            if (p(next.key, next.value)) {
              prev = next
            } else {
              prev.next = next.next
              contentSize -= 1
            }
            next = next.next
          }
        }

        table(bucket) = head
        bucket += 1
      }
    }
    this
  }

  // TODO: rename to `mapValuesInPlace` and override the base version (not binary compatible)
  private[mutable] def mapValuesInPlaceImpl(f: (K, V) => V): this.type = {
    val len = table.length
    var i = 0
    while (i < len) {
      var n = table(i)
      while (n ne null) {
        n.value = f(n.key, n.value)
        n = n.next
      }
      i += 1
    }
    this
  }

  override def mapFactory: MapFactory[HashMap] = HashMap

  @nowarn("""cat=deprecation&origin=scala\.collection\.Iterable\.stringPrefix""")
  override protected[this] def stringPrefix = "HashMap"

  override def hashCode: Int = {
    if (isEmpty) MurmurHash3.emptyMapHash
    else {
      val tupleHashIterator = new HashMapIterator[Any] {
        var hash: Int = 0
        override def hashCode: Int = hash
        override protected[this] def extract(nd: Node[K, V]): Any = {
          hash = MurmurHash3.tuple2Hash(unimproveHash(nd.hash), nd.value.##)
          this
        }
      }
      MurmurHash3.unorderedHash(tupleHashIterator, MurmurHash3.mapSeed)
    }
  }
}

/**
  * $factoryInfo
  *  @define Coll `mutable.HashMap`
  *  @define coll mutable hash map
  */
@SerialVersionUID(3L)
object HashMap extends MapFactory[HashMap] {

  def empty[K, V]: HashMap[K, V] = new HashMap[K, V]

  def from[K, V](it: collection.IterableOnce[(K, V)]): HashMap[K, V] = {
    val k = it.knownSize
    val cap = if(k > 0) ((k + 1).toDouble / defaultLoadFactor).toInt else defaultInitialCapacity
    new HashMap[K, V](cap, defaultLoadFactor).addAll(it)
  }

  def newBuilder[K, V]: Builder[(K, V), HashMap[K, V]] = newBuilder(defaultInitialCapacity, defaultLoadFactor)

  def newBuilder[K, V](initialCapacity: Int, loadFactor: Double): Builder[(K, V), HashMap[K, V]] =
    new GrowableBuilder[(K, V), HashMap[K, V]](new HashMap[K, V](initialCapacity, loadFactor)) {
      override def sizeHint(size: Int) = elems.sizeHint(size)
    }

  /** The default load factor for the hash table */
  final def defaultLoadFactor: Double = 0.75

  /** The default initial capacity for the hash table */
  final def defaultInitialCapacity: Int = 16

  @SerialVersionUID(3L)
  private final class DeserializationFactory[K, V](val tableLength: Int, val loadFactor: Double) extends Factory[(K, V), HashMap[K, V]] with Serializable {
    def fromSpecific(it: IterableOnce[(K, V)]): HashMap[K, V] = new HashMap[K, V](tableLength, loadFactor).addAll(it)
    def newBuilder: Builder[(K, V), HashMap[K, V]] = HashMap.newBuilder(tableLength, loadFactor)
  }

  private[collection] final class Node[K, V](_key: K, _hash: Int, private[this] var _value: V, private[this] var _next: Node[K, V]) {
    def key: K = _key
    def hash: Int = _hash
    def value: V = _value
    def value_= (v: V): Unit = _value = v
    def next: Node[K, V] = _next
    def next_= (n: Node[K, V]): Unit = _next = n

    @tailrec
    def findNode(k: K, h: Int): Node[K, V] =
      if(h == _hash && k == _key) this
      else if((_next eq null) || (_hash > h)) null
      else _next.findNode(k, h)

    @tailrec
    def foreach[U](f: ((K, V)) => U): Unit = {
      f((_key, _value))
      if(_next ne null) _next.foreach(f)
    }

    @tailrec
    def foreachEntry[U](f: (K, V) => U): Unit = {
      f(_key, _value)
      if(_next ne null) _next.foreachEntry(f)
    }

    override def toString = s"Node($key, $value, $hash) -> $next"
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy