
sbt.TestFramework.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of testing_2.11 Show documentation
Show all versions of testing_2.11 Show documentation
sbt is an interactive build tool
The newest version!
/* sbt -- Simple Build Tool
* Copyright 2008, 2009 Steven Blundy, Mark Harrah, Josh Cough
*/
package sbt
import java.io.File
import java.net.URLClassLoader
import testing.{ Logger => TLogger, Task => TestTask, _ }
import org.scalatools.testing.{ Framework => OldFramework }
import sbt.internal.inc.classpath.{ ClasspathUtilities, DualLoader, FilteredLoader }
import sbt.internal.inc.ScalaInstance
import scala.annotation.tailrec
import sbt.util.Logger
import sbt.io.IO
object TestResult extends Enumeration {
val Passed, Failed, Error = Value
}
object TestFrameworks {
val ScalaCheck = new TestFramework("org.scalacheck.ScalaCheckFramework")
val ScalaTest = new TestFramework("org.scalatest.tools.Framework", "org.scalatest.tools.ScalaTestFramework")
val Specs = new TestFramework("org.specs.runner.SpecsFramework")
val Specs2 = new TestFramework("org.specs2.runner.Specs2Framework", "org.specs2.runner.SpecsFramework")
val JUnit = new TestFramework("com.novocode.junit.JUnitFramework")
}
case class TestFramework(implClassNames: String*) {
@tailrec
private def createFramework(loader: ClassLoader, log: Logger, frameworkClassNames: List[String]): Option[Framework] = {
frameworkClassNames match {
case head :: tail =>
try {
Some(Class.forName(head, true, loader).newInstance match {
case newFramework: Framework => newFramework
case oldFramework: OldFramework => new FrameworkWrapper(oldFramework)
})
} catch {
case e: ClassNotFoundException =>
log.debug("Framework implementation '" + head + "' not present.");
createFramework(loader, log, tail)
}
case Nil =>
None
}
}
def create(loader: ClassLoader, log: Logger): Option[Framework] =
createFramework(loader, log, implClassNames.toList)
}
final class TestDefinition(val name: String, val fingerprint: Fingerprint, val explicitlySpecified: Boolean, val selectors: Array[Selector]) {
override def toString = "Test " + name + " : " + TestFramework.toString(fingerprint)
override def equals(t: Any) =
t match {
case r: TestDefinition => name == r.name && TestFramework.matches(fingerprint, r.fingerprint)
case _ => false
}
override def hashCode: Int = (name.hashCode, TestFramework.hashCode(fingerprint)).hashCode
}
final class TestRunner(delegate: Runner, listeners: Seq[TestReportListener], log: Logger) {
final def tasks(testDefs: Set[TestDefinition]): Array[TestTask] =
delegate.tasks(testDefs.map(df => new TaskDef(df.name, df.fingerprint, df.explicitlySpecified, df.selectors)).toArray)
final def run(taskDef: TaskDef, testTask: TestTask): (SuiteResult, Seq[TestTask]) =
{
val testDefinition = new TestDefinition(taskDef.fullyQualifiedName, taskDef.fingerprint, taskDef.explicitlySpecified, taskDef.selectors)
log.debug("Running " + taskDef)
val name = testDefinition.name
def runTest() =
{
// here we get the results! here is where we'd pass in the event listener
val results = new scala.collection.mutable.ListBuffer[Event]
val handler = new EventHandler { def handle(e: Event): Unit = { results += e } }
val loggers = listeners.flatMap(_.contentLogger(testDefinition))
val nestedTasks =
try testTask.execute(handler, loggers.map(_.log).toArray)
finally loggers.foreach(_.flush())
val event = TestEvent(results)
safeListenersCall(_.testEvent(event))
(SuiteResult(results), nestedTasks.toSeq)
}
safeListenersCall(_.startGroup(name))
try {
val (suiteResult, nestedTasks) = runTest()
safeListenersCall(_.endGroup(name, suiteResult.result))
(suiteResult, nestedTasks)
} catch {
case e: Throwable =>
safeListenersCall(_.endGroup(name, e))
(SuiteResult.Error, Seq.empty[TestTask])
}
}
protected def safeListenersCall(call: (TestReportListener) => Unit): Unit =
TestFramework.safeForeach(listeners, log)(call)
}
object TestFramework {
def getFingerprints(framework: Framework): Seq[Fingerprint] =
framework.getClass.getMethod("fingerprints").invoke(framework) match {
case fingerprints: Array[Fingerprint] => fingerprints.toList
case _ => sys.error("Could not call 'fingerprints' on framework " + framework)
}
private[sbt] def safeForeach[T](it: Iterable[T], log: Logger)(f: T => Unit): Unit =
it.foreach(i => try f(i) catch { case e: Exception => log.trace(e); log.error(e.toString) })
private[sbt] def hashCode(f: Fingerprint): Int = f match {
case s: SubclassFingerprint => (s.isModule, s.superclassName).hashCode
case a: AnnotatedFingerprint => (a.isModule, a.annotationName).hashCode
case _ => 0
}
def matches(a: Fingerprint, b: Fingerprint) =
(a, b) match {
case (a: SubclassFingerprint, b: SubclassFingerprint) => a.isModule == b.isModule && a.superclassName == b.superclassName
case (a: AnnotatedFingerprint, b: AnnotatedFingerprint) => a.isModule == b.isModule && a.annotationName == b.annotationName
case _ => false
}
def toString(f: Fingerprint): String =
f match {
case sf: SubclassFingerprint => "subclass(" + sf.isModule + ", " + sf.superclassName + ")"
case af: AnnotatedFingerprint => "annotation(" + af.isModule + ", " + af.annotationName + ")"
case _ => f.toString
}
def testTasks(frameworks: Map[TestFramework, Framework],
runners: Map[TestFramework, Runner],
testLoader: ClassLoader,
tests: Seq[TestDefinition],
log: Logger,
listeners: Seq[TestReportListener]): (() => Unit, Seq[(String, TestFunction)], TestResult.Value => () => Unit) =
{
val mappedTests = testMap(frameworks.values.toSeq, tests)
if (mappedTests.isEmpty)
(() => (), Nil, _ => () => ())
else
createTestTasks(testLoader, runners.map { case (tf, r) => (frameworks(tf), new TestRunner(r, listeners, log)) }, mappedTests, tests, log, listeners)
}
private[this] def order(mapped: Map[String, TestFunction], inputs: Seq[TestDefinition]): Seq[(String, TestFunction)] =
for (d <- inputs; act <- mapped.get(d.name)) yield (d.name, act)
private[this] def testMap(frameworks: Seq[Framework], tests: Seq[TestDefinition]): Map[Framework, Set[TestDefinition]] =
{
import scala.collection.mutable.{ HashMap, HashSet, Set }
val map = new HashMap[Framework, Set[TestDefinition]]
def assignTest(test: TestDefinition): Unit = {
def isTestForFramework(framework: Framework) = getFingerprints(framework).exists { t => matches(t, test.fingerprint) }
for (framework <- frameworks.find(isTestForFramework))
map.getOrElseUpdate(framework, new HashSet[TestDefinition]) += test
}
if (frameworks.nonEmpty)
for (test <- tests) assignTest(test)
map.toMap.mapValues(_.toSet)
}
private def createTestTasks(loader: ClassLoader, runners: Map[Framework, TestRunner], tests: Map[Framework, Set[TestDefinition]], ordered: Seq[TestDefinition], log: Logger, listeners: Seq[TestReportListener]) =
{
val testsListeners = listeners collect { case tl: TestsListener => tl }
def foreachListenerSafe(f: TestsListener => Unit): () => Unit = () => safeForeach(testsListeners, log)(f)
import TestResult.{ Error, Passed, Failed }
val startTask = foreachListenerSafe(_.doInit)
val testTasks =
tests flatMap {
case (framework, testDefinitions) =>
val runner = runners(framework)
val testTasks = withContextLoader(loader) { runner.tasks(testDefinitions) }
for (testTask <- testTasks) yield {
val taskDef = testTask.taskDef
(taskDef.fullyQualifiedName, createTestFunction(loader, taskDef, runner, testTask))
}
}
val endTask = (result: TestResult.Value) => foreachListenerSafe(_.doComplete(result))
(startTask, order(testTasks, ordered), endTask)
}
private[this] def withContextLoader[T](loader: ClassLoader)(eval: => T): T =
{
val oldLoader = Thread.currentThread.getContextClassLoader
Thread.currentThread.setContextClassLoader(loader)
try { eval } finally { Thread.currentThread.setContextClassLoader(oldLoader) }
}
def createTestLoader(classpath: Seq[File], scalaInstance: ScalaInstance, tempDir: File): ClassLoader =
{
val interfaceJar = IO.classLocationFile(classOf[testing.Framework])
val interfaceFilter = (name: String) => name.startsWith("org.scalatools.testing.") || name.startsWith("sbt.testing.")
val notInterfaceFilter = (name: String) => !interfaceFilter(name)
val dual = new DualLoader(scalaInstance.loader, notInterfaceFilter, x => true, getClass.getClassLoader, interfaceFilter, x => false)
val main = ClasspathUtilities.makeLoader(classpath, dual, scalaInstance, tempDir)
// TODO - There's actually an issue with the classpath facility such that unmanagedScalaInstances are not added
// to the classpath correctly. We have a temporary workaround here.
val cp: Seq[File] =
if (scalaInstance.isManagedVersion) interfaceJar +: classpath
else scalaInstance.allJars ++ (interfaceJar +: classpath)
ClasspathUtilities.filterByClasspath(cp, main)
}
def createTestFunction(loader: ClassLoader, taskDef: TaskDef, runner: TestRunner, testTask: TestTask): TestFunction =
new TestFunction(taskDef, runner, (r: TestRunner) => withContextLoader(loader) { r.run(taskDef, testTask) }) { def tags = testTask.tags }
}
abstract class TestFunction(val taskDef: TaskDef, val runner: TestRunner, fun: (TestRunner) => (SuiteResult, Seq[TestTask])) {
def apply(): (SuiteResult, Seq[TestTask]) = fun(runner)
def tags: Seq[String]
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy