
io.otoroshi.wasm4s.impl.runtimev2.scala Maven / Gradle / Ivy
package io.otoroshi.wasm4s.impl
import akka.stream.OverflowStrategy
import akka.stream.scaladsl._
import com.codahale.metrics.UniformReservoir
import io.otoroshi.wasm4s.scaladsl._
import io.otoroshi.wasm4s.scaladsl.opa._
import io.otoroshi.wasm4s.scaladsl.implicits._
import org.extism.sdk.{HostFunction, HostUserData, Plugin}
import org.extism.sdk.manifest.{Manifest, MemoryOptions}
import org.extism.sdk.wasm.WasmSourceResolver
import org.extism.sdk.wasmotoroshi._
import org.joda.time.DateTime
import play.api.libs.json._
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.atomic._
import scala.collection.JavaConverters._
import scala.collection.concurrent.TrieMap
import scala.concurrent.duration.{DurationInt, FiniteDuration}
import scala.concurrent.{Await, ExecutionContext, Future, Promise}
sealed trait WasmVmAction
object WasmVmAction {
case object WasmVmKillAction extends WasmVmAction
case class WasmVmCallAction(
parameters: WasmFunctionParameters,
context: Option[WasmVmData],
promise: Promise[Either[JsValue, (String, ResultsWrapper)]]
) extends WasmVmAction
}
case class OPAWasmVm(opaDataAddr: Int, opaBaseHeapPtr: Int)
case class WasmVmImpl(
index: Int,
maxCalls: Int,
maxMemory: Long,
resetMemory: Boolean,
instance: Plugin,
vmDataRef: AtomicReference[WasmVmData],
memories: Array[LinearMemory],
functions: Array[HostFunction[_ <: HostUserData]],
pool: WasmVmPoolImpl,
var opaPointers: Option[OPAWasmVm] = None
) extends WasmVm {
private val callDurationReservoirNs = new UniformReservoir()
private val lastUsage: AtomicLong = new AtomicLong(System.currentTimeMillis())
private val initializedRef: AtomicBoolean = new AtomicBoolean(false)
private val killAtRelease: AtomicBoolean = new AtomicBoolean(false)
private val inFlight = new AtomicInteger(0)
private val callCounter = new AtomicInteger(0)
private val queue = {
val env = pool.ic
Source
.queue[WasmVmAction](env.wasmQueueBufferSize, OverflowStrategy.dropTail)
.mapAsync(1)(handle)
.toMat(Sink.ignore)(Keep.both)
.run()(env.materializer)
._1
}
def calls: Int = callCounter.get()
def current: Int = inFlight.get()
private def handle(act: WasmVmAction): Future[Unit] = {
Future.apply {
lastUsage.set(System.currentTimeMillis())
act match {
case WasmVmAction.WasmVmKillAction => destroy()
case action: WasmVmAction.WasmVmCallAction => {
try {
inFlight.decrementAndGet()
// action.context.foreach(ctx => WasmContextSlot.setCurrentContext(ctx))
action.context.foreach(ctx => vmDataRef.set(ctx))
if (pool.ic.logger.isDebugEnabled)
pool.ic.logger.debug(s"call vm ${index} with method ${action.parameters.functionName} on thread ${Thread
.currentThread()
.getName} on path ${action.context.map(_.properties.get("request.path").map(v => new String(v))).getOrElse("--")}")
val start = System.nanoTime()
val res = action.parameters.call(instance)
callDurationReservoirNs.update(System.nanoTime() - start)
if (res.isRight && res.right.get._2.results.getValues() != null) {
val ret = res.right.get._2.results.getValues()(0).v.i32
if (ret > 7 || ret < 0) { // weird multi thread issues
ignore()
killAtRelease.set(true)
}
}
action.promise.trySuccess(res)
} catch {
case t: Throwable => action.promise.tryFailure(t)
} finally {
if (resetMemory) {
action.parameters match {
case _m: WasmFunctionParameters.ExtismFuntionCall => instance.reset()
case _m: WasmFunctionParameters.OPACall => // the memory data will already be overwritten during the next call
case _ => instance.resetCustomMemory()
}
}
pool.ic.logger.debug(s"functions: ${functions.size}")
pool.ic.logger.debug(s"memories: ${memories.size}")
// WasmContextSlot.clearCurrentContext()
// vmDataRef.set(null)
val count = callCounter.incrementAndGet()
if (count >= maxCalls) {
callCounter.set(0)
if (pool.ic.logger.isDebugEnabled)
pool.ic.logger.debug(s"killing vm ${index} with remaining ${inFlight.get()} calls (${count})")
destroyAtRelease()
}
}
}
}
()
}(pool.ic.wasmExecutor)
}
def getOpaPointers(): Option[OPAWasmVm] = opaPointers
def reset(): Unit = instance.reset()
def destroy(): Unit = {
if (pool.ic.logger.isDebugEnabled) pool.ic.logger.debug(s"destroy vm: ${index}")
pool.ic.logger.debug(s"destroy vm: ${index}")
pool.clear(this)
instance.close()
}
def isAquired(): Boolean = {
pool.inUseVms.contains(this)
}
def isBusy(): Boolean = {
inFlight.get() > 0
}
def destroyAtRelease(): Unit = {
ignore()
killAtRelease.set(true)
}
def release(): Unit = {
if (killAtRelease.get()) {
queue.offer(WasmVmAction.WasmVmKillAction)
} else {
pool.release(this)
}
}
def lastUsedAt(): Long = lastUsage.get()
def hasNotBeenUsedInTheLast(duration: FiniteDuration): Boolean =
if (duration.toNanos == 0L) false else !hasBeenUsedInTheLast(duration)
def consumesMoreThanMemoryPercent(percent: Double): Boolean = if (percent == 0.0) {
false
} else {
val consumed: Double = instance.getMemorySize.toDouble / maxMemory.toDouble
val res = consumed > percent
if (pool.ic.logger.isDebugEnabled)
pool.ic.logger.debug(
s"consumesMoreThanMemoryPercent($percent) = (${instance.getMemorySize} / $maxMemory) > $percent : $res : (${consumed * 100.0}%)"
)
res
}
def tooSlow(max: Long): Boolean = {
if (max == 0L) {
false
} else {
callDurationReservoirNs.getSnapshot.getMean.toLong > max
}
}
def hasBeenUsedInTheLast(duration: FiniteDuration): Boolean = {
val now = System.currentTimeMillis()
val limit = lastUsage.get() + duration.toMillis
now < limit
}
def ignore(): Unit = pool.ignore(this)
def initialized(): Boolean = initializedRef.get()
def initialize(f: => Any): Unit = {
if (initializedRef.compareAndSet(false, true)) {
f
}
}
def finitialize[A](f: => Future[A]): Future[Unit] = {
if (initializedRef.compareAndSet(false, true)) {
f.map(_ => ())(pool.ic.executionContext)
} else {
().vfuture
}
}
def call(
parameters: WasmFunctionParameters,
context: Option[WasmVmData]
): Future[Either[JsValue, (String, ResultsWrapper)]] = {
val promise = Promise[Either[JsValue, (String, ResultsWrapper)]]()
inFlight.incrementAndGet()
lastUsage.set(System.currentTimeMillis())
queue.offer(WasmVmAction.WasmVmCallAction(parameters, context, promise))
promise.future
}
def callOpa(functionName: String, in: String, context: Option[WasmVmData] = None)(implicit ec: ExecutionContext): Future[Either[JsValue, (String, ResultsWrapper)]] = {
ensureOpaInitialized().call(WasmFunctionParameters.OPACall(functionName, opaPointers, in), context)
}
def ensureOpaInitializedAsync(in: Option[String] = None)(implicit ec: ExecutionContext): Future[WasmVmImpl] = {
if (!initialized()) {
call(
WasmFunctionParameters.OPACall(
"initialize",
in = in.getOrElse(Json.obj().stringify),
),
None
) flatMap {
case Left(error) => Future.failed(new RuntimeException(s"opa initialize error: ${error.stringify}"))
case Right(value) =>
initialize {
val pointers = Json.parse(value._1)
opaPointers = OPAWasmVm(
opaDataAddr = (pointers \ "dataAddr").as[Int],
opaBaseHeapPtr = (pointers \ "baseHeapPtr").as[Int]
).some
}
this.vfuture
}
} else {
this.vfuture
}
}
def ensureOpaInitialized(in: Option[String] = None)(implicit ec: ExecutionContext): WasmVmImpl = {
Await.result(
ensureOpaInitializedAsync(in),
10.seconds
)
}
}
case class WasmVmPoolAction(promise: Promise[WasmVmImpl], options: WasmVmInitOptions) {
private[wasm4s] def provideVm(vm: WasmVmImpl): Unit = promise.trySuccess(vm)
private[wasm4s] def fail(e: Throwable): Unit = promise.tryFailure(e)
}
object WasmVmPoolImpl {
private val instances = new TrieMap[String, WasmVmPoolImpl]()
def allInstances(): Map[String, WasmVmPoolImpl] = instances.synchronized {
instances.toMap
}
def forConfig(config: => WasmConfiguration, maxCallsBetweenUpdates: Int = 100000)(implicit ic: WasmIntegrationContext): WasmVmPoolImpl = instances.synchronized {
val key = s"${config.source.cacheKey}?mcbu=${maxCallsBetweenUpdates}&cfg=${config.json.stringify.sha512}"
instances.getOrUpdate(key) {
new WasmVmPoolImpl(key, config.some, maxCallsBetweenUpdates, ic)
}
}
private[wasm4s] def removePlugin(id: String): Unit = instances.synchronized {
instances.remove(id)
}
}
class WasmVmPoolImpl(stableId: => String, optConfig: => Option[WasmConfiguration], maxCallsBetweenUpdates: Int = 100000, val ic: WasmIntegrationContext) extends WasmVmPool {
ic.logger.trace("new WasmVmPool")
private val counter = new AtomicInteger(-1)
private[wasm4s] val availableVms = new ConcurrentLinkedQueue[WasmVmImpl]()
private[wasm4s] val inUseVms = new ConcurrentLinkedQueue[WasmVmImpl]()
private val lastCacheUpdateTime = new AtomicLong(System.currentTimeMillis())
private val lastCacheUpdateCalls = new AtomicLong(0L)
private val creatingRef = new AtomicBoolean(false)
private val lastPluginVersion = new AtomicReference[String](null)
private val requestsSource = Source.queue[WasmVmPoolAction](ic.wasmQueueBufferSize, OverflowStrategy.dropTail)
private val prioritySource = Source.queue[WasmVmPoolAction](ic.wasmQueueBufferSize, OverflowStrategy.dropTail)
private val (priorityQueue, requestsQueue) = {
prioritySource
.mergePrioritizedMat(requestsSource, 99, 1, false)(Keep.both)
.map(handleAction)
.toMat(Sink.ignore)(Keep.both)
.run()(ic.materializer)
._1
}
// unqueue actions from the action queue
private def handleAction(action: WasmVmPoolAction): Unit = try {
val time = System.currentTimeMillis()
wasmConfig() match {
case None =>
// if we cannot find the current wasm config, something is wrong, we destroy the pool
destroyCurrentVms()
WasmVmPoolImpl.removePlugin(stableId)
action.fail(new RuntimeException(s"No more plugin ${stableId}"))
case Some(wcfg) => {
// first we ensure the wasm source has been fetched
if (!wcfg.source.isCached()(ic)) {
wcfg.source
.getWasm()(ic, ic.executionContext)
.andThen { case _ =>
priorityQueue.offer(action)
}(ic.executionContext)
} else if (wcfg.source.isFailed()(ic)) {
val until = wcfg.source.getFailedFromCache()(ic).get.until
if (until < time) {
wcfg.source.removeFromCache()(ic)
}
action.fail(new RuntimeException(s"accessing wasm binary was impossible. will retry after ${new DateTime(until).toString()}"))
} else {
// try to self refresh cache if more call than or time elapsed
if (ic.selfRefreshingPools && (((time - lastCacheUpdateTime.get()) > ic.wasmCacheTtl) || (lastCacheUpdateCalls.get() > maxCallsBetweenUpdates))) {
lastCacheUpdateTime.set(time)
lastCacheUpdateCalls.set(0L)
wcfg.source.getWasm()(ic, ic.executionContext)
}
// TODO: try to refresh cache if more than n calls or env.wasmCacheTtl elasped since last time
val changed = hasChanged(wcfg)
val available = hasAvailableVm(wcfg)
val creating = isVmCreating()
val atMax = atMaxPoolCapacity(wcfg)
// then we check if the underlying wasmcode + config has not changed since last time
if (changed) {
// if so, we destroy all current vms and recreate a new one
ic.logger.warn("plugin has changed, destroying old instances")
destroyCurrentVms()
createVm(wcfg, action.options)
}
// check if a vm is available
if (!available) {
// if not, but a new one is creating, just wait a little bit more
if (creating) {
priorityQueue.offer(action)
} else {
// check if we hit the max possible instances
if (atMax) {
// if so, just wait
priorityQueue.offer(action)
} else {
// if not, create a new instance because we need one
createVm(wcfg, action.options)
priorityQueue.offer(action)
}
}
} else {
// if so, acquire one
val vm = acquireVm()
action.provideVm(vm)
}
}
}
}
} catch {
case t: Throwable =>
t.printStackTrace()
action.fail(t)
}
// create a new vm for the pool
// we try to create vm one by one and to not create more than needed
private def createVm(config: WasmConfiguration, options: WasmVmInitOptions): Unit = synchronized {
if (creatingRef.compareAndSet(false, true)) {
val index = counter.incrementAndGet()
ic.logger.debug(s"creating vm: ${index}")
if (config.source.isFailed()(ic)) {
creatingRef.compareAndSet(true, false)
return
}
if (!config.source.isCached()(ic)) {
// this part should never happen anymore, but just in case
ic.logger.warn("fetching missing source")
Await.result(config.source.getWasm()(ic, ic.executionContext), 30.seconds)
}
lastPluginVersion.set(computeHash(config, config.source.cacheKey, ic.wasmScriptCache))
val cache = ic.wasmScriptCache
val key = config.source.cacheKey
val wasm = cache(key) match {
case CacheableWasmScript.CachedWasmScript(script, _) => script
case CacheableWasmScript.FetchingCachedWasmScript(_, script) => script
case CacheableWasmScript.FailedFetch(_, until) => throw new RuntimeException(s"accessing wasm binary was impossible. will retry after ${new DateTime(until).toString()}")
case _ => throw new RuntimeException("unable to get wasm source from cache. this should not happen !")
}
// val hash = wasm.sha256
val resolver = new WasmSourceResolver()
val source = resolver.resolve("wasm", wasm.toByteBuffer.array())
val vmDataRef = new AtomicReference[WasmVmData](null)
val addedFunctions = options.addHostFunctions(vmDataRef)
val functions: Array[HostFunction[_ <: HostUserData]] = {
val fs: Array[HostFunction[_ <: HostUserData]] = if (options.importDefaultHostFunctions) {
ic.hostFunctions(config, stableId) ++ addedFunctions
} else {
addedFunctions.toArray[HostFunction[_ <: HostUserData]]
}
if (config.opa) {
val opaFunctions: Seq[HostFunction[_ <: HostUserData]] = OPA.getFunctions(config).collect {
case func if func.authorized(config) => func.function
}
fs ++ opaFunctions
} else {
fs
}
}
val memories = LinearMemories.getMemories(config)
val instance = new Plugin(
new Manifest(
Seq[org.extism.sdk.wasm.WasmSource](source).asJava,
new MemoryOptions(config.memoryPages),
config.config.asJava,
config.allowedHosts.asJava,
config.allowedPaths.asJava
),
config.wasi,
functions,
memories
)
// val instance = template.instantiate(engine, functions, memories, config.wasi)
val vm = WasmVmImpl(
index,
config.killOptions.maxCalls,
config.memoryPages * (64L * 1024L),
options.resetMemory,
instance,
vmDataRef,
memories,
functions,
this
)
availableVms.offer(vm)
creatingRef.compareAndSet(true, false)
}
}
// acquire an available vm for work
private def acquireVm(): WasmVmImpl = synchronized {
if (availableVms.size() > 0) {
availableVms.synchronized {
val vm = availableVms.poll()
availableVms.remove(vm)
inUseVms.offer(vm)
vm
}
} else {
throw new RuntimeException("no instances available")
}
}
// release the vm to be available for other tasks
private[wasm4s] def release(vm: WasmVmImpl): Unit = synchronized {
availableVms.synchronized {
availableVms.offer(vm)
inUseVms.remove(vm)
}
}
// do not consider the vm anymore for more work (the vm is being dropped for some reason)
private[wasm4s] def ignore(vm: WasmVmImpl): Unit = synchronized {
availableVms.synchronized {
inUseVms.remove(vm)
}
}
// do not consider the vm anymore for more work (the vm is being dropped for some reason)
private[wasm4s] def clear(vm: WasmVmImpl): Unit = synchronized {
availableVms.synchronized {
availableVms.remove(vm)
}
}
private[wasm4s] def wasmConfig(): Option[WasmConfiguration] = {
optConfig.orElse(ic.wasmConfigSync(stableId)).orElse(ic.wasmConfig(stableId).await(30.seconds)) // ugly but ...
}
private def hasAvailableVm(plugin: WasmConfiguration): Boolean =
availableVms.size() > 0 && (inUseVms.size < plugin.instances)
private def isVmCreating(): Boolean = creatingRef.get()
private def atMaxPoolCapacity(plugin: WasmConfiguration): Boolean = (availableVms.size + inUseVms.size) >= plugin.instances
// close the current pool
private[wasm4s] def close(): Unit = availableVms.synchronized {
// engine.close()
}
// destroy all vms and clear everything in order to destroy the current pool
private[wasm4s] def destroyCurrentVms(): Unit = availableVms.synchronized {
ic.logger.info("destroying all vms")
availableVms.asScala.foreach(_.destroy())
availableVms.clear()
inUseVms.clear()
//counter.set(0)
// templateRef.set(null)
creatingRef.set(false)
lastPluginVersion.set(null)
}
// compute the current hash for a tuple (wasmcode + config)
private def computeHash(
config: WasmConfiguration,
key: String,
cache: TrieMap[String, CacheableWasmScript]
): String = {
config.json.stringify.sha512 + "#" + cache
.get(key)
.map {
case CacheableWasmScript.CachedWasmScript(wasm, _) => wasm.sha512
case CacheableWasmScript.FetchingCachedWasmScript(_, wasm) => wasm.sha512
case CacheableWasmScript.FailedFetch(_, _) => "failed"
case _ => "fetching"
}
.getOrElse("null")
}
// compute if the source (wasm code + config) is the same than current
private def hasChanged(config: WasmConfiguration): Boolean = availableVms.synchronized {
val key = config.source.cacheKey
val cache = ic.wasmScriptCache
var oldHash = lastPluginVersion.get()
if (oldHash == null) {
oldHash = computeHash(config, key, cache)
lastPluginVersion.set(oldHash)
}
cache.get(key) match {
case Some(CacheableWasmScript.CachedWasmScript(_, _)) => {
val currentHash = computeHash(config, key, cache)
oldHash != currentHash
}
case Some(CacheableWasmScript.FetchingCachedWasmScript(_, _)) => {
val currentHash = computeHash(config, key, cache)
oldHash != currentHash
}
case Some(CacheableWasmScript.FailedFetch(_, _)) => false
case _ => false
}
}
// get a pooled vm when one available.
// Do not forget to release it after usage
def getPooledVm(options: WasmVmInitOptions = WasmVmInitOptions.empty()): Future[WasmVm] = {
val p = Promise[WasmVmImpl]()
requestsQueue.offer(WasmVmPoolAction(p, options))
p.future
}
// borrow a vm for sync operations
def withPooledVm[A](options: WasmVmInitOptions = WasmVmInitOptions.empty())(f: WasmVm => A): Future[A] = {
implicit val ev = ic
implicit val ec = ic.executionContext
getPooledVm(options).flatMap { vm =>
val p = Promise[A]()
try {
val ret = f(vm)
p.trySuccess(ret)
} catch {
case e: Throwable =>
p.tryFailure(e)
} finally {
vm.release()
}
p.future
}
}
// borrow a vm for async operations
def withPooledVmF[A](options: WasmVmInitOptions = WasmVmInitOptions.empty())(f: WasmVm => Future[A]): Future[A] = {
implicit val ev = ic
implicit val ec = ic.executionContext
getPooledVm(options).flatMap { vm =>
f(vm).andThen { case _ =>
vm.release()
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy