Skip to content

Commit e363724

Browse files
committed
Merge pull request #40 from retronym/topic/live-var-capture
Don't aggressively null out captured vars
2 parents 490238d + 302a07c commit e363724

File tree

5 files changed

+205
-18
lines changed

5 files changed

+205
-18
lines changed

src/main/scala/scala/async/internal/AsyncId.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ object AsyncTestLV extends AsyncBase {
2727
def asyncIdImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[T] = asyncImpl[T](c)(body)(c.literalUnit)
2828

2929
var log: List[(String, Any)] = List()
30+
def assertNulledOut(a: Any): Unit = assert(log.exists(_._2 == a), AsyncTestLV.log)
31+
def assertNotNulledOut(a: Any): Unit = assert(!log.exists(_._2 == a), AsyncTestLV.log)
32+
def clear() = log = Nil
3033

3134
def apply(name: String, v: Any): Unit =
3235
log ::= (name -> v)

src/main/scala/scala/async/internal/LiveVariables.scala

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -68,19 +68,53 @@ trait LiveVariables {
6868
* @param as a state of an `async` expression
6969
* @return a set of lifted fields that are used within state `as`
7070
*/
71-
def fieldsUsedIn(as: AsyncState): Set[Symbol] = {
72-
class FindUseTraverser extends Traverser {
71+
def fieldsUsedIn(as: AsyncState): ReferencedFields = {
72+
class FindUseTraverser extends AsyncTraverser {
7373
var usedFields = Set[Symbol]()
74-
override def traverse(tree: Tree) = tree match {
75-
case Ident(_) if liftedSyms(tree.symbol) =>
76-
usedFields += tree.symbol
77-
case _ =>
78-
super.traverse(tree)
74+
var capturedFields = Set[Symbol]()
75+
private def capturing[A](body: => A): A = {
76+
val saved = capturing
77+
try {
78+
capturing = true
79+
body
80+
} finally capturing = saved
7981
}
82+
private def capturingCheck(tree: Tree) = capturing(tree foreach check)
83+
private var capturing: Boolean = false
84+
private def check(tree: Tree) {
85+
tree match {
86+
case Ident(_) if liftedSyms(tree.symbol) =>
87+
if (capturing)
88+
capturedFields += tree.symbol
89+
else
90+
usedFields += tree.symbol
91+
case _ =>
92+
}
93+
}
94+
override def traverse(tree: Tree) = {
95+
check(tree)
96+
super.traverse(tree)
97+
}
98+
99+
override def nestedClass(classDef: ClassDef): Unit = capturingCheck(classDef)
100+
101+
override def nestedModule(module: ModuleDef): Unit = capturingCheck(module)
102+
103+
override def nestedMethod(defdef: DefDef): Unit = capturingCheck(defdef)
104+
105+
override def byNameArgument(arg: Tree): Unit = capturingCheck(arg)
106+
107+
override def function(function: Function): Unit = capturingCheck(function)
108+
109+
override def patMatFunction(tree: Match): Unit = capturingCheck(tree)
80110
}
111+
81112
val findUses = new FindUseTraverser
82113
findUses.traverse(Block(as.stats: _*))
83-
findUses.usedFields
114+
ReferencedFields(findUses.usedFields, findUses.capturedFields)
115+
}
116+
case class ReferencedFields(used: Set[Symbol], captured: Set[Symbol]) {
117+
override def toString = s"used: ${used.mkString(",")}\ncaptured: ${captured.mkString(",")}"
84118
}
85119

86120
/* Build the control-flow graph.
@@ -104,7 +138,7 @@ trait LiveVariables {
104138
val finalState = asyncStates.find(as => !asyncStates.exists(other => isPred(as.state, other.state))).get
105139

106140
for (as <- asyncStates)
107-
AsyncUtils.vprintln(s"fields used in state #${as.state}: ${fieldsUsedIn(as).mkString(", ")}")
141+
AsyncUtils.vprintln(s"fields used in state #${as.state}: ${fieldsUsedIn(as)}")
108142

109143
/* Backwards data-flow analysis. Computes live variables information at entry and exit
110144
* of each async state.
@@ -130,13 +164,16 @@ trait LiveVariables {
130164
var currStates = List(finalState) // start at final state
131165
var pred = List[AsyncState]() // current predecessor states
132166
var hasChanged = true // if something has changed we need to continue iterating
167+
var captured: Set[Symbol] = Set()
133168

134169
while (hasChanged) {
135170
hasChanged = false
136171

137172
for (cs <- currStates) {
138173
val LVentryOld = LVentry(cs.state)
139-
val LVentryNew = LVexit(cs.state) ++ fieldsUsedIn(cs)
174+
val referenced = fieldsUsedIn(cs)
175+
captured ++= referenced.captured
176+
val LVentryNew = LVexit(cs.state) ++ referenced.used
140177
if (!LVentryNew.sameElements(LVentryOld)) {
141178
LVentry = LVentry + (cs.state -> LVentryNew)
142179
hasChanged = true
@@ -164,6 +201,9 @@ trait LiveVariables {
164201

165202
def lastUsagesOf(field: Tree, at: AsyncState, avoid: Set[AsyncState]): Set[Int] =
166203
if (avoid(at)) Set()
204+
else if (captured(field.symbol)) {
205+
Set()
206+
}
167207
else LVentry get at.state match {
168208
case Some(fields) if fields.exists(_ == field.symbol) =>
169209
Set(at.state)

src/main/scala/scala/async/internal/TransformUtils.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ private[async] trait TransformUtils {
166166
def nestedModule(module: ModuleDef) {
167167
}
168168

169-
def nestedMethod(module: DefDef) {
169+
def nestedMethod(defdef: DefDef) {
170170
}
171171

172172
def byNameArgument(arg: Tree) {

src/test/scala/scala/async/TreeInterrogation.scala

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,43 @@ object TreeInterrogation extends App {
6666
withDebug {
6767
val cm = reflect.runtime.currentMirror
6868
val tb = mkToolbox("-cp ${toolboxClasspath} -Xprint:typer -uniqid")
69-
import scala.async.Async._
69+
import scala.async.internal.AsyncTestLV._
7070
val tree = tb.parse(
71-
""" import _root_.scala.async.internal.AsyncId.{async, await}
71+
"""
72+
| import scala.async.internal.AsyncTestLV._
73+
| import scala.async.internal.AsyncTestLV
74+
|
75+
| case class MCell[T](var v: T)
76+
| val f = async { MCell(1) }
77+
|
78+
| def m1(x: MCell[Int], y: Int): Int =
79+
| async { x.v + y }
80+
| case class Cell[T](v: T)
81+
|
7282
| async {
73-
| implicit def view(a: Int): String = ""
74-
| await(0).length
83+
| // state #1
84+
| val a: MCell[Int] = await(f) // await$13$1
85+
| // state #2
86+
| var y = MCell(0)
87+
|
88+
| while (a.v < 10) {
89+
| // state #4
90+
| a.v = a.v + 1
91+
| y = MCell(await(a).v + 1) // await$14$1
92+
| // state #7
93+
| }
94+
|
95+
| // state #3
96+
| assert(AsyncTestLV.log.exists(entry => entry._1 == "await$14$1"))
97+
|
98+
| val b = await(m1(a, y.v)) // await$15$1
99+
| // state #8
100+
| assert(AsyncTestLV.log.exists(_ == ("a$1" -> MCell(10))))
101+
| assert(AsyncTestLV.log.exists(_ == ("y$1" -> MCell(11))))
102+
| b
75103
| }
104+
|
105+
|
76106
| """.stripMargin)
77107
println(tree)
78108
val tree1 = tb.typeCheck(tree.duplicate)

src/test/scala/scala/async/run/live/LiveVariablesSpec.scala

Lines changed: 117 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ case class MCell[T](var v: T)
1919

2020

2121
class LiveVariablesSpec {
22+
AsyncTestLV.clear()
2223

2324
@Test
2425
def `zero out fields of reference type`() {
@@ -35,7 +36,7 @@ class LiveVariablesSpec {
3536
// a == Cell(1)
3637
val b: Cell[Int] = await(m1(a)) // await$2$1
3738
// b == Cell(2)
38-
assert(AsyncTestLV.log.exists(_ == ("await$1$1" -> Cell(1))))
39+
assert(AsyncTestLV.log.exists(_ == ("await$1$1" -> Cell(1))), AsyncTestLV.log)
3940
val res = await(m2(b)) // await$3$1
4041
assert(AsyncTestLV.log.exists(_ == ("await$2$1" -> Cell(2))))
4142
res
@@ -141,12 +142,125 @@ class LiveVariablesSpec {
141142

142143
val b = await(m1(a, y.v)) // await$15$1
143144
// state #8
144-
assert(AsyncTestLV.log.exists(_ == ("a$1" -> MCell(10))))
145+
assert(AsyncTestLV.log.exists(_ == ("a$1" -> MCell(10))), AsyncTestLV.log)
145146
assert(AsyncTestLV.log.exists(_ == ("y$1" -> MCell(11))))
146147
b
147148
}
148149

149-
assert(m3() == 21)
150+
assert(m3() == 21, m3())
150151
}
151152

153+
@Test
154+
def `don't zero captured fields captured lambda`() {
155+
val f = async {
156+
val x = "x"
157+
val y = "y"
158+
await(0)
159+
y.reverse
160+
val f = () => assert(x != null)
161+
await(0)
162+
f
163+
}
164+
AsyncTestLV.assertNotNulledOut("x")
165+
AsyncTestLV.assertNulledOut("y")
166+
f()
167+
}
168+
169+
@Test
170+
def `don't zero captured fields captured by-name`() {
171+
def func0[A](a: => A): () => A = () => a
172+
val f = async {
173+
val x = "x"
174+
val y = "y"
175+
await(0)
176+
y.reverse
177+
val f = func0(assert(x != null))
178+
await(0)
179+
f
180+
}
181+
AsyncTestLV.assertNotNulledOut("x")
182+
AsyncTestLV.assertNulledOut("y")
183+
f()
184+
}
185+
186+
@Test
187+
def `don't zero captured fields nested class`() {
188+
def func0[A](a: => A): () => A = () => a
189+
val f = async {
190+
val x = "x"
191+
val y = "y"
192+
await(0)
193+
y.reverse
194+
val f = new Function0[Unit] {
195+
def apply = assert(x != null)
196+
}
197+
await(0)
198+
f
199+
}
200+
AsyncTestLV.assertNotNulledOut("x")
201+
AsyncTestLV.assertNulledOut("y")
202+
f()
203+
}
204+
205+
@Test
206+
def `don't zero captured fields nested object`() {
207+
def func0[A](a: => A): () => A = () => a
208+
val f = async {
209+
val x = "x"
210+
val y = "y"
211+
await(0)
212+
y.reverse
213+
object f extends Function0[Unit] {
214+
def apply = assert(x != null)
215+
}
216+
await(0)
217+
f
218+
}
219+
AsyncTestLV.assertNotNulledOut("x")
220+
AsyncTestLV.assertNulledOut("y")
221+
f()
222+
}
223+
224+
@Test
225+
def `don't zero captured fields nested def`() {
226+
val f = async {
227+
val x = "x"
228+
val y = "y"
229+
await(0)
230+
y.reverse
231+
def xx = x
232+
val f = xx _
233+
await(0)
234+
f
235+
}
236+
AsyncTestLV.assertNotNulledOut("x")
237+
AsyncTestLV.assertNulledOut("y")
238+
f()
239+
}
240+
241+
@Test
242+
def `capture bug`() {
243+
sealed trait Base
244+
case class B1() extends Base
245+
case class B2() extends Base
246+
val outer = List[(Base, Int)]((B1(), 8))
247+
248+
def getMore(b: Base) = 4
249+
250+
def baz = async {
251+
outer.head match {
252+
case (a @ B1(), r) => {
253+
val ents = await(getMore(a))
254+
255+
{ () =>
256+
println(a)
257+
assert(a ne null)
258+
}
259+
}
260+
case (b @ B2(), x) =>
261+
() => ???
262+
}
263+
}
264+
baz()
265+
}
152266
}

0 commit comments

Comments
 (0)