Skip to content

Commit 931b753

Browse files
committed
NIT Refactor test-runner
1 parent 83644e1 commit 931b753

File tree

4 files changed

+231
-240
lines changed

4 files changed

+231
-240
lines changed

modules/test-runner/src/main/scala/scala/build/testrunner/AsmTestRunner.scala

+6-25
Original file line numberDiff line numberDiff line change
@@ -200,18 +200,6 @@ object AsmTestRunner {
200200
.map(b => new String(b, StandardCharsets.UTF_8))
201201
.toSeq
202202

203-
def findFramework(
204-
classPath: Seq[Path],
205-
preferredClasses: Seq[String]
206-
): Option[String] = {
207-
val parentInspector = new ParentInspector(classPath)
208-
findFrameworks(
209-
classPath,
210-
preferredClasses,
211-
parentInspector
212-
).headOption // TODO: handle multiple frameworks
213-
}
214-
215203
def findFrameworks(
216204
classPath: Seq[Path],
217205
preferredClasses: Seq[String],
@@ -224,7 +212,7 @@ object AsmTestRunner {
224212
findInClassPath(classPath, name + ".class")
225213
.map { b =>
226214
def openStream() = new ByteArrayInputStream(b)
227-
(name, openStream _)
215+
(name, () => openStream())
228216
}
229217
}
230218
(preferredClassesByteCode ++ listClassesByteCode(classPath, true))
@@ -254,18 +242,11 @@ object AsmTestRunner {
254242
private var isInterfaceOpt = Option.empty[Boolean]
255243
private var isAbstractOpt = Option.empty[Boolean]
256244
private var implements0 = List.empty[String]
257-
def canBeTestSuite: Boolean = {
258-
val isModule = nameOpt.exists(_.endsWith("$"))
259-
!isAbstractOpt.contains(true) &&
260-
!isInterfaceOpt.contains(true) &&
261-
publicConstructorCount0 <= 1 &&
262-
isModule != (publicConstructorCount0 == 1)
263-
}
264-
def name = nameOpt.getOrElse(sys.error("Class not visited"))
265-
def publicConstructorCount = publicConstructorCount0
266-
def implements = implements0
267-
def isAbstract = isAbstractOpt.getOrElse(sys.error("Class not visited"))
268-
def isInterface = isInterfaceOpt.getOrElse(sys.error("Class not visited"))
245+
def name: String = nameOpt.getOrElse(sys.error("Class not visited"))
246+
def publicConstructorCount: Int = publicConstructorCount0
247+
def implements: Seq[String] = implements0
248+
def isAbstract: Boolean = isAbstractOpt.getOrElse(sys.error("Class not visited"))
249+
def isInterface: Boolean = isInterfaceOpt.getOrElse(sys.error("Class not visited"))
269250
override def visit(
270251
version: Int,
271252
access: Int,

modules/test-runner/src/main/scala/scala/build/testrunner/DynamicTestRunner.scala

+4-202
Original file line numberDiff line numberDiff line change
@@ -2,150 +2,13 @@ package scala.build.testrunner
22

33
import sbt.testing.{Logger => _, _}
44

5-
import java.lang.annotation.Annotation
6-
import java.lang.reflect.Modifier
7-
import java.nio.file.{Files, Path}
8-
import java.util.ServiceLoader
95
import java.util.regex.Pattern
106

117
import scala.annotation.tailrec
128
import scala.build.testrunner.FrameworkUtils._
13-
import scala.jdk.CollectionConverters._
149

1510
object DynamicTestRunner {
1611

17-
// adapted from https://github.com/com-lihaoyi/mill/blob/ab4d61a50da24fb7fac97c4453dd8a770d8ac62b/scalalib/src/Lib.scala#L156-L172
18-
private def matchFingerprints(
19-
loader: ClassLoader,
20-
cls: Class[_],
21-
fingerprints: Array[Fingerprint]
22-
): Option[Fingerprint] = {
23-
val isModule = cls.getName.endsWith("$")
24-
val publicConstructorCount = cls.getConstructors.count(c => Modifier.isPublic(c.getModifiers))
25-
val noPublicConstructors = publicConstructorCount == 0
26-
val definitelyNoTests = Modifier.isAbstract(cls.getModifiers) ||
27-
cls.isInterface ||
28-
publicConstructorCount > 1 ||
29-
isModule != noPublicConstructors
30-
if (definitelyNoTests)
31-
None
32-
else
33-
fingerprints.find {
34-
case f: SubclassFingerprint =>
35-
f.isModule == isModule &&
36-
loader.loadClass(f.superclassName())
37-
.isAssignableFrom(cls)
38-
39-
case f: AnnotatedFingerprint =>
40-
val annotationCls = loader.loadClass(f.annotationName())
41-
.asInstanceOf[Class[Annotation]]
42-
f.isModule == isModule && (
43-
cls.isAnnotationPresent(annotationCls) ||
44-
cls.getDeclaredMethods.exists(_.isAnnotationPresent(annotationCls)) ||
45-
cls.getMethods.exists { m =>
46-
m.isAnnotationPresent(annotationCls) &&
47-
Modifier.isPublic(m.getModifiers())
48-
}
49-
)
50-
}
51-
}
52-
53-
def listClasses(classPathEntry: Path, keepJars: Boolean): Iterator[String] =
54-
if (Files.isDirectory(classPathEntry)) {
55-
var stream: java.util.stream.Stream[Path] = null
56-
try {
57-
stream = Files.walk(classPathEntry, Int.MaxValue)
58-
stream
59-
.iterator
60-
.asScala
61-
.filter(_.getFileName.toString.endsWith(".class"))
62-
.map(classPathEntry.relativize(_))
63-
.map { p =>
64-
val count = p.getNameCount
65-
(0 until count).map(p.getName).mkString(".")
66-
}
67-
.map(_.stripSuffix(".class"))
68-
.toVector // fully consume stream before closing it
69-
.iterator
70-
}
71-
finally if (stream != null) stream.close()
72-
}
73-
else if (keepJars && Files.isRegularFile(classPathEntry)) {
74-
import java.util.zip._
75-
var zf: ZipFile = null
76-
try {
77-
zf = new ZipFile(classPathEntry.toFile)
78-
zf.entries
79-
.asScala
80-
// FIXME Check if these are files too
81-
.filter(_.getName.endsWith(".class"))
82-
.map(ent => ent.getName.stripSuffix(".class").replace("/", "."))
83-
.toVector // full consume ZipFile before closing it
84-
.iterator
85-
}
86-
finally if (zf != null) zf.close()
87-
}
88-
else Iterator.empty
89-
90-
def listClasses(classPath: Seq[Path], keepJars: Boolean): Iterator[String] =
91-
classPath.iterator.flatMap(listClasses(_, keepJars))
92-
93-
def findFrameworkServices(loader: ClassLoader): Seq[Framework] =
94-
ServiceLoader.load(classOf[Framework], loader)
95-
.iterator()
96-
.asScala
97-
.toSeq
98-
99-
def loadFramework(
100-
loader: ClassLoader,
101-
className: String
102-
): Framework = {
103-
val cls = loader.loadClass(className)
104-
val constructor = cls.getConstructor()
105-
constructor.newInstance().asInstanceOf[Framework]
106-
}
107-
108-
def findFrameworks(
109-
classPath: Seq[Path],
110-
loader: ClassLoader,
111-
preferredClasses: Seq[String]
112-
): Seq[Framework] = {
113-
val frameworkCls = classOf[Framework]
114-
(preferredClasses.iterator ++ listClasses(classPath, true))
115-
.flatMap { name =>
116-
val it: Iterator[Class[_]] =
117-
try Iterator(loader.loadClass(name))
118-
catch {
119-
case _: ClassNotFoundException | _: UnsupportedClassVersionError | _: NoClassDefFoundError | _: IncompatibleClassChangeError =>
120-
Iterator.empty
121-
}
122-
it
123-
}
124-
.flatMap { cls =>
125-
def isAbstract = Modifier.isAbstract(cls.getModifiers)
126-
def publicConstructorCount =
127-
cls.getConstructors.count { c =>
128-
Modifier.isPublic(c.getModifiers) && c.getParameterCount() == 0
129-
}
130-
val it: Iterator[Class[_]] =
131-
if (frameworkCls.isAssignableFrom(cls) && !isAbstract && publicConstructorCount == 1)
132-
Iterator(cls)
133-
else
134-
Iterator.empty
135-
it
136-
}
137-
.flatMap { cls =>
138-
try {
139-
val constructor = cls.getConstructor()
140-
Iterator(constructor.newInstance().asInstanceOf[Framework])
141-
}
142-
catch {
143-
case _: NoSuchMethodException => Iterator.empty
144-
}
145-
}
146-
.toSeq
147-
}
148-
14912
/** Based on junit-interface [GlobFilter.
15013
* compileGlobPattern](https://github.com/sbt/junit-interface/blob/f8c6372ed01ce86f15393b890323d96afbe6d594/src/main/java/com/novocode/junit/GlobFilter.java#L37)
15114
*
@@ -223,71 +86,10 @@ object DynamicTestRunner {
22386
.map(loadFramework(classLoader, _))
22487
.map(Seq(_))
22588
.getOrElse {
226-
// needed for Scala 2.12
227-
def distinctBy[A, B](seq: Seq[A])(f: A => B): Seq[A] = {
228-
@annotation.tailrec
229-
def loop(remaining: Seq[A], seen: Set[B], acc: Vector[A]): Vector[A] =
230-
if (remaining.isEmpty) acc
231-
else {
232-
val head = remaining.head
233-
val tail = remaining.tail
234-
val key = f(head)
235-
if (seen(key)) loop(tail, seen, acc)
236-
else loop(tail, seen + key, acc :+ head)
237-
}
238-
loop(seq, Set.empty, Vector.empty)
239-
}
240-
241-
val foundFrameworkServices = findFrameworkServices(classLoader)
242-
if (foundFrameworkServices.nonEmpty)
243-
logger.debug(
244-
s"""Found test framework services:
245-
| - ${foundFrameworkServices.map(_.description).mkString("\n - ")}
246-
|""".stripMargin
247-
)
248-
249-
val foundFrameworks =
250-
findFrameworks(classPath0, classLoader, TestRunner.commonTestFrameworks)
251-
if (foundFrameworks.nonEmpty)
252-
logger.debug(
253-
s"""Found test frameworks:
254-
| - ${foundFrameworks.map(_.description).mkString("\n - ")}
255-
|""".stripMargin
256-
)
257-
258-
val distinctFrameworks = distinctBy(foundFrameworkServices ++ foundFrameworks)(_.name())
259-
if (distinctFrameworks.nonEmpty)
260-
logger.debug(
261-
s"""Distinct test frameworks found (by framework name):
262-
| - ${distinctFrameworks.map(_.description).mkString("\n - ")}
263-
|""".stripMargin
264-
)
265-
266-
val finalFrameworks =
267-
distinctFrameworks
268-
.filter(f1 =>
269-
!distinctFrameworks
270-
.filter(_ != f1)
271-
.exists(f2 =>
272-
f1.getClass.isAssignableFrom(f2.getClass)
273-
)
274-
)
275-
if (finalFrameworks.nonEmpty)
276-
logger.log(
277-
s"""Final list of test frameworks found:
278-
| - ${finalFrameworks.map(_.description).mkString("\n - ")}
279-
|""".stripMargin
280-
)
281-
282-
val skippedInheritedFrameworks = distinctFrameworks.diff(finalFrameworks)
283-
if (skippedInheritedFrameworks.nonEmpty)
284-
logger.log(
285-
s"""The following test frameworks have been filtered out, as they're being inherited from by others:
286-
| - ${skippedInheritedFrameworks.map(_.description).mkString("\n - ")}
287-
|""".stripMargin
288-
)
289-
290-
finalFrameworks match {
89+
getFrameworksToRun(
90+
frameworkServices = findFrameworkServices(classLoader),
91+
frameworks = findFrameworks(classPath0, classLoader, TestRunner.commonTestFrameworks)
92+
)(logger) match {
29193
case f if f.nonEmpty => f
29294
case _ if verbosity >= 2 => sys.error("No test framework found")
29395
case _ =>

0 commit comments

Comments
 (0)