Skip to content

Commit 032868a

Browse files
committed
Run all found test frameworks, rather than just one (on the JVM)
1 parent 65a0c49 commit 032868a

File tree

2 files changed

+169
-52
lines changed

2 files changed

+169
-52
lines changed

modules/integration/src/test/scala/scala/cli/integration/TestTestDefinitions.scala

+39
Original file line numberDiff line numberDiff line change
@@ -810,4 +810,43 @@ abstract class TestTestDefinitions extends ScalaCliSuite with TestScalaVersionAr
810810
expect(res.out.text().contains(expectedMessage))
811811
}
812812
}
813+
814+
test("multiple test frameworks") {
815+
val scalatestMessage = "Hello from ScalaTest"
816+
val munitMessage = "Hello from Munit"
817+
TestInputs(
818+
os.rel / "project.scala" ->
819+
s"""//> using test.dep org.scalatest::scalatest::3.2.19
820+
|//> using test.dep org.scalameta::munit::$munitVersion
821+
|""".stripMargin,
822+
os.rel / "scalatest.test.scala" ->
823+
s"""import org.scalatest.flatspec.AnyFlatSpec
824+
|
825+
|class ScalaTestSpec extends AnyFlatSpec {
826+
| "example" should "work" in {
827+
| assertResult(1)(1)
828+
| println("$scalatestMessage")
829+
| }
830+
|}
831+
|""".stripMargin,
832+
os.rel / "munit.test.scala" ->
833+
s"""import munit.FunSuite
834+
|
835+
|class Munit extends FunSuite {
836+
| test("foo") {
837+
| assert(2 + 2 == 4)
838+
| println("$munitMessage")
839+
| }
840+
|}
841+
|""".stripMargin
842+
).fromRoot { root =>
843+
val r = os.proc(TestUtil.cli, "test", extraOptions, ".").call(cwd = root)
844+
val output = r.out.trim()
845+
expect(output.nonEmpty)
846+
expect(output.contains(scalatestMessage))
847+
expect(countSubStrings(output, scalatestMessage) == 1)
848+
expect(output.contains(munitMessage))
849+
expect(countSubStrings(output, munitMessage) == 1)
850+
}
851+
}
813852
}

modules/test-runner/src/main/scala/scala/build/testrunner/DynamicTestRunner.scala

