Skip to content

Commit 1b0da54

Browse files
committed
Prepare to improve positions of async synthetics
Add a test for positions and show the status quo
1 parent cb1486d commit 1b0da54

File tree

1 file changed

+206
-61
lines changed

1 file changed

+206
-61
lines changed

test/junit/scala/tools/nsc/async/AnnotationDrivenAsyncTest.scala

Lines changed: 206 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import org.junit.Assert.assertEquals
99
import org.junit.{Assert, Ignore, Test}
1010

1111
import scala.annotation.{StaticAnnotation, nowarn, unused}
12+
import scala.collection.mutable
1213
import scala.concurrent.duration.Duration
1314
import scala.reflect.internal.util.Position
1415
import scala.reflect.internal.util.ScalaClassLoader.URLClassLoader
@@ -367,6 +368,116 @@ class AnnotationDrivenAsyncTest {
367368
assertEquals(classOf[Array[String]], result.getClass)
368369
}
369370

371+
@Test
372+
def testPositions(): Unit = {
373+
val code =
374+
"""
375+
|import scala.tools.nsc.async.{autoawait, customAsync}
376+
|object Test {
377+
| @autoawait def id(a: Int) = a
378+
| @customAsync def test = {
379+
| val x = id(1)
380+
| val y = id(2)
381+
| x + y
382+
| }
383+
|}""".stripMargin
384+
385+
val result = compile(code)
386+
387+
import result.global._
388+
// settings.Xprintpos.value = true // enable to help debugging
389+
val methpdParseTree = result.parseTree.find { case dt: DefTree => dt.name.string_==("test") case _ => false }
390+
methpdParseTree.get match {
391+
case DefDef(_, _, _, _, _, Block(stats, expr)) =>
392+
val parseTreeStats: List[Tree] = (expr :: stats)
393+
val fsmTree = result.fsmTree
394+
val posMap = mutable.HashMap[Tree, Tree]()
395+
for {
396+
parseTreeStat <- parseTreeStats
397+
pos = parseTreeStat.pos
398+
tree <- fsmTree.get
399+
} {
400+
if (pos.includes(tree.pos)) {
401+
posMap.get(parseTreeStat) match {
402+
case Some(existing) =>
403+
if (existing.pos.includes(tree.pos)) {
404+
405+
} else {
406+
posMap(parseTreeStat) = tree
407+
}
408+
case None =>
409+
posMap(parseTreeStat) = tree
410+
}
411+
}
412+
}
413+
414+
val actual = posMap.toList.mkString("\n")
415+
val expected =
416+
"""(val x = id(1),self.x)
417+
|(x.$plus(y),while$(){
418+
| try {
419+
| self.state() match {
420+
| case 0 => {
421+
| val awaitable$async: scala.tools.nsc.async.CustomFuture = scala.tools.nsc.async.CustomFuture._successful(scala.Int.box(Test.this.id(1)));
422+
| tr = self.getCompleted(awaitable$async);
423+
| self.state_=(1);
424+
| if (null.!=(tr))
425+
| while$()
426+
| else
427+
| {
428+
| self.onComplete(awaitable$async);
429+
| return ()
430+
| }
431+
| }
432+
| case 1 => {
433+
| <synthetic> val await$1: Object = {
434+
| val tryGetResult$async: Object = self.tryGet(tr);
435+
| if (self.eq(tryGetResult$async))
436+
| return ()
437+
| else
438+
| tryGetResult$async.$asInstanceOf[Object]()
439+
| };
440+
| self.x = scala.Int.unbox(await$1);
441+
| val awaitable$async: scala.tools.nsc.async.CustomFuture = scala.tools.nsc.async.CustomFuture._successful(scala.Int.box(Test.this.id(2)));
442+
| tr = self.getCompleted(awaitable$async);
443+
| self.state_=(2);
444+
| if (null.!=(tr))
445+
| while$()
446+
| else
447+
| {
448+
| self.onComplete(awaitable$async);
449+
| return ()
450+
| }
451+
| }
452+
| case 2 => {
453+
| <synthetic> val await$2: Object = {
454+
| val tryGetResult$async: Object = self.tryGet(tr);
455+
| if (self.eq(tryGetResult$async))
456+
| return ()
457+
| else
458+
| tryGetResult$async.$asInstanceOf[Object]()
459+
| };
460+
| val y: Int = scala.Int.unbox(await$2);
461+
| self.completeSuccess(scala.Int.box(self.x.+(y)));
462+
| return ()
463+
| }
464+
| case _ => throw new IllegalStateException(java.lang.String.valueOf(self.state()))
465+
| }
466+
| } catch {
467+
| case (throwable$async @ (_: Throwable)) => {
468+
| self.completeFailure(throwable$async);
469+
| return ()
470+
| }
471+
| };
472+
| while$()
473+
|})
474+
|(val y = id(2),val y: Int = scala.Int.unbox(await$2))""".stripMargin
475+
assertEquals(
476+
expected, actual)
477+
}
478+
}
479+
480+
370481
// Handy to debug the compiler or to collect code coverage statistics in IntelliJ.
371482
@Test
372483
@Ignore
@@ -389,76 +500,110 @@ class AnnotationDrivenAsyncTest {
389500
f
390501
}
391502

503+
abstract class CompileResult {
504+
val global: Global
505+
val tree: global.Tree
506+
val parseTree: global.Tree
507+
def run(): Any
508+
def close(): Unit
509+
def fsmTree: Option[global.Tree] = tree.find { case dd: global.DefDef => dd.symbol.name.containsName("fsm"); case _ => false }
510+
}
511+
392512
def run(code: String, compileOnly: Boolean = false): Any = {
513+
val compileResult = compile(code, compileOnly)
514+
try
515+
if (!compileOnly) compileResult.run()
516+
finally {
517+
compileResult.close()
518+
}
519+
}
520+
521+
def compile(code: String, compileOnly: Boolean = false): CompileResult = {
393522
val out = createTempDir()
394-
try {
395-
val reporter = new StoreReporter(new Settings) {
396-
override def doReport(pos: Position, msg: String, severity: Severity): Unit =
397-
if (severity == INFO) println(msg)
398-
else super.doReport(pos, msg, severity)
399-
}
400-
val settings = new Settings(println(_))
401-
settings.async.value = true
402-
settings.outdir.value = out.getAbsolutePath
403-
settings.embeddedDefaults(getClass.getClassLoader)
404523

405-
// settings.debug.value = true
406-
// settings.uniqid.value = true
407-
// settings.processArgumentString("-Xprint:typer,posterasure,async -nowarn")
408-
// settings.log.value = List("async")
524+
val reporter = new StoreReporter(new Settings) {
525+
override def doReport(pos: Position, msg: String, severity: Severity): Unit =
526+
if (severity == INFO) println(msg)
527+
else super.doReport(pos, msg, severity)
528+
}
529+
val settings = new Settings(println(_))
530+
settings.async.value = true
531+
settings.outdir.value = out.getAbsolutePath
532+
settings.embeddedDefaults(getClass.getClassLoader)
409533

410-
// NOTE: edit ANFTransform.traceAsync to `= true` to get additional diagnostic tracing.
534+
// settings.debug.value = true
535+
// settings.uniqid.value = true
536+
// settings.processArgumentString("-Xprint:typer,posterasure,async -nowarn")
537+
// settings.log.value = List("async")
411538

412-
val isInSBT = !settings.classpath.isSetByUser
413-
if (isInSBT) settings.usejavacp.value = true
414-
val global = new Global(settings, reporter) {
415-
self =>
539+
// NOTE: edit ANFTransform.traceAsync to `= true` to get additional diagnostic tracing.
416540

417-
@nowarn("cat=deprecation&msg=early initializers")
418-
object late extends {
419-
val global: self.type = self
420-
} with AnnotationDrivenAsyncPlugin
541+
val isInSBT = !settings.classpath.isSetByUser
542+
if (isInSBT) settings.usejavacp.value = true
543+
val g = new Global(settings, reporter) {
544+
self =>
421545

422-
override protected def loadPlugins(): List[Plugin] = late :: Nil
423-
}
424-
import global._
425-
426-
val run = new Run
427-
val source = newSourceFile(code)
428-
run.compileSources(source :: Nil)
429-
if (compileOnly) return null
430-
def showInfo(info: StoreReporter#Info): String = {
431-
Position.formatMessage(info.pos, info.severity.toString.toLowerCase + " : " + info.msg, false)
432-
}
433-
Assert.assertTrue(reporter.infos.map(showInfo).mkString("\n"), !reporter.hasErrors)
434-
Assert.assertTrue(reporter.infos.map(showInfo).mkString("\n"), !reporter.hasWarnings)
435-
val loader = new URLClassLoader(Seq(new File(settings.outdir.value).toURI.toURL), global.getClass.getClassLoader)
436-
val cls = loader.loadClass("Test")
437-
val result = try {
438-
cls.getMethod("test").invoke(null)
439-
} catch {
440-
case ite: InvocationTargetException => throw ite.getCause
441-
case _: NoSuchMethodException =>
442-
cls.getMethod("main", classOf[Array[String]]).invoke(null, null)
546+
@nowarn("cat=deprecation&msg=early initializers")
547+
object late extends {
548+
val global: self.type = self
549+
} with AnnotationDrivenAsyncPlugin
550+
551+
override protected def loadPlugins(): List[Plugin] = late :: Nil
552+
}
553+
554+
import g._
555+
556+
val run = new Run
557+
val source = newSourceFile(code)
558+
run.compileSources(source :: Nil)
559+
560+
def showInfo(info: StoreReporter#Info): String = {
561+
Position.formatMessage(info.pos, info.severity.toString.toLowerCase + " : " + info.msg, false)
562+
}
563+
564+
Assert.assertTrue(reporter.infos.map(showInfo).mkString("\n"), !reporter.hasErrors)
565+
Assert.assertTrue(reporter.infos.map(showInfo).mkString("\n"), !reporter.hasWarnings)
566+
567+
val unit: CompilationUnit = run.units.next()
568+
val parseTree0 = newUnitParser(unit).parse()
569+
new CompileResult {
570+
val global: g.type = g
571+
572+
val tree = unit.body
573+
override val parseTree: global.Tree = parseTree0
574+
575+
def run(): Any = {
576+
try {
577+
val loader = new URLClassLoader(Seq(new File(settings.outdir.value).toURI.toURL), global.getClass.getClassLoader)
578+
val cls = loader.loadClass("Test")
579+
val result = try {
580+
cls.getMethod("test").invoke(null)
581+
} catch {
582+
case ite: InvocationTargetException => throw ite.getCause
583+
case _: NoSuchMethodException =>
584+
cls.getMethod("main", classOf[Array[String]]).invoke(null, null)
585+
}
586+
result match {
587+
case t: scala.concurrent.Future[_] =>
588+
scala.concurrent.Await.result(t, Duration.Inf)
589+
case cf: CustomFuture[_] =>
590+
cf._block
591+
case cf: CompletableFuture[_] =>
592+
cf.get()
593+
case value => value
594+
}
595+
} catch {
596+
case ve: VerifyError =>
597+
val asm = out.listFiles().flatMap { file =>
598+
val asmp = AsmUtils.textify(AsmUtils.readClass(file.getAbsolutePath))
599+
asmp :: Nil
600+
}.mkString("\n\n")
601+
throw new AssertionError(asm, ve)
602+
}
443603
}
444-
result match {
445-
case t: scala.concurrent.Future[_] =>
446-
scala.concurrent.Await.result(t, Duration.Inf)
447-
case cf: CustomFuture[_] =>
448-
cf._block
449-
case cf: CompletableFuture[_] =>
450-
cf.get()
451-
case value => value
604+
override def close(): Unit = {
605+
scala.reflect.io.Path.apply(out).deleteRecursively()
452606
}
453-
} catch {
454-
case ve: VerifyError =>
455-
val asm = out.listFiles().flatMap { file =>
456-
val asmp = AsmUtils.textify(AsmUtils.readClass(file.getAbsolutePath))
457-
asmp :: Nil
458-
}.mkString("\n\n")
459-
throw new AssertionError(asm, ve)
460-
} finally {
461-
scala.reflect.io.Path.apply(out).deleteRecursively()
462607
}
463608
}
464609
}

0 commit comments

Comments
 (0)