Skip to content

Commit a8d4032

Browse files
committed
Rework dead state elimination to happen as part of expr builder
1 parent 5eb1cd6 commit a8d4032

File tree

3 files changed

+67
-73
lines changed

3 files changed

+67
-73
lines changed

src/main/scala/scala/async/internal/AsyncTransform.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,8 @@ trait AsyncTransform {
130130
AsyncUtils.vprintln(s"${c.macroApplication}")
131131
AsyncUtils.vprintln(s"ANF transform expands to:\n $anfTree")
132132
states foreach (s => AsyncUtils.vprintln(s))
133-
AsyncUtils.vprintln("===== BEFORE DSE =====")
134-
AsyncUtils.vprintln(block.toDot(afterDSE = false))
135-
AsyncUtils.vprintln("===== AFTER DSE =====")
136-
AsyncUtils.vprintln(block.toDot(afterDSE = true))
133+
AsyncUtils.vprintln("===== DOT =====")
134+
AsyncUtils.vprintln(block.toDot)
137135
}
138136

139137
/**

src/main/scala/scala/async/internal/ExprBuilder.scala

Lines changed: 60 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ trait ExprBuilder {
2323
val labelDefStates = collection.mutable.Map[Symbol, Int]()
2424

2525
trait AsyncState {
26+
var switchId: Int = -1
27+
2628
def state: Int
2729

2830
def nextStates: Array[Int]
@@ -76,13 +78,13 @@ trait ExprBuilder {
7678
mkHandlerCase(state, stats)
7779

7880
override val toString: String =
79-
s"AsyncStateWithoutAwait #$state, nextStates = $nextStates"
81+
s"AsyncStateWithoutAwait #$state, nextStates = ${nextStates.toList}"
8082
}
8183

8284
/** A sequence of statements that concludes with an `await` call. The `onComplete`
8385
* handler will unconditionally transition to `nextState`.
8486
*/
85-
final class AsyncStateWithAwait(var stats: List[Tree], val state: Int, onCompleteState: Int, nextState: Int,
87+
final class AsyncStateWithAwait(var stats: List[Tree], val state: Int, val onCompleteState: Int, nextState: Int,
8688
val awaitable: Awaitable, symLookup: SymLookup)
8789
extends AsyncState {
8890

@@ -268,11 +270,11 @@ trait ExprBuilder {
268270
}
269271

270272
// populate asyncStates
271-
def add(stat: Tree): Unit = stat match {
273+
def add(stat: Tree, afterState: Option[Int] = None): Unit = stat match {
272274
// the val name = await(..) pattern
273275
case vd @ ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) =>
274276
val onCompleteState = nextState()
275-
val afterAwaitState = nextState()
277+
val afterAwaitState = afterState.getOrElse(nextState())
276278
val awaitable = Awaitable(arg, stat.symbol, tpt.tpe, vd)
277279
asyncStates += stateBuilder.resultWithAwait(awaitable, onCompleteState, afterAwaitState) // complete with await
278280
currState = afterAwaitState
@@ -283,7 +285,7 @@ trait ExprBuilder {
283285

284286
val thenStartState = nextState()
285287
val elseStartState = nextState()
286-
val afterIfState = nextState()
288+
val afterIfState = afterState.getOrElse(nextState())
287289

288290
asyncStates +=
289291
// the two Int arguments are the start state of the then branch and the else branch, respectively
@@ -305,7 +307,7 @@ trait ExprBuilder {
305307
java.util.Arrays.setAll(caseStates, new IntUnaryOperator {
306308
override def applyAsInt(operand: Int): Int = nextState()
307309
})
308-
val afterMatchState = nextState()
310+
val afterMatchState = afterState.getOrElse(nextState())
309311

310312
asyncStates +=
311313
stateBuilder.resultWithMatch(scrutinee, cases, caseStates, symLookup)
@@ -323,15 +325,16 @@ trait ExprBuilder {
323325
if containsAwait(rhs) || directlyAdjacentLabelDefs(ld).exists(containsAwait) =>
324326

325327
val startLabelState = stateIdForLabel(ld.symbol)
326-
val afterLabelState = nextState()
328+
val afterLabelState = afterState.getOrElse(nextState())
327329
asyncStates += stateBuilder.resultWithLabel(startLabelState, symLookup)
328330
labelDefStates(ld.symbol) = startLabelState
329331
val builder = nestedBlockBuilder(rhs, startLabelState, afterLabelState)
330332
asyncStates ++= builder.asyncStates
331333
currState = afterLabelState
332334
stateBuilder = new AsyncStateBuilder(currState, symLookup)
333335
case b @ Block(stats, expr) =>
334-
(stats :+ expr) foreach (add)
336+
for (stat <- stats) add(stat)
337+
add(expr, afterState = Some(endState))
335338
case _ =>
336339
checkForUnsupportedAwait(stat)
337340
stateBuilder += stat
@@ -346,7 +349,7 @@ trait ExprBuilder {
346349

347350
def onCompleteHandler[T: WeakTypeTag]: Tree
348351

349-
def toDot(afterDSE: Boolean): String
352+
def toDot: String
350353
}
351354

352355
case class SymLookup(stateMachineClass: Symbol, applyTrParam: Symbol) {
@@ -371,11 +374,10 @@ trait ExprBuilder {
371374
val blockBuilder = new AsyncBlockBuilder(stats, expr, startState, endState, symLookup)
372375

373376
new AsyncBlock {
374-
val liveStates = mutable.AnyRefMap[Integer, Integer]()
375-
val deadStates = mutable.AnyRefMap[Integer, Integer]()
377+
val switchIds = mutable.AnyRefMap[Integer, Integer]()
376378

377379
// render with http://graphviz.it/#/new
378-
def toDot(afterDSE: Boolean): String = {
380+
def toDot: String = {
379381
val states = asyncStates
380382
def toHtmlLabel(label: String, preText: String, builder: StringBuilder): Unit = {
381383
builder.append("<b>").append(label).append("</b>").append("<br/>")
@@ -390,39 +392,60 @@ trait ExprBuilder {
390392
val dotBuilder = new StringBuilder()
391393
dotBuilder.append("digraph {\n")
392394
def stateLabel(s: Int) = {
393-
val beforeDseLabel = if (s == 0) "INITIAL" else if (s == Int.MaxValue) "TERMINAL" else if (s > 0) "S" + s else "C" + Math.abs(s)
394-
if (afterDSE) {
395-
"\"S" + liveStates.getOrElse(s, s) + " (" + beforeDseLabel + ")\""
396-
} else {
397-
beforeDseLabel
398-
}
399-
395+
if (s == 0) "INITIAL" else if (s == Int.MaxValue) "TERMINAL" else switchIds.getOrElse(s, s).toString
400396
}
401397
val length = asyncStates.size
402398
for ((state, i) <- asyncStates.zipWithIndex) {
403-
val liveStateIdOpt: Option[Int] = if (afterDSE) {
404-
liveStates.get(state.state).map(_.intValue())
405-
} else Some(state.state)
406-
for (_ <- liveStateIdOpt) {
407-
dotBuilder.append(s"""${stateLabel(state.state)} [label=""").append("<")
408-
if (i != length - 1) {
409-
val CaseDef(_, _, body) = state.mkHandlerCaseForState
410-
toHtmlLabel(stateLabel(state.state), showCode(body), dotBuilder)
411-
} else {
412-
toHtmlLabel(stateLabel(state.state), state.allStats.map(showCode(_)).mkString("\n"), dotBuilder)
413-
}
414-
dotBuilder.append("> ]\n")
399+
dotBuilder.append(s"""${stateLabel(state.state)} [label=""").append("<")
400+
if (i != length - 1) {
401+
val CaseDef(_, _, body) = state.mkHandlerCaseForState
402+
toHtmlLabel(stateLabel(state.state), showCode(body), dotBuilder)
403+
} else {
404+
toHtmlLabel(stateLabel(state.state), state.allStats.map(showCode(_)).mkString("\n"), dotBuilder)
415405
}
406+
dotBuilder.append("> ]\n")
416407
}
417-
for (state <- states; if liveStates.contains(state.state); succ <- state.nextStates) {
408+
for (state <- states; succ <- state.nextStates) {
418409
dotBuilder.append(s"""${stateLabel(state.state)} -> ${stateLabel(succ)}""")
419410
dotBuilder.append("\n")
420411
}
421412
dotBuilder.append("}\n")
422413
dotBuilder.toString
423414
}
424415

425-
def asyncStates = blockBuilder.asyncStates.toList
416+
lazy val asyncStates: List[AsyncState] = filterStates
417+
418+
def filterStates = {
419+
val all = blockBuilder.asyncStates.toList
420+
val (initial :: rest) = all
421+
val map = all.iterator.map(x => (x.state, x)).toMap
422+
var seen = mutable.HashSet[Int]()
423+
def loop(state: AsyncState): Unit = {
424+
seen.add(state.state)
425+
for (i <- state.nextStates) {
426+
if (i != Int.MaxValue && !seen.contains(i)) {
427+
loop(map(i))
428+
}
429+
}
430+
}
431+
loop(initial)
432+
val live = rest.filter(state => seen(state.state))
433+
var nextSwitchId = 1
434+
(initial :: live).foreach { state =>
435+
val switchId = nextSwitchId
436+
switchIds(state.state) = switchId
437+
nextSwitchId += 1
438+
state match {
439+
case state: AsyncStateWithAwait =>
440+
val switchId = nextSwitchId
441+
switchIds(state.onCompleteState) = switchId
442+
nextSwitchId += 1
443+
case _ =>
444+
}
445+
}
446+
initial :: live
447+
448+
}
426449

427450
def mkCombinedHandlerCases[T: WeakTypeTag]: List[CaseDef] = {
428451
val caseForLastState: CaseDef = {
@@ -488,43 +511,14 @@ trait ExprBuilder {
488511
// Identify dead states: `case <id> => { state = nextId; (); (); ... }, eliminated, and compact state ids to
489512
// enable emission of a tableswitch.
490513
private def eliminateDeadStates(m: Match): Tree = {
491-
object DeadState {
492-
private var compactedStateId = 1
493-
for (CaseDef(Literal(Constant(stateId: Integer)), EmptyTree, body) <- m.cases) {
494-
body match {
495-
case _ if (stateId == 0) => liveStates(stateId) = stateId
496-
case Block(Assign(_, Literal(Constant(nextState: Integer))) :: rest, expr) if (expr :: rest).forall(t => isLiteralUnit(t)) =>
497-
deadStates(stateId) = nextState
498-
case _ =>
499-
liveStates(stateId) = compactedStateId
500-
compactedStateId += 1
501-
}
502-
}
503-
if (deadStates.nonEmpty)
504-
AsyncUtils.vprintln(s"${deadStates.size} dead states eliminated")
505-
def isDead(i: Integer) = deadStates.contains(i)
506-
def translatedStateId(i: Integer, tree: Tree): Integer = {
507-
def chaseDead(i: Integer): Integer = {
508-
val replacement = deadStates.getOrNull(i)
509-
if (replacement == null) i
510-
else chaseDead(replacement)
511-
}
512-
513-
val live = chaseDead(i)
514-
liveStates.get(live) match {
515-
case Some(x) => x
516-
case None => sys.error(s"$live, $liveStates \n$deadStates\n$m\n\n====\n$tree")
517-
}
518-
}
519-
}
520514
val stateMemberSymbol = symLookup.stateMachineMember(name.state)
521515
// - remove CaseDef-s for dead states
522516
// - rewrite state transitions to dead states to instead transition to the
523517
// non-dead successor.
524518
val elimDeadStateTransform = new Transformer {
525519
override def transform(tree: Tree): Tree = tree match {
526520
case as @ Assign(lhs, Literal(Constant(i: Integer))) if lhs.symbol == stateMemberSymbol =>
527-
val replacement = DeadState.translatedStateId(i, as)
521+
val replacement = switchIds(i)
528522
treeCopy.Assign(tree, lhs, Literal(Constant(replacement)))
529523
case _: Match | _: CaseDef | _: Block | _: If =>
530524
super.transform(tree)
@@ -533,12 +527,9 @@ trait ExprBuilder {
533527
}
534528
val cases1 = m.cases.flatMap {
535529
case cd @ CaseDef(Literal(Constant(i: Integer)), EmptyTree, rhs) =>
536-
if (DeadState.isDead(i)) Nil
537-
else {
538-
val replacement = DeadState.translatedStateId(i, cd)
539-
val rhs1 = elimDeadStateTransform.transform(rhs)
540-
treeCopy.CaseDef(cd, Literal(Constant(replacement)), EmptyTree, rhs1) :: Nil
541-
}
530+
val replacement = switchIds(i)
531+
val rhs1 = elimDeadStateTransform.transform(rhs)
532+
treeCopy.CaseDef(cd, Literal(Constant(replacement)), EmptyTree, rhs1) :: Nil
542533
case x => x :: Nil
543534
}
544535
treeCopy.Match(m, m.selector, cases1)

src/main/scala/scala/async/internal/TransformUtils.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ private[async] trait TransformUtils {
3939
true
4040
case _ => false
4141
}
42+
def isNullAssignment(t: Tree) = t match {
43+
case Assign(Ident(_), Literal(Constant(null))) =>
44+
true
45+
case _ => false
46+
}
4247

4348
def isPastTyper =
4449
c.universe.asInstanceOf[scala.reflect.internal.SymbolTable].isPastTyper

0 commit comments

Comments
 (0)