+130-52
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,11 @@ object DynamicTestRunner {
8989
def listClasses(classPath: Seq[Path], keepJars: Boolean): Iterator[String] =
9090
classPath.iterator.flatMap(listClasses(_, keepJars))
9191

92-
def findFrameworkService(loader: ClassLoader): Option[Framework] =
92+
def findFrameworkServices(loader: ClassLoader): Seq[Framework] =
9393
ServiceLoader.load(classOf[Framework], loader)
9494
.iterator()
9595
.asScala
96-
.take(1)
97-
.toList
98-
.headOption
96+
.toSeq
9997

10098
def loadFramework(
10199
loader: ClassLoader,
@@ -106,11 +104,11 @@ object DynamicTestRunner {
106104
constructor.newInstance().asInstanceOf[Framework]
107105
}
108106

109-
def findFramework(
107+
def findFrameworks(
110108
classPath: Seq[Path],
111109
loader: ClassLoader,
112110
preferredClasses: Seq[String]
113-
): Option[Framework] = {
111+
): Seq[Framework] = {
114112
val frameworkCls = classOf[Framework]
115113
(preferredClasses.iterator ++ listClasses(classPath, true))
116114
.flatMap { name =>
@@ -144,9 +142,7 @@ object DynamicTestRunner {
144142
case _: NoSuchMethodException => Iterator.empty
145143
}
146144
}
147-
.take(1)
148-
.toList
149-
.headOption
145+
.toSeq
150146
}
151147

152148
/** Based on junit-interface [GlobFilter.
@@ -220,15 +216,83 @@ object DynamicTestRunner {
220216

221217
val classLoader = Thread.currentThread().getContextClassLoader
222218
val classPath0 = TestRunner.classPath(classLoader)
223-
val framework = testFrameworkOpt.map(loadFramework(classLoader, _))
224-
.orElse(findFrameworkService(classLoader))
225-
.orElse(findFramework(classPath0, classLoader, TestRunner.commonTestFrameworks))
219+
val frameworks = testFrameworkOpt
220+
.map(loadFramework(classLoader, _))
221+
.map(Seq(_))
226222
.getOrElse {
227-
if (verbosity >= 2)
228-
sys.error("No test framework found")
229-
else {
230-
System.err.println("No test framework found")
231-
sys.exit(1)
223+
// needed for Scala 2.12
224+
def distinctBy[A, B](seq: Seq[A])(f: A => B): Seq[A] = {
225+
@annotation.tailrec
226+
def loop(remaining: Seq[A], seen: Set[B], acc: Vector[A]): Vector[A] =
227+
if (remaining.isEmpty) acc
228+
else {
229+
val head = remaining.head
230+
val tail = remaining.tail
231+
val key = f(head)
232+
if (seen(key)) loop(tail, seen, acc)
233+
else loop(tail, seen + key, acc :+ head)
234+
}
235+
loop(seq, Set.empty, Vector.empty)
236+
}
237+
238+
def getFrameworkDescription(f: Framework): String =
239+
s"${f.name()} (${Option(f.getClass.getCanonicalName).getOrElse(f.toString)})"
240+
241+
val foundFrameworkServices = findFrameworkServices(classLoader)
242+
if (verbosity >= 2 && foundFrameworkServices.nonEmpty)
243+
System.err.println(
244+
s"""Found test framework services:
245+
| - ${foundFrameworkServices.map(getFrameworkDescription).mkString("\n - ")}
246+
|""".stripMargin
247+
)
248+
249+
val foundFrameworks =
250+
findFrameworks(classPath0, classLoader, TestRunner.commonTestFrameworks)
251+
if (verbosity >= 2 && foundFrameworks.nonEmpty)
252+
System.err.println(
253+
s"""Found test frameworks:
254+
| - ${foundFrameworks.map(getFrameworkDescription).mkString("\n - ")}
255+
|""".stripMargin
256+
)
257+
258+
val distinctFrameworks = distinctBy(foundFrameworkServices ++ foundFrameworks)(_.name())
259+
if (verbosity >= 2 && distinctFrameworks.nonEmpty)
260+
System.err.println(
261+
s"""Distinct test frameworks found (by framework name):
262+
| - ${distinctFrameworks.map(getFrameworkDescription).mkString("\n - ")}
263+
|""".stripMargin
264+
)
265+
266+
val finalFrameworks =
267+
distinctFrameworks
268+
.filter(f1 =>
269+
!distinctFrameworks
270+
.filter(_ != f1)
271+
.exists(f2 =>
272+
f1.getClass.isAssignableFrom(f2.getClass)
273+
)
274+
)
275+
if (verbosity >= 1 && finalFrameworks.nonEmpty)
276+
System.err.println(
277+
s"""Final list of test frameworks found:
278+
| - ${finalFrameworks.map(getFrameworkDescription).mkString("\n - ")}
279+
|""".stripMargin
280+
)
281+
282+
val skippedInheritedFrameworks = distinctFrameworks.diff(finalFrameworks)
283+
if (verbosity >= 1 && skippedInheritedFrameworks.nonEmpty)
284+
System.err.println(
285+
s"""The following test frameworks have been filtered out, as they're being inherited from by others:
286+
| - ${skippedInheritedFrameworks.map(getFrameworkDescription).mkString("\n - ")}
287+
|""".stripMargin
288+
)
289+
290+
finalFrameworks match {
291+
case f if f.nonEmpty => f
292+
case _ if verbosity >= 2 => sys.error("No test framework found")
293+
case _ =>
294+
System.err.println("No test framework found")
295+
sys.exit(1)
232296
}
233297
}
234298
def classes = {
@@ -237,41 +301,55 @@ object DynamicTestRunner {
237301
}
238302
val out = System.out
239303

240-
val fingerprints = framework.fingerprints()
241-
val runner = framework.runner(args0.toArray, Array(), classLoader)
242-
def clsFingerprints = classes.flatMap { cls =>
243-
matchFingerprints(classLoader, cls, fingerprints)
244-
.map((cls, _))
245-
.iterator
246-
}
247-
val taskDefs = clsFingerprints
248-
.filter {
249-
case (cls, _) =>
250-
testOnly.forall(pattern =>
251-
globPattern(pattern).matcher(cls.getName.stripSuffix("$")).matches()
252-
)
253-
}
254-
.map {
255-
case (cls, fp) =>
256-
new TaskDef(cls.getName.stripSuffix("$"), fp, false, Array(new SuiteSelector))
257-
}
258-
.toVector
259-
val initialTasks = runner.tasks(taskDefs.toArray)
260-
val events = TestRunner.runTasks(initialTasks, out)
261-
val failed = events.exists { ev =>
262-
ev.status == Status.Error ||
263-
ev.status == Status.Failure ||
264-
ev.status == Status.Canceled
265-
}
266-
val doneMsg = runner.done()
267-
if (doneMsg.nonEmpty)
268-
out.println(doneMsg)
269-
if (requireTests && events.isEmpty) {
270-
System.err.println("Error: no tests were run.")
271-
sys.exit(1)
272-
}
273-
if (failed)
274-
sys.exit(1)
304+
val exitCodes =
305+
frameworks
306+
.map { framework =>
307+
if (verbosity >= 1) System.err.println(s"Running test framework: ${framework.name}")
308+
val fingerprints = framework.fingerprints()
309+
val runner = framework.runner(args0.toArray, Array(), classLoader)
310+
311+
def clsFingerprints = classes.flatMap { cls =>
312+
matchFingerprints(classLoader, cls, fingerprints)
313+
.map((cls, _))
314+
.iterator
315+
}
316+
317+
val taskDefs = clsFingerprints
318+
.filter {
319+
case (cls, _) =>
320+
testOnly.forall(pattern =>
321+
globPattern(pattern).matcher(cls.getName.stripSuffix("$")).matches()
322+
)
323+
}
324+
.map {
325+
case (cls, fp) =>
326+
new TaskDef(cls.getName.stripSuffix("$"), fp, false, Array(new SuiteSelector))
327+
}
328+
.toVector
329+
val initialTasks = runner.tasks(taskDefs.toArray)
330+
val events = TestRunner.runTasks(initialTasks, out)
331+
val failed = events.exists { ev =>
332+
ev.status == Status.Error ||
333+
ev.status == Status.Failure ||
334+
ev.status == Status.Canceled
335+
}
336+
val doneMsg = runner.done()
337+
if (doneMsg.nonEmpty) out.println(doneMsg)
338+
if (requireTests && events.isEmpty) {
339+
System.err.println(s"Error: no tests were run for ${framework.name()}.")
340+
1
341+
}
342+
else if (failed) {
343+
System.err.println(s"Error: ${framework.name()} tests failed.")
344+
1
345+
}
346+
else {
347+
if (verbosity >= 1) System.err.println(s"${framework.name()} tests ran successfully.")
348+
0
349+
}
350+
}
351+
if (exitCodes.contains(1)) sys.exit(1)
352+
else sys.exit(0)
275353
}
276354
}
277355

0 commit comments

Comments
 (0)