All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.scalatestplus.junit5.ScalaTestEngine.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2001-2023 Artima, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.scalatestplus.junit5

import org.junit.platform.commons.support.ReflectionSupport
import org.junit.platform.engine.discovery.{ClassSelector, ClasspathRootSelector, ModuleSelector, PackageSelector, UniqueIdSelector}
import org.junit.platform.engine.support.descriptor.EngineDescriptor
import org.junit.platform.engine.support.discovery.SelectorResolver.{Match, Resolution}
import org.junit.platform.engine.support.discovery.{EngineDiscoveryRequestResolver, SelectorResolver}
import org.junit.platform.engine.{EngineDiscoveryRequest, ExecutionRequest, TestDescriptor, TestExecutionResult, UniqueId}
import org.scalatest.{Args, ConfigMap, DynaTags, Filter, ParallelTestExecution, Stopper, Suite, Tracker}

import java.lang.reflect.Modifier
import java.util.concurrent.atomic.AtomicInteger
import java.util.{Optional, UUID}
import java.util.concurrent.{ExecutorService, Executors, ThreadFactory}
import java.util.logging.Logger
import java.util.stream.Collectors
import scala.collection.JavaConverters._
import scala.reflect.NameTransformer
import scala.util.Try

/**
 * ScalaTest implementation for JUnit 5 Test Engine.
 */ 
class ScalaTestEngine extends org.junit.platform.engine.TestEngine {

  private val logger = Logger.getLogger(classOf[ScalaTestEngine].getName)

  /**
   * Test engine ID, return "scalatest".
   */
  def getId: String = "scalatest"

  /**
   * Discover ScalaTest suites, you can disable the discover by setting system property org.scalatestplus.junit5.ScalaTestEngine.disabled to "true".
   */
  def discover(discoveryRequest: EngineDiscoveryRequest, uniqueId: UniqueId): TestDescriptor = {
    // reference: https://blogs.oracle.com/javamagazine/post/junit-build-custom-test-engines-java
    //            https://software-matters.net/posts/custom-test-engine/

    val engineDesc = new EngineDescriptor(uniqueId, "ScalaTest EngineDescriptor")

    if (System.getProperty("org.scalatestplus.junit5.ScalaTestEngine.disabled") != "true") {
      logger.info("Starting test discovery...")

      val alwaysTruePredicate =
        new java.util.function.Predicate[String]() {
          def test(t: String): Boolean = true
        }

      val isSuitePredicate =
        new java.util.function.Predicate[Class[_]]() {
          def test(t: Class[_]): Boolean = 
            classOf[org.scalatest.Suite].isAssignableFrom(t) &&
            !Modifier.isAbstract(t.getModifiers) &&
            JUnitHelper.checkForPublicNoArgConstructor(t)
        }

      def classDescriptorFunction(aClass: Class[_]) =
        new java.util.function.Function[TestDescriptor, Optional[ScalaTestClassDescriptor]]() {
          def apply(parent: TestDescriptor): Optional[ScalaTestClassDescriptor] = {
            val suiteUniqueId = parent.getUniqueId.append(ScalaTestClassDescriptor.segmentType, aClass.getName)
            parent.getChildren.asScala.find(_.getUniqueId == suiteUniqueId) match {
              case Some(_) => Optional.empty[ScalaTestClassDescriptor]()
              case None => Optional.of(new ScalaTestClassDescriptor(engineDesc, suiteUniqueId, aClass, true))
            }
          }
        }

      val toMatch =
        new java.util.function.Function[TestDescriptor, java.util.stream.Stream[Match]]() {
          def apply(td: TestDescriptor): java.util.stream.Stream[Match] = {
            java.util.stream.Stream.of[Match](Match.exact(td))
          }
        }



      def addToParentFunction(context: SelectorResolver.Context) =
        new java.util.function.Function[Class[_], java.util.stream.Stream[Match]]() {
          def apply(aClass: Class[_]): java.util.stream.Stream[Match] = {
            context.addToParent(classDescriptorFunction(aClass))
              .map[java.util.stream.Stream[Match]](toMatch)
              .orElse(java.util.stream.Stream.empty())
          }
        }

      val classSelectorResolver = new SelectorResolver {

        override def resolve(selector: ClasspathRootSelector, context: SelectorResolver.Context): SelectorResolver.Resolution = {
          val matches =
            ReflectionSupport.findAllClassesInClasspathRoot(selector.getClasspathRoot, isSuitePredicate, alwaysTruePredicate)
              .stream()
              .flatMap(addToParentFunction(context))
              .collect(Collectors.toSet())
          Resolution.matches(matches)
        }

        override def resolve(selector: PackageSelector, context: SelectorResolver.Context): SelectorResolver.Resolution = {
          val matches =
            ReflectionSupport.findAllClassesInPackage(selector.getPackageName, isSuitePredicate, alwaysTruePredicate)
              .stream()
              .flatMap(addToParentFunction(context))
              .collect(Collectors.toSet())
          Resolution.matches(matches)
        }

        override def resolve(selector: ModuleSelector, context: SelectorResolver.Context): SelectorResolver.Resolution = {
          val matches =
            ReflectionSupport.findAllClassesInModule(selector.getModuleName, isSuitePredicate, alwaysTruePredicate)
              .stream()
              .flatMap(addToParentFunction(context))
              .collect(Collectors.toSet())
          Resolution.matches(matches)
        }

        override def resolve(selector: ClassSelector, context: SelectorResolver.Context): SelectorResolver.Resolution = {
          val testClass = selector.getJavaClass
          if (isSuitePredicate.test(testClass)) {
            context.addToParent(
              new java.util.function.Function[TestDescriptor, Optional[ScalaTestClassDescriptor]]() {
                def apply(parent: TestDescriptor): Optional[ScalaTestClassDescriptor] = {
                  val suiteUniqueId = parent.getUniqueId.append(ScalaTestClassDescriptor.segmentType, testClass.getName)
                  parent.getChildren.asScala.find(_.getUniqueId == suiteUniqueId) match {
                    case Some(_) => Optional.empty[ScalaTestClassDescriptor]()
                    case None => Optional.of(new ScalaTestClassDescriptor(engineDesc, suiteUniqueId, testClass, true))
                  }
                }
              })
            .map[Resolution](
              new java.util.function.Function[TestDescriptor, Resolution]() {
                def apply(td: TestDescriptor): Resolution = Resolution.`match`(Match.exact(td))
              }
            ).orElse(Resolution.unresolved())
          }
          else
            Resolution.unresolved()
        }
      }

      val uniqueIdSelectorResolver = new SelectorResolver {
        override def resolve(selector: UniqueIdSelector, context: SelectorResolver.Context): SelectorResolver.Resolution = {
          selector.getUniqueId.getSegments.asScala.toList match {
            case engineSeg :: suiteSeg :: testSeg :: Nil if engineSeg.getType == "engine" && engineSeg.getValue == "scalatest" && testSeg.getType == "test" && suiteSeg.getType == ScalaTestClassDescriptor.segmentType =>
              val suiteClassName = suiteSeg.getValue
              val suiteClass = Class.forName(suiteClassName)
              if (classOf[org.scalatest.Suite].isAssignableFrom(suiteClass)) {
                context.addToParent(
                  new java.util.function.Function[TestDescriptor, Optional[ScalaTestClassDescriptor]]() {
                    def apply(parent: TestDescriptor): Optional[ScalaTestClassDescriptor] = {
                      val children = parent.getChildren.asScala
                      val suiteUniqueId = uniqueId.append(ScalaTestClassDescriptor.segmentType, suiteClass.getName)
                      val testUniqueId = suiteUniqueId.append("test", testSeg.getValue)
                      val testDesc = new ScalaTestDescriptor(testUniqueId, testSeg.getValue, None)
                      val (suiteDesc, result) =
                        children.find(_.getUniqueId == suiteUniqueId) match {
                          case Some(suiteDesc) =>
                            (suiteDesc, Optional.empty[ScalaTestClassDescriptor]())

                          case None =>
                            val suiteDesc = new ScalaTestClassDescriptor(engineDesc, suiteUniqueId, suiteClass, false)
                            (suiteDesc, Optional.of(suiteDesc))
                        }

                      suiteDesc.getChildren.asScala.find(_.getUniqueId == testUniqueId) match {
                        case Some(_) => // Do nothing if the test already exists
                        case None => suiteDesc.addChild(testDesc)
                      }

                      result
                    }
                  }
                )
                .map[Resolution](
                  new java.util.function.Function[TestDescriptor, Resolution]() {
                    def apply(td: TestDescriptor): Resolution = Resolution.`match`(Match.exact(td))
                  }
                )
                .orElse(Resolution.unresolved())
              }
              else
                Resolution.unresolved()

            case engineSeg :: suiteSeg :: Nil if engineSeg.getType == "engine" && engineSeg.getValue == "scalatest" && suiteSeg.getType == ScalaTestClassDescriptor.segmentType =>
              val suiteClassName = suiteSeg.getValue
              val suiteClass = Class.forName(suiteClassName)
              if (classOf[org.scalatest.Suite].isAssignableFrom(suiteClass)) {
                context.addToParent(
                  new java.util.function.Function[TestDescriptor, Optional[ScalaTestClassDescriptor]]() {
                    def apply(parent: TestDescriptor): Optional[ScalaTestClassDescriptor] = {
                      val children = parent.getChildren.asScala
                      val suiteUniqueId = uniqueId.append(ScalaTestClassDescriptor.segmentType, suiteClass.getName)
                      children.find(_.getUniqueId == suiteUniqueId) match {
                        case Some(_) => Optional.empty[ScalaTestClassDescriptor]()
                        case None => Optional.of(new ScalaTestClassDescriptor(engineDesc, suiteUniqueId, suiteClass, false))
                      }
                    }
                  }
                )
                .map[Resolution](
                  new java.util.function.Function[TestDescriptor, Resolution]() {
                    def apply(td: TestDescriptor): Resolution = Resolution.`match`(Match.exact(td))
                  }
                )
                .orElse(Resolution.unresolved())
              }
              else
                Resolution.unresolved()

            case _ => Resolution.unresolved()
          }
        }
      }

      val resolver = EngineDiscoveryRequestResolver.builder[EngineDescriptor]()
                     .addClassContainerSelectorResolver(isSuitePredicate)
                     .addSelectorResolver(classSelectorResolver)
                     .addSelectorResolver(uniqueIdSelectorResolver)
                     .build()

      resolver.resolve(discoveryRequest, engineDesc)

      logger.info("Completed test discovery, discovered suite count: " + engineDesc.getChildren.size())
    }

    engineDesc
  }

  /**
   * Execute ScalaTest suites, you can disable the ScalaTest suites execution by setting system property org.scalatestplus.junit.JUnit5TestEngine.disabled to "true".
   */
  def execute(request: ExecutionRequest): Unit = {
    if (System.getProperty("org.scalatestplus.junit5.ScalaTestEngine.disabled") != "true") {
      logger.info("Start tests execution...")
      val engineDesc = request.getRootTestDescriptor
      val listener = request.getEngineExecutionListener

      listener.executionStarted(engineDesc)
      engineDesc.getChildren.asScala.foreach { testDesc =>
        testDesc match {
          case clzDesc: ScalaTestClassDescriptor =>
            logger.info("Start execution of suite class " + clzDesc.suiteClass.getName + "...")
            listener.executionStarted(clzDesc)
            val suiteClass = clzDesc.suiteClass
            val canInstantiate = JUnitHelper.checkForPublicNoArgConstructor(suiteClass) && classOf[org.scalatest.Suite].isAssignableFrom(suiteClass)
            require(canInstantiate, "Must pass an org.scalatest.Suite with a public no-arg constructor")
            val suiteToRun = suiteClass.newInstance.asInstanceOf[org.scalatest.Suite]
            val reporter = new EngineExecutionListenerReporter(listener, clzDesc, engineDesc)
            val children = clzDesc.getChildren.asScala

            val filter = {
              if (children.isEmpty)
                Filter(
                  tagsToInclude = None,
                  excludeNestedSuites = false,
                  dynaTags = DynaTags(Map.empty, Map(suiteToRun.suiteId -> Map.empty))
                )
              if (suiteToRun.testNames.size == children.size)  // When testNames size is same as children size, it means all tests are selected, so no need to apply filter, this solves the issue of dynamic test names when running suite.
                Filter.default
              else {
                val SelectedTag = "Selected"
                val SelectedSet = Set(SelectedTag)
                val testNames = suiteToRun.testNames
                val desiredTests: Set[String] =
                  children.map(_.getDisplayName).filter { tn =>
                    testNames.contains(tn) || testNames.contains(NameTransformer.decode(tn))
                  }.toSet
                val taggedTests: Map[String, Set[String]] = desiredTests.map(_ -> SelectedSet).toMap
                val suiteId = suiteToRun.suiteId
                Filter(
                  tagsToInclude = Some(SelectedSet),
                  excludeNestedSuites = true,
                  dynaTags = DynaTags(Map.empty, Map(suiteId -> taggedTests))
                )
              }
            }

            if (suiteToRun.isInstanceOf[ParallelTestExecution]) {
              val numThreads = System.getProperty("org.scalatestplus.junit5.numThreads", "0")
              val poolSize =
                if (System.getProperty("org.scalatestplus.junit5.numThreads", "0") == "0")
                  Runtime.getRuntime.availableProcessors * 2
                else
                  Try(numThreads.toInt).getOrElse(throw new RuntimeException(Resources.invalidNumThreads(numThreads)))
              val threadFactory =
                new ThreadFactory {
                  val defaultThreadFactory = Executors.defaultThreadFactory
                  val atomicThreadCounter = new AtomicInteger
                  def newThread(runnable: Runnable): Thread = {
                    val thread = defaultThreadFactory.newThread(runnable)
                    thread.setName("ScalaTest-" + atomicThreadCounter.incrementAndGet())
                    thread
                  }
                }

              val execSvc: ExecutorService =
                if (poolSize > 0)
                  Executors.newFixedThreadPool(poolSize, threadFactory)
                else
                  Executors.newCachedThreadPool(threadFactory)
              val distributor = new ConcurrentDistributor(Args(reporter, Stopper.default, filter, ConfigMap.empty, None, new Tracker), execSvc)
              try {
                suiteToRun.run(None, Args(reporter, Stopper.default, filter, ConfigMap.empty, Some(distributor), new Tracker))
                distributor.waitUntilDone()
              } finally {
                execSvc.shutdown()
              }
            }
            else {
              val status = suiteToRun.run(None, Args(reporter, Stopper.default, filter, ConfigMap.empty, None, new Tracker))
              status.waitUntilCompleted()
            }

            listener.executionFinished(clzDesc, TestExecutionResult.successful())

            logger.info("Completed execution of suite class " + clzDesc.suiteClass.getName + ".")

          case otherDesc =>
            // Do nothing for other descriptor, just log it.
            logger.warning("Found test descriptor " + otherDesc.toString + " that is not supported, skipping.")
        }
      }
      listener.executionFinished(engineDesc, TestExecutionResult.successful())
      logger.info("Completed tests execution.")
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy