All Downloads are FREE. Search and download functionalities are using the official Maven repository.

wvlet.airspec.Compat.scala Maven / Gradle / Ivy

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package wvlet.airspec

import java.lang.reflect.InvocationTargetException
import sbt.testing.Fingerprint
import wvlet.log.LogFormatter.SourceCodeLogFormatter
import wvlet.log.{LogSupport, Logger}

import scala.annotation.tailrec
import scala.util.{Failure, Success, Try}
import wvlet.airframe.surface.reflect.ReflectTypeUtil
import wvlet.airspec.Framework.{AirSpecClassFingerPrint, AirSpecObjectFingerPrint}
import wvlet.airspec.spi.{AirSpecException, Asserts}

import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.{Executors, ThreadFactory}
import scala.concurrent.ExecutionContext

/**
  */
private[airspec] object Compat extends CompatApi with LogSupport {
  override def isScalaJVM    = true
  override def isScalaJs     = false
  override def isScalaNative = false

  override private[airspec] val executionContext: ExecutionContext =
    ExecutionContext.fromExecutorService(Executors.newCachedThreadPool(newDaemonThreadFactory("airspec-executor")))

  /**
    * Create a thread factory for daemon threads, which do not block JVM shutdown
    *
    * @param name
    *   the name of the new thread group. New threads will be named (name)-1, (name)-2, etc.
    */
  private def newDaemonThreadFactory(name: String): ThreadFactory = new ThreadFactory {
    private val group: ThreadGroup = new ThreadGroup(Thread.currentThread().getThreadGroup(), name)
    private val threadNumber       = new AtomicInteger(1)

    override def newThread(r: Runnable): Thread = {
      val threadName = s"${name}-${threadNumber.getAndIncrement()}"
      val thread     = new Thread(group, r, threadName)
      thread.setName(threadName)
      thread.setDaemon(true)
      thread
    }
  }

  private[airspec] def findCompanionObjectOf(fullyQualifiedName: String, classLoader: ClassLoader): Option[Any] = {
    val cls = classLoader.loadClass(fullyQualifiedName)
    ReflectTypeUtil.companionObject(cls)
  }

  private[airspec] def getFingerprint(fullyQualifiedName: String, classLoader: ClassLoader): Option[Fingerprint] = {
    Try(findCompanionObjectOf(fullyQualifiedName, classLoader)).toOption
      .flatMap {
        case Some(spec: AirSpecSpi) =>
          Some(AirSpecObjectFingerPrint)
        case other =>
          None
      }
      .orElse {
        Try(classLoader.loadClass(fullyQualifiedName)).toOption
          .flatMap { x =>
            if (classOf[AirSpec].isAssignableFrom(x))
              Some(AirSpecClassFingerPrint)
            else {
              None
            }
          }
      }
  }

  private[airspec] def newInstanceOf(fullyQualifiedName: String, classLoader: ClassLoader): Option[Any] = {
    Try(classLoader.loadClass(fullyQualifiedName).getDeclaredConstructor().newInstance()) match {
      case Success(x) => Some(x)
      case Failure(e: InvocationTargetException) if e.getCause != null =>
        if (classOf[spi.AirSpecException].isAssignableFrom(e.getCause.getClass)) {
          // For assertion failrues, throw it as is
          throw e
        } else {
          // For other failures when instantiating the object, throw the cause
          throw e.getCause
        }
      case _ =>
        // Ignore other types of failures, which should not happen in general
        None
    }
  }

  private[airspec] def withLogScanner[U](block: => U): U = {
    try {
      startLogScanner
      block
    } finally {
      stopLogScanner
    }
  }

  private[airspec] def startLogScanner: Unit = {
    Logger.setDefaultFormatter(SourceCodeLogFormatter)

    // Periodically scan log level file
    Logger.scheduleLogLevelScan
  }
  private[airspec] def stopLogScanner: Unit = {
    Logger.stopScheduledLogLevelScan
  }

  @tailrec private[airspec] def findCause(e: Throwable): Throwable = {
    e match {
      case i: InvocationTargetException => findCause(i.getTargetException)
      case _                            => e
    }
  }

  override private[airspec] def getSpecName(cl: Class[_]): String = {
    var name = cl.getName

    if (name.endsWith("$")) {
      // Remove trailing $ of Scala Object name
      name = name.substring(0, name.length - 1)
    }

    // When class is an anonymous trait
    if (name.contains("$anon$")) {
      val interfaces = cl.getInterfaces
      if (interfaces != null && interfaces.length > 0) {
        // Use the first interface name instead of the anonymous name and Airframe SessionHolder (injected at compile-time)
        interfaces
          .map(_.getName)
          .find(x => x != "wvlet.airframe.SessionHolder" && !x.contains("$anon$"))
          .foreach(name = _)
      }
    }
    name
  }

  private[airspec] def getContextClassLoader: ClassLoader = {
    Thread.currentThread().getContextClassLoader
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy