org.apache.pekko.io.SimpleDnsCache.scala Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* license agreements; and to You under the Apache License, version 2.0:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* This file is part of the Apache Pekko project, which was derived from Akka.
*/
/*
* Copyright (C) 2018-2022 Lightbend Inc.
*/
package org.apache.pekko.io
import java.util.concurrent.atomic.AtomicReference
import scala.annotation.tailrec
import scala.collection.immutable
import scala.annotation.nowarn
import org.apache.pekko
import pekko.actor.NoSerializationVerificationNeeded
import pekko.annotation.InternalApi
import pekko.io.dns.{ AAAARecord, ARecord }
import pekko.io.dns.CachePolicy.CachePolicy
import pekko.io.dns.CachePolicy.Forever
import pekko.io.dns.CachePolicy.Never
import pekko.io.dns.CachePolicy.Ttl
import pekko.io.dns.DnsProtocol
import pekko.io.dns.DnsProtocol.{ Ip, RequestType, Resolved }
private[io] trait PeriodicCacheCleanup {
def cleanup(): Unit
}
class SimpleDnsCache extends Dns with PeriodicCacheCleanup with NoSerializationVerificationNeeded {
import SimpleDnsCache._
private val cacheRef = new AtomicReference(
new Cache[(String, RequestType), Resolved](
immutable.SortedSet()(expiryEntryOrdering()),
immutable.Map(),
() => clock()))
private val nanoBase = System.nanoTime()
/**
* Gets any IPv4 and IPv6 cached entries.
* To get Srv or just one type use DnsProtocol
*
* This method is deprecated and involves a copy from the new protocol to
* remain compatible
*/
@nowarn("msg=deprecated")
override def cached(name: String): Option[Dns.Resolved] = {
// adapt response to the old protocol
val ipv4 = cacheRef.get().get((name, Ip(ipv6 = false))).toList.flatMap(_.records)
val ipv6 = cacheRef.get().get((name, Ip(ipv4 = false))).toList.flatMap(_.records)
val both = cacheRef.get().get((name, Ip())).toList.flatMap(_.records)
val all = (ipv4 ++ ipv6 ++ both).collect {
case r: ARecord => r.ip
case r: AAAARecord => r.ip
}
if (all.isEmpty) None
else Some(Dns.Resolved(name, all))
}
override def cached(request: DnsProtocol.Resolve): Option[DnsProtocol.Resolved] =
cacheRef.get().get((request.name, request.requestType))
// Milliseconds since start
protected def clock(): Long = {
val now = System.nanoTime()
if (now - nanoBase < 0) 0
else (now - nanoBase) / 1000000
}
/**
* INTERNAL API
*/
@InternalApi
private[pekko] final def get(key: (String, RequestType)): Option[Resolved] = {
cacheRef.get().get(key)
}
@tailrec
private[io] final def put(key: (String, RequestType), records: Resolved, ttl: CachePolicy): Unit = {
val c = cacheRef.get()
if (!cacheRef.compareAndSet(c, c.put(key, records, ttl)))
put(key, records, ttl)
}
@tailrec
override final def cleanup(): Unit = {
val c = cacheRef.get()
if (!cacheRef.compareAndSet(c, c.cleanup()))
cleanup()
}
}
object SimpleDnsCache {
/**
* INTERNAL API
*/
@InternalApi
private[io] class Cache[K, V](
queue: immutable.SortedSet[ExpiryEntry[K]],
cache: immutable.Map[K, CacheEntry[V]],
clock: () => Long) {
def get(name: K): Option[V] = {
for {
e <- cache.get(name)
if e.isValid(clock())
} yield e.answer
}
def put(name: K, answer: V, ttl: CachePolicy): Cache[K, V] = {
val until = ttl match {
case Forever => Long.MaxValue
case Never => clock() - 1
case ttl: Ttl => clock() + ttl.value.toMillis
}
new Cache[K, V](queue + new ExpiryEntry[K](name, until), cache + (name -> CacheEntry(answer, until)), clock)
}
def cleanup(): Cache[K, V] = {
val now = clock()
var q = queue
var c = cache
while (q.nonEmpty && !q.head.isValid(now)) {
val minEntry = q.head
val name = minEntry.name
q -= minEntry
if (c.get(name).filterNot(_.isValid(now)).isDefined)
c -= name
}
new Cache(q, c, clock)
}
}
private[io] case class CacheEntry[T](answer: T, until: Long) {
def isValid(clock: Long): Boolean = clock < until
}
/**
* INTERNAL API
*/
@InternalApi
private[io] class ExpiryEntry[K](val name: K, val until: Long) extends Ordered[ExpiryEntry[K]] {
def isValid(clock: Long): Boolean = clock < until
override def compare(that: ExpiryEntry[K]): Int = -until.compareTo(that.until)
}
/**
* INTERNAL API
*/
@InternalApi
private[io] def expiryEntryOrdering[K]() = new Ordering[ExpiryEntry[K]] {
override def compare(x: ExpiryEntry[K], y: ExpiryEntry[K]): Int = {
x.until.compareTo(y.until)
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy