Skip to content

Commit b2fd081

Browse files
committed
zinc based test quick command
1 parent 56da947 commit b2fd081

File tree

7 files changed

+231
-21
lines changed

7 files changed

+231
-21
lines changed

core/define/src/mill/define/Task.scala

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,12 @@ object Task extends TaskBase {
122122
inline def Command[T](inline t: Result[T])(implicit
123123
inline w: W[T],
124124
inline ctx: mill.define.Ctx
125-
): Command[T] = ${ TaskMacros.commandImpl[T]('t)('w, 'ctx, exclusive = '{ false }) }
125+
): Command[T] = ${ TaskMacros.commandImpl[T]('t)('w, 'ctx, exclusive = '{ false }, '{ false }) }
126+
127+
inline def Command[T](persistent: Boolean)(inline t: Result[T])(implicit
128+
inline w: W[T],
129+
inline ctx: mill.define.Ctx
130+
): Command[T] = ${ TaskMacros.commandImpl[T]('t)('w, 'ctx, exclusive = '{ false }, '{ persistent }) }
126131

127132
/**
128133
* @param exclusive Exclusive commands run serially at the end of an evaluation,
@@ -140,7 +145,7 @@ object Task extends TaskBase {
140145
inline def apply[T](inline t: Result[T])(implicit
141146
inline w: W[T],
142147
inline ctx: mill.define.Ctx
143-
): Command[T] = ${ TaskMacros.commandImpl[T]('t)('w, 'ctx, '{ this.exclusive }) }
148+
): Command[T] = ${ TaskMacros.commandImpl[T]('t)('w, 'ctx, '{ this.exclusive }, '{ false }) }
144149
}
145150

146151
/**
@@ -391,7 +396,8 @@ class Command[+T](
391396
val ctx0: mill.define.Ctx,
392397
val writer: W[?],
393398
val isPrivate: Option[Boolean],
394-
val exclusive: Boolean
399+
val exclusive: Boolean,
400+
override val persistent: Boolean
395401
) extends NamedTask[T] {
396402

397403
override def asCommand: Some[Command[T]] = Some(this)
@@ -538,12 +544,13 @@ private object TaskMacros {
538544
)(t: Expr[Result[T]])(
539545
w: Expr[W[T]],
540546
ctx: Expr[mill.define.Ctx],
541-
exclusive: Expr[Boolean]
547+
exclusive: Expr[Boolean],
548+
persistent: Expr[Boolean]
542549
): Expr[Command[T]] = {
543550
appImpl[Command, T](
544551
(in, ev) =>
545552
'{
546-
new Command[T]($in, $ev, $ctx, $w, ${ taskIsPrivate() }, exclusive = $exclusive)
553+
new Command[T]($in, $ev, $ctx, $w, ${ taskIsPrivate() }, exclusive = $exclusive, persistent = $persistent)
547554
},
548555
t
549556
)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package mill.scalalib.api
2+
3+
final case class TransitiveSourceStampResults(
4+
currentStamps: Map[String, String],
5+
previousStamps: Option[Map[String, String]] = None
6+
) {
7+
lazy val changedSources: Set[String] = {
8+
previousStamps match {
9+
case Some(prevStamps) =>
10+
currentStamps.view
11+
.flatMap { (source, stamp) =>
12+
prevStamps.get(source) match {
13+
case None => Some(source) // new source
14+
case Some(prevStamp) => Option.when(stamp != prevStamp)(source) // changed source
15+
}
16+
}
17+
.toSet
18+
case None => currentStamps.keySet
19+
}
20+
}
21+
}
22+
23+
object TransitiveSourceStampResults {
24+
implicit val jsonFormatter: upickle.default.ReadWriter[TransitiveSourceStampResults] =
25+
upickle.default.macroRW
26+
}

scalalib/package.mill

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ object `package` extends RootModule with build.MillStableScalaModule {
2525
// (also transitively included by com.eed3si9n.jarjarabrams:jarjar-abrams-core)
2626
// perhaps the class can be copied here?
2727
Agg(build.Deps.scalaReflect(scalaVersion()))
28-
}
28+
} ++
29+
Agg(build.Deps.zinc)
2930
}
3031
def testIvyDeps = super.testIvyDeps() ++ Agg(build.Deps.TestDeps.scalaCheck)
3132
def testTransitiveDeps =

scalalib/src/mill/scalalib/JavaModule.scala

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ import mill.scalalib.bsp.{BspBuildTarget, BspModule, BspUri, JvmBuildTarget}
1818
import mill.scalalib.publish.Artifact
1919
import mill.util.Jvm
2020
import os.Path
21+
import mill.testrunner.TestResult
22+
import mill.scalalib.api.TransitiveSourceStampResults
23+
import scala.collection.immutable.TreeMap
24+
import scala.util.Try
2125

2226
/**
2327
* Core configuration required to compile a single Java compilation target
@@ -92,6 +96,173 @@ trait JavaModule
9296
case _: ClassNotFoundException => // if we can't find the classes, we certainly are not in a ScalaJSModule
9397
}
9498
}
99+
100+
def testQuick(args: String*): Command[(String, Seq[TestResult])] = Task.Command(persistent = true) {
101+
val quicktestFailedClassesLog = Task.dest / "quickTestFailedClasses.json"
102+
val invalidatedClassesLog = Task.dest / "invalidatedClasses.json"
103+
val failedTestClasses =
104+
if (!os.exists(quicktestFailedClassesLog)) {
105+
Set.empty[String]
106+
} else {
107+
Try {
108+
upickle.default.read[Seq[String]](os.read.stream(quicktestFailedClassesLog))
109+
}.getOrElse(Seq.empty[String]).toSet
110+
}
111+
112+
val transitiveStampsFile = Task.dest / "transitiveStamps.json"
113+
val previousStampsOpt = if (os.exists(transitiveStampsFile)) {
114+
val previousStamps = upickle.default.read[TransitiveSourceStampResults](
115+
os.read.stream(transitiveStampsFile)
116+
).currentStamps
117+
os.remove(transitiveStampsFile)
118+
Some(previousStamps)
119+
} else {
120+
None
121+
}
122+
123+
def getAnalysisStore(compileResult: CompilationResult): Option[xsbti.compile.CompileAnalysis] = {
124+
val analysisStore = sbt.internal.inc.consistent.ConsistentFileAnalysisStore.binary(
125+
file = compileResult.analysisFile.toIO,
126+
mappers = xsbti.compile.analysis.ReadWriteMappers.getEmptyMappers(),
127+
reproducible = true,
128+
parallelism = math.min(Runtime.getRuntime.availableProcessors(), 8)
129+
)
130+
val analysisOptional = analysisStore.get()
131+
if (analysisOptional.isPresent) Some(analysisOptional.get.getAnalysis) else None
132+
}
133+
134+
val combinedAnalysis = (compile() +: upstreamCompileOutput())
135+
.flatMap(getAnalysisStore)
136+
.flatMap {
137+
case analysis: sbt.internal.inc.Analysis => Some(analysis)
138+
case _ => None
139+
}
140+
.foldLeft(sbt.internal.inc.Analysis.empty)(_ ++ _)
141+
142+
val result = TransitiveSourceStampResults(
143+
currentStamps = TreeMap.from(
144+
combinedAnalysis.stamps.sources.view.map { (source, stamp) =>
145+
source.id() -> stamp.writeStamp()
146+
}
147+
),
148+
previousStamps = previousStampsOpt
149+
)
150+
151+
def getInvalidatedClasspaths(
152+
initialInvalidatedClassNames: Set[String],
153+
relations: sbt.internal.inc.Relations
154+
): Set[os.Path] = {
155+
val seen = collection.mutable.Set.empty[String]
156+
val seenList = collection.mutable.Buffer.empty[String]
157+
val queued = collection.mutable.Queue.from(initialInvalidatedClassNames)
158+
159+
while (queued.nonEmpty) {
160+
val current = queued.dequeue()
161+
seenList.append(current)
162+
seen.add(current)
163+
164+
for (next <- relations.usesInternalClass(current)) {
165+
if (!seen.contains(next)) {
166+
seen.add(next)
167+
queued.enqueue(next)
168+
}
169+
}
170+
171+
for (next <- relations.usesExternal(current)) {
172+
if (!seen.contains(next)) {
173+
seen.add(next)
174+
queued.enqueue(next)
175+
}
176+
}
177+
}
178+
179+
seenList
180+
.iterator
181+
.flatMap { invalidatedClassName =>
182+
relations.definesClass(invalidatedClassName)
183+
}
184+
.flatMap { source =>
185+
relations.products(source)
186+
}
187+
.map { product =>
188+
os.Path(product.id)
189+
}
190+
.toSet
191+
}
192+
193+
val relations = combinedAnalysis.relations
194+
195+
val invalidatedAbsoluteClasspaths = getInvalidatedClasspaths(
196+
result.changedSources.flatMap { source =>
197+
relations.classNames(xsbti.VirtualFileRef.of(source))
198+
},
199+
combinedAnalysis.relations
200+
)
201+
202+
// We only care about testing class, so we can:
203+
// - filter out all class path that start with `testClasspath()`
204+
// - strip the prefix and safely turn them into module class path
205+
206+
val testClasspaths = testClasspath()
207+
val invalidatedClassNames = invalidatedAbsoluteClasspaths.flatMap { absoluteClasspath =>
208+
testClasspaths.collectFirst {
209+
case path if absoluteClasspath.startsWith(path.path) =>
210+
absoluteClasspath.relativeTo(path.path).segments.map(_.stripSuffix(".class")).mkString(".")
211+
}
212+
}
213+
val testingClasses = invalidatedClassNames ++ failedTestClasses
214+
val testClasses = testForkGrouping().map(_.filter(testingClasses.contains)).filter(_.nonEmpty)
215+
216+
// Clean up the directory for test runners
217+
os.walk(Task.dest).foreach { subPath => os.remove.all(subPath) }
218+
219+
val quickTestReportXml = testReportXml()
220+
221+
val testModuleUtil = new TestModuleUtil(
222+
testUseArgsFile(),
223+
forkArgs(),
224+
Seq.empty,
225+
zincWorker().scalalibClasspath(),
226+
resources(),
227+
testFramework(),
228+
runClasspath(),
229+
testClasspaths,
230+
args.toSeq,
231+
testClasses,
232+
zincWorker().testrunnerEntrypointClasspath(),
233+
forkEnv(),
234+
testSandboxWorkingDir(),
235+
forkWorkingDir(),
236+
quickTestReportXml,
237+
zincWorker().javaHome().map(_.path),
238+
testParallelism()
239+
)
240+
241+
val results = testModuleUtil.runTests()
242+
243+
val badTestClasses = (results match {
244+
case Result.Failure(_) =>
245+
// Consider all quick testing classes as failed
246+
testClasses.flatten
247+
case Result.Success((_, results)) =>
248+
// Get all test classes that failed
249+
results
250+
.filter(testResult => Set("Error", "Failure").contains(testResult.status))
251+
.map(_.fullyQualifiedName)
252+
}).distinct
253+
254+
os.write.over(transitiveStampsFile, upickle.default.write(result))
255+
os.write.over(quicktestFailedClassesLog, upickle.default.write(badTestClasses))
256+
os.write.over(invalidatedClassesLog, upickle.default.write(invalidatedClassNames))
257+
results match {
258+
case Result.Failure(errMsg) => Result.Failure(errMsg)
259+
case Result.Success((doneMsg, results)) =>
260+
try TestModule.handleResults(doneMsg, results, Task.ctx(), quickTestReportXml)
261+
catch {
262+
case e: Throwable => Result.Failure("Test reporting failed: " + e)
263+
}
264+
}
265+
}
95266
}
96267

97268
def defaultCommandName(): String = "run"

scalalib/src/mill/scalalib/TestModule.scala

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,12 @@ trait TestModule
203203
globSelectors: Task[Seq[String]]
204204
): Task[(String, Seq[TestResult])] =
205205
Task.Anon {
206+
val testGlobSelectors = globSelectors()
207+
val reportXml = testReportXml()
206208
val testModuleUtil = new TestModuleUtil(
207209
testUseArgsFile(),
208210
forkArgs(),
209-
globSelectors(),
211+
testGlobSelectors,
210212
jvmWorker().scalalibClasspath(),
211213
resources(),
212214
testFramework(),
@@ -218,12 +220,25 @@ trait TestModule
218220
forkEnv(),
219221
testSandboxWorkingDir(),
220222
forkWorkingDir(),
221-
testReportXml(),
223+
reportXml,
222224
jvmWorker().javaHome().map(_.path),
223225
testParallelism(),
224226
testLogLevel()
225227
)
226-
testModuleUtil.runTests()
228+
val result = testModuleUtil.runTests()
229+
230+
result match {
231+
case Result.Failure(errMsg) => Result.Failure(errMsg)
232+
case Result.Success((doneMsg, results)) =>
233+
if (results.isEmpty && testGlobSelectors.nonEmpty) throw new Result.Exception(
234+
s"Test selector does not match any test: ${testGlobSelectors.mkString(" ")}" +
235+
"\nRun discoveredTestClasses to see available tests"
236+
)
237+
try TestModule.handleResults(doneMsg, results, Task.ctx(), reportXml)
238+
catch {
239+
case e: Throwable => Result.Failure("Test reporting failed: " + e)
240+
}
241+
}
227242
}
228243

229244
/**

scalalib/src/mill/scalalib/TestModuleUtil.scala

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -101,21 +101,11 @@ private final class TestModuleUtil(
101101
}
102102
if (selectors.nonEmpty && filteredClassLists.isEmpty) throw doesNotMatchError
103103

104-
val result = if (testParallelism) {
104+
if (testParallelism) {
105105
runTestQueueScheduler(filteredClassLists)
106106
} else {
107107
runTestDefault(filteredClassLists)
108108
}
109-
110-
result match {
111-
case Result.Failure(errMsg) => Result.Failure(errMsg)
112-
case Result.Success((doneMsg, results)) =>
113-
if (results.isEmpty && selectors.nonEmpty) throw doesNotMatchError
114-
try TestModuleUtil.handleResults(doneMsg, results, Task.ctx(), testReportXml)
115-
catch {
116-
case e: Throwable => Result.Failure("Test reporting failed: " + e)
117-
}
118-
}
119109
}
120110

121111
private def callTestRunnerSubprocess(

testkit/src/mill/testkit/UnitTester.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class UnitTester(
5252
inStream: InputStream,
5353
debugEnabled: Boolean,
5454
env: Map[String, String],
55-
resetSourcePath: Boolean
55+
resetSourcePath: Boolean = true
5656
)(implicit fullName: sourcecode.FullName) extends AutoCloseable {
5757
val outPath: os.Path = module.moduleDir / "out"
5858

0 commit comments

Comments
 (0)