All Downloads are FREE. Search and download functionalities are using the official Maven repository.

mill.testkit.UnitTester.scala Maven / Gradle / Ivy

There is a newer version: 0.12.0-RC2-17-07e173
Show newest version
package mill.testkit

import mill.{Target, Task}
import mill.api.Result.OuterStack
import mill.api.{DummyInputStream, Result, SystemStreams, Val}
import mill.define.{InputImpl, TargetImpl}
import mill.eval.{Evaluator, EvaluatorImpl}
import mill.resolve.{Resolve, SelectMode}
import mill.util.PrintLogger
import mill.api.Strict.Agg

import java.io.{InputStream, PrintStream}

object UnitTester {
  case class Result[T](value: T, evalCount: Int)
  def apply(
      module: mill.testkit.TestBaseModule,
      sourceRoot: os.Path,
      failFast: Boolean = false,
      threads: Option[Int] = Some(1),
      outStream: PrintStream = System.out,
      errStream: PrintStream = System.err,
      inStream: InputStream = DummyInputStream,
      debugEnabled: Boolean = false,
      env: Map[String, String] = Evaluator.defaultEnv,
      resetSourcePath: Boolean = true
  ) = new UnitTester(
    module = module,
    sourceRoot = sourceRoot,
    failFast = failFast,
    threads = threads,
    outStream = outStream,
    errStream = errStream,
    inStream = inStream,
    debugEnabled = debugEnabled,
    env = env,
    resetSourcePath = resetSourcePath
  )
}

/**
 * @param module The module under test
 * @param failFast failFast mode enabled
 * @param threads explicitly used nr. of parallel threads
 */
class UnitTester(
    module: mill.testkit.TestBaseModule,
    sourceRoot: os.Path,
    failFast: Boolean,
    threads: Option[Int],
    outStream: PrintStream,
    errStream: PrintStream,
    inStream: InputStream,
    debugEnabled: Boolean,
    env: Map[String, String],
    resetSourcePath: Boolean
)(implicit fullName: sourcecode.FullName) {
  val outPath: os.Path = module.millSourcePath / "out"

  if (resetSourcePath) {
    os.remove.all(module.millSourcePath)
    os.makeDir.all(module.millSourcePath)

    for (sourceFileRoot <- Option(sourceRoot)) {
      os.copy.over(sourceFileRoot, module.millSourcePath, createFolders = true)
    }
  }

  object logger extends mill.util.PrintLogger(
        colored = true,
        enableTicker = true,
        mill.util.Colors.Default.info,
        mill.util.Colors.Default.error,
        new SystemStreams(out = outStream, err = errStream, in = inStream),
        debugEnabled = debugEnabled,
        context = "",
        new PrintLogger.State()
      ) {
    val prefix: String = {
      val idx = fullName.value.lastIndexOf(".")
      if (idx > 0) fullName.value.substring(0, idx)
      else fullName.value
    }
    override def error(s: String): Unit = super.error(s"${prefix}: ${s}")
    override def info(s: String): Unit = super.info(s"${prefix}: ${s}")
    override def debug(s: String): Unit = super.debug(s"${prefix}: ${s}")
    override def ticker(s: String): Unit = super.ticker(s"${prefix}: ${s}")
  }
  val evaluator: EvaluatorImpl = mill.eval.EvaluatorImpl(
    mill.api.Ctx.defaultHome,
    module.millSourcePath,
    outPath,
    outPath,
    Seq(module),
    logger,
    0,
    0,
    failFast = failFast,
    threadCount = threads,
    env = env,
    methodCodeHashSignatures = Map(),
    disableCallgraphInvalidation = false
  )

  def apply(args: String*): Either[Result.Failing[_], UnitTester.Result[Seq[_]]] = {
    mill.eval.Evaluator.currentEvaluator.withValue(evaluator) {
      Resolve.Tasks.resolve(evaluator.rootModules, args, SelectMode.Separated)
    } match {
      case Left(err) => Left(Result.Failure(err))
      case Right(resolved) => apply(resolved)
    }
  }

  def apply[T](task: Task[T]): Either[Result.Failing[T], UnitTester.Result[T]] = {
    apply(Seq(task)) match {
      case Left(f) => Left(f.asInstanceOf[Result.Failing[T]])
      case Right(UnitTester.Result(Seq(v), i)) =>
        Right(UnitTester.Result(v.asInstanceOf[T], i))
    }
  }

  def apply(
      tasks: Seq[Task[_]],
      dummy: DummyImplicit = null
  ): Either[Result.Failing[_], UnitTester.Result[Seq[_]]] = {
    val evaluated = evaluator.evaluate(tasks)

    if (evaluated.failing.keyCount == 0) {
      Right(
        UnitTester.Result(
          evaluated.rawValues.map(_.asInstanceOf[Result.Success[Val]].value.value),
          evaluated.evaluated.collect {
            case t: TargetImpl[_]
                if module.millInternal.targets.contains(t)
                  && !t.ctx.external => t
            case t: mill.define.Command[_] => t
          }.size
        )
      )
    } else {
      Left(
        evaluated
          .failing
          .lookupKey(evaluated.failing.keys().next)
          .items
          .next()
          .asFailing
          .get
          .map(_.value)
      )
    }
  }

  def fail(target: Target[_], expectedFailCount: Int, expectedRawValues: Seq[Result[_]]): Unit = {

    val res = evaluator.evaluate(Agg(target))

    val cleaned = res.rawValues.map {
      case Result.Exception(ex, _) => Result.Exception(ex, new OuterStack(Nil))
      case x => x.map(_.value)
    }

    assert(
      cleaned == expectedRawValues,
      res.failing.keyCount == expectedFailCount
    )

  }

  def check(targets: Agg[Task[_]], expected: Agg[Task[_]]): Unit = {

    val evaluated = evaluator.evaluate(targets)
      .evaluated
      .flatMap(_.asTarget)
      .filter(module.millInternal.targets.contains)
      .filter(!_.isInstanceOf[InputImpl[_]])
    assert(
      evaluated.toSet == expected.toSet,
      s"evaluated is not equal expected. evaluated=${evaluated}, expected=${expected}"
    )
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy