@@ -23,6 +23,8 @@ trait ExprBuilder {
23
23
val labelDefStates = collection.mutable.Map [Symbol , Int ]()
24
24
25
25
trait AsyncState {
26
+ var switchId : Int = - 1
27
+
26
28
def state : Int
27
29
28
30
def nextStates : Array [Int ]
@@ -76,13 +78,13 @@ trait ExprBuilder {
76
78
mkHandlerCase(state, stats)
77
79
78
80
override val toString : String =
79
- s " AsyncStateWithoutAwait # $state, nextStates = $nextStates"
81
+ s " AsyncStateWithoutAwait # $state, nextStates = ${ nextStates.toList} "
80
82
}
81
83
82
84
/** A sequence of statements that concludes with an `await` call. The `onComplete`
83
85
* handler will unconditionally transition to `nextState`.
84
86
*/
85
- final class AsyncStateWithAwait (var stats : List [Tree ], val state : Int , onCompleteState : Int , nextState : Int ,
87
+ final class AsyncStateWithAwait (var stats : List [Tree ], val state : Int , val onCompleteState : Int , nextState : Int ,
86
88
val awaitable : Awaitable , symLookup : SymLookup )
87
89
extends AsyncState {
88
90
@@ -268,11 +270,11 @@ trait ExprBuilder {
268
270
}
269
271
270
272
// populate asyncStates
271
- def add (stat : Tree ): Unit = stat match {
273
+ def add (stat : Tree , afterState : Option [ Int ] = None ): Unit = stat match {
272
274
// the val name = await(..) pattern
273
275
case vd @ ValDef (mods, name, tpt, Apply (fun, arg :: Nil )) if isAwait(fun) =>
274
276
val onCompleteState = nextState()
275
- val afterAwaitState = nextState()
277
+ val afterAwaitState = afterState.getOrElse( nextState() )
276
278
val awaitable = Awaitable (arg, stat.symbol, tpt.tpe, vd)
277
279
asyncStates += stateBuilder.resultWithAwait(awaitable, onCompleteState, afterAwaitState) // complete with await
278
280
currState = afterAwaitState
@@ -283,7 +285,7 @@ trait ExprBuilder {
283
285
284
286
val thenStartState = nextState()
285
287
val elseStartState = nextState()
286
- val afterIfState = nextState()
288
+ val afterIfState = afterState.getOrElse( nextState() )
287
289
288
290
asyncStates +=
289
291
// the two Int arguments are the start state of the then branch and the else branch, respectively
@@ -305,7 +307,7 @@ trait ExprBuilder {
305
307
java.util.Arrays .setAll(caseStates, new IntUnaryOperator {
306
308
override def applyAsInt (operand : Int ): Int = nextState()
307
309
})
308
- val afterMatchState = nextState()
310
+ val afterMatchState = afterState.getOrElse( nextState() )
309
311
310
312
asyncStates +=
311
313
stateBuilder.resultWithMatch(scrutinee, cases, caseStates, symLookup)
@@ -323,15 +325,16 @@ trait ExprBuilder {
323
325
if containsAwait(rhs) || directlyAdjacentLabelDefs(ld).exists(containsAwait) =>
324
326
325
327
val startLabelState = stateIdForLabel(ld.symbol)
326
- val afterLabelState = nextState()
328
+ val afterLabelState = afterState.getOrElse( nextState() )
327
329
asyncStates += stateBuilder.resultWithLabel(startLabelState, symLookup)
328
330
labelDefStates(ld.symbol) = startLabelState
329
331
val builder = nestedBlockBuilder(rhs, startLabelState, afterLabelState)
330
332
asyncStates ++= builder.asyncStates
331
333
currState = afterLabelState
332
334
stateBuilder = new AsyncStateBuilder (currState, symLookup)
333
335
case b @ Block (stats, expr) =>
334
- (stats :+ expr) foreach (add)
336
+ for (stat <- stats) add(stat)
337
+ add(expr, afterState = Some (endState))
335
338
case _ =>
336
339
checkForUnsupportedAwait(stat)
337
340
stateBuilder += stat
@@ -346,7 +349,7 @@ trait ExprBuilder {
346
349
347
350
def onCompleteHandler [T : WeakTypeTag ]: Tree
348
351
349
- def toDot ( afterDSE : Boolean ) : String
352
+ def toDot : String
350
353
}
351
354
352
355
case class SymLookup (stateMachineClass : Symbol , applyTrParam : Symbol ) {
@@ -371,11 +374,10 @@ trait ExprBuilder {
371
374
val blockBuilder = new AsyncBlockBuilder (stats, expr, startState, endState, symLookup)
372
375
373
376
new AsyncBlock {
374
- val liveStates = mutable.AnyRefMap [Integer , Integer ]()
375
- val deadStates = mutable.AnyRefMap [Integer , Integer ]()
377
+ val switchIds = mutable.AnyRefMap [Integer , Integer ]()
376
378
377
379
// render with http://graphviz.it/#/new
378
- def toDot ( afterDSE : Boolean ) : String = {
380
+ def toDot : String = {
379
381
val states = asyncStates
380
382
def toHtmlLabel (label : String , preText : String , builder : StringBuilder ): Unit = {
381
383
builder.append(" <b>" ).append(label).append(" </b>" ).append(" <br/>" )
@@ -390,39 +392,60 @@ trait ExprBuilder {
390
392
val dotBuilder = new StringBuilder ()
391
393
dotBuilder.append(" digraph {\n " )
392
394
def stateLabel (s : Int ) = {
393
- val beforeDseLabel = if (s == 0 ) " INITIAL" else if (s == Int .MaxValue ) " TERMINAL" else if (s > 0 ) " S" + s else " C" + Math .abs(s)
394
- if (afterDSE) {
395
- " \" S" + liveStates.getOrElse(s, s) + " (" + beforeDseLabel + " )\" "
396
- } else {
397
- beforeDseLabel
398
- }
399
-
395
+ if (s == 0 ) " INITIAL" else if (s == Int .MaxValue ) " TERMINAL" else switchIds.getOrElse(s, s).toString
400
396
}
401
397
val length = asyncStates.size
402
398
for ((state, i) <- asyncStates.zipWithIndex) {
403
- val liveStateIdOpt : Option [Int ] = if (afterDSE) {
404
- liveStates.get(state.state).map(_.intValue())
405
- } else Some (state.state)
406
- for (_ <- liveStateIdOpt) {
407
- dotBuilder.append(s """ ${stateLabel(state.state)} [label= """ ).append(" <" )
408
- if (i != length - 1 ) {
409
- val CaseDef (_, _, body) = state.mkHandlerCaseForState
410
- toHtmlLabel(stateLabel(state.state), showCode(body), dotBuilder)
411
- } else {
412
- toHtmlLabel(stateLabel(state.state), state.allStats.map(showCode(_)).mkString(" \n " ), dotBuilder)
413
- }
414
- dotBuilder.append(" > ]\n " )
399
+ dotBuilder.append(s """ ${stateLabel(state.state)} [label= """ ).append(" <" )
400
+ if (i != length - 1 ) {
401
+ val CaseDef (_, _, body) = state.mkHandlerCaseForState
402
+ toHtmlLabel(stateLabel(state.state), showCode(body), dotBuilder)
403
+ } else {
404
+ toHtmlLabel(stateLabel(state.state), state.allStats.map(showCode(_)).mkString(" \n " ), dotBuilder)
415
405
}
406
+ dotBuilder.append(" > ]\n " )
416
407
}
417
- for (state <- states; if liveStates.contains(state.state); succ <- state.nextStates) {
408
+ for (state <- states; succ <- state.nextStates) {
418
409
dotBuilder.append(s """ ${stateLabel(state.state)} -> ${stateLabel(succ)}""" )
419
410
dotBuilder.append(" \n " )
420
411
}
421
412
dotBuilder.append(" }\n " )
422
413
dotBuilder.toString
423
414
}
424
415
425
- def asyncStates = blockBuilder.asyncStates.toList
416
+ lazy val asyncStates : List [AsyncState ] = filterStates
417
+
418
+ def filterStates = {
419
+ val all = blockBuilder.asyncStates.toList
420
+ val (initial :: rest) = all
421
+ val map = all.iterator.map(x => (x.state, x)).toMap
422
+ var seen = mutable.HashSet [Int ]()
423
+ def loop (state : AsyncState ): Unit = {
424
+ seen.add(state.state)
425
+ for (i <- state.nextStates) {
426
+ if (i != Int .MaxValue && ! seen.contains(i)) {
427
+ loop(map(i))
428
+ }
429
+ }
430
+ }
431
+ loop(initial)
432
+ val live = rest.filter(state => seen(state.state))
433
+ var nextSwitchId = 1
434
+ (initial :: live).foreach { state =>
435
+ val switchId = nextSwitchId
436
+ switchIds(state.state) = switchId
437
+ nextSwitchId += 1
438
+ state match {
439
+ case state : AsyncStateWithAwait =>
440
+ val switchId = nextSwitchId
441
+ switchIds(state.onCompleteState) = switchId
442
+ nextSwitchId += 1
443
+ case _ =>
444
+ }
445
+ }
446
+ initial :: live
447
+
448
+ }
426
449
427
450
def mkCombinedHandlerCases [T : WeakTypeTag ]: List [CaseDef ] = {
428
451
val caseForLastState : CaseDef = {
@@ -488,43 +511,14 @@ trait ExprBuilder {
488
511
// Identify dead states: `case <id> => { state = nextId; (); (); ... }, eliminated, and compact state ids to
489
512
// enable emission of a tableswitch.
490
513
private def eliminateDeadStates (m : Match ): Tree = {
491
- object DeadState {
492
- private var compactedStateId = 1
493
- for (CaseDef (Literal (Constant (stateId : Integer )), EmptyTree , body) <- m.cases) {
494
- body match {
495
- case _ if (stateId == 0 ) => liveStates(stateId) = stateId
496
- case Block (Assign (_, Literal (Constant (nextState : Integer ))) :: rest, expr) if (expr :: rest).forall(t => isLiteralUnit(t)) =>
497
- deadStates(stateId) = nextState
498
- case _ =>
499
- liveStates(stateId) = compactedStateId
500
- compactedStateId += 1
501
- }
502
- }
503
- if (deadStates.nonEmpty)
504
- AsyncUtils .vprintln(s " ${deadStates.size} dead states eliminated " )
505
- def isDead (i : Integer ) = deadStates.contains(i)
506
- def translatedStateId (i : Integer , tree : Tree ): Integer = {
507
- def chaseDead (i : Integer ): Integer = {
508
- val replacement = deadStates.getOrNull(i)
509
- if (replacement == null ) i
510
- else chaseDead(replacement)
511
- }
512
-
513
- val live = chaseDead(i)
514
- liveStates.get(live) match {
515
- case Some (x) => x
516
- case None => sys.error(s " $live, $liveStates \n $deadStates\n $m\n\n ==== \n $tree" )
517
- }
518
- }
519
- }
520
514
val stateMemberSymbol = symLookup.stateMachineMember(name.state)
521
515
// - remove CaseDef-s for dead states
522
516
// - rewrite state transitions to dead states to instead transition to the
523
517
// non-dead successor.
524
518
val elimDeadStateTransform = new Transformer {
525
519
override def transform (tree : Tree ): Tree = tree match {
526
520
case as @ Assign (lhs, Literal (Constant (i : Integer ))) if lhs.symbol == stateMemberSymbol =>
527
- val replacement = DeadState .translatedStateId(i, as )
521
+ val replacement = switchIds(i )
528
522
treeCopy.Assign (tree, lhs, Literal (Constant (replacement)))
529
523
case _ : Match | _ : CaseDef | _ : Block | _ : If =>
530
524
super .transform(tree)
@@ -533,12 +527,9 @@ trait ExprBuilder {
533
527
}
534
528
val cases1 = m.cases.flatMap {
535
529
case cd @ CaseDef (Literal (Constant (i : Integer )), EmptyTree , rhs) =>
536
- if (DeadState .isDead(i)) Nil
537
- else {
538
- val replacement = DeadState .translatedStateId(i, cd)
539
- val rhs1 = elimDeadStateTransform.transform(rhs)
540
- treeCopy.CaseDef (cd, Literal (Constant (replacement)), EmptyTree , rhs1) :: Nil
541
- }
530
+ val replacement = switchIds(i)
531
+ val rhs1 = elimDeadStateTransform.transform(rhs)
532
+ treeCopy.CaseDef (cd, Literal (Constant (replacement)), EmptyTree , rhs1) :: Nil
542
533
case x => x :: Nil
543
534
}
544
535
treeCopy.Match (m, m.selector, cases1)
0 commit comments