Skip to content

Don't aggressively null out captured vars #40

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 13, 2013
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/main/scala/scala/async/internal/AsyncId.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ object AsyncTestLV extends AsyncBase {
def asyncIdImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[T] = asyncImpl[T](c)(body)(c.literalUnit)

var log: List[(String, Any)] = List()
def assertNulledOut(a: Any): Unit = assert(log.exists(_._2 == a), AsyncTestLV.log)
def assertNotNulledOut(a: Any): Unit = assert(!log.exists(_._2 == a), AsyncTestLV.log)
def clear() = log = Nil

def apply(name: String, v: Any): Unit =
log ::= (name -> v)
Expand Down
60 changes: 50 additions & 10 deletions src/main/scala/scala/async/internal/LiveVariables.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,53 @@ trait LiveVariables {
* @param as a state of an `async` expression
* @return a set of lifted fields that are used within state `as`
*/
def fieldsUsedIn(as: AsyncState): Set[Symbol] = {
class FindUseTraverser extends Traverser {
def fieldsUsedIn(as: AsyncState): ReferencedFields = {
class FindUseTraverser extends AsyncTraverser {
var usedFields = Set[Symbol]()
override def traverse(tree: Tree) = tree match {
case Ident(_) if liftedSyms(tree.symbol) =>
usedFields += tree.symbol
case _ =>
super.traverse(tree)
var capturedFields = Set[Symbol]()
private def capturing[A](body: => A): A = {
val saved = capturing
try {
capturing = true
body
} finally capturing = saved
}
private def capturingCheck(tree: Tree) = capturing(tree foreach check)
private var capturing: Boolean = false
private def check(tree: Tree) {
tree match {
case Ident(_) if liftedSyms(tree.symbol) =>
if (capturing)
capturedFields += tree.symbol
else
usedFields += tree.symbol
case _ =>
}
}
override def traverse(tree: Tree) = {
check(tree)
super.traverse(tree)
}

override def nestedClass(classDef: ClassDef): Unit = capturingCheck(classDef)

override def nestedModule(module: ModuleDef): Unit = capturingCheck(module)

override def nestedMethod(defdef: DefDef): Unit = capturingCheck(defdef)

override def byNameArgument(arg: Tree): Unit = capturingCheck(arg)

override def function(function: Function): Unit = capturingCheck(function)

override def patMatFunction(tree: Match): Unit = capturingCheck(tree)
}

val findUses = new FindUseTraverser
findUses.traverse(Block(as.stats: _*))
findUses.usedFields
ReferencedFields(findUses.usedFields, findUses.capturedFields)
}
case class ReferencedFields(used: Set[Symbol], captured: Set[Symbol]) {
override def toString = s"used: ${used.mkString(",")}\ncaptured: ${captured.mkString(",")}"
}

/* Build the control-flow graph.
Expand All @@ -104,7 +138,7 @@ trait LiveVariables {
val finalState = asyncStates.find(as => !asyncStates.exists(other => isPred(as.state, other.state))).get

for (as <- asyncStates)
AsyncUtils.vprintln(s"fields used in state #${as.state}: ${fieldsUsedIn(as).mkString(", ")}")
AsyncUtils.vprintln(s"fields used in state #${as.state}: ${fieldsUsedIn(as)}")

/* Backwards data-flow analysis. Computes live variables information at entry and exit
* of each async state.
Expand All @@ -130,13 +164,16 @@ trait LiveVariables {
var currStates = List(finalState) // start at final state
var pred = List[AsyncState]() // current predecessor states
var hasChanged = true // if something has changed we need to continue iterating
var captured: Set[Symbol] = Set()

while (hasChanged) {
hasChanged = false

for (cs <- currStates) {
val LVentryOld = LVentry(cs.state)
val LVentryNew = LVexit(cs.state) ++ fieldsUsedIn(cs)
val referenced = fieldsUsedIn(cs)
captured ++= referenced.captured
val LVentryNew = LVexit(cs.state) ++ referenced.used
if (!LVentryNew.sameElements(LVentryOld)) {
LVentry = LVentry + (cs.state -> LVentryNew)
hasChanged = true
Expand Down Expand Up @@ -164,6 +201,9 @@ trait LiveVariables {

def lastUsagesOf(field: Tree, at: AsyncState, avoid: Set[AsyncState]): Set[Int] =
if (avoid(at)) Set()
else if (captured(field.symbol)) {
Set()
}
else LVentry get at.state match {
case Some(fields) if fields.exists(_ == field.symbol) =>
Set(at.state)
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/scala/async/internal/TransformUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ private[async] trait TransformUtils {
def nestedModule(module: ModuleDef) {
}

def nestedMethod(module: DefDef) {
def nestedMethod(defdef: DefDef) {
}

def byNameArgument(arg: Tree) {
Expand Down
38 changes: 34 additions & 4 deletions src/test/scala/scala/async/TreeInterrogation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,43 @@ object TreeInterrogation extends App {
withDebug {
val cm = reflect.runtime.currentMirror
val tb = mkToolbox("-cp ${toolboxClasspath} -Xprint:typer -uniqid")
import scala.async.Async._
import scala.async.internal.AsyncTestLV._
val tree = tb.parse(
""" import _root_.scala.async.internal.AsyncId.{async, await}
"""
| import scala.async.internal.AsyncTestLV._
| import scala.async.internal.AsyncTestLV
|
| case class MCell[T](var v: T)
| val f = async { MCell(1) }
|
| def m1(x: MCell[Int], y: Int): Int =
| async { x.v + y }
| case class Cell[T](v: T)
|
| async {
| implicit def view(a: Int): String = ""
| await(0).length
| // state #1
| val a: MCell[Int] = await(f) // await$13$1
| // state #2
| var y = MCell(0)
|
| while (a.v < 10) {
| // state #4
| a.v = a.v + 1
| y = MCell(await(a).v + 1) // await$14$1
| // state #7
| }
|
| // state #3
| assert(AsyncTestLV.log.exists(entry => entry._1 == "await$14$1"))
|
| val b = await(m1(a, y.v)) // await$15$1
| // state #8
| assert(AsyncTestLV.log.exists(_ == ("a$1" -> MCell(10))))
| assert(AsyncTestLV.log.exists(_ == ("y$1" -> MCell(11))))
| b
| }
|
|
| """.stripMargin)
println(tree)
val tree1 = tb.typeCheck(tree.duplicate)
Expand Down
120 changes: 117 additions & 3 deletions src/test/scala/scala/async/run/live/LiveVariablesSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ case class MCell[T](var v: T)


class LiveVariablesSpec {
AsyncTestLV.clear()

@Test
def `zero out fields of reference type`() {
Expand All @@ -35,7 +36,7 @@ class LiveVariablesSpec {
// a == Cell(1)
val b: Cell[Int] = await(m1(a)) // await$2$1
// b == Cell(2)
assert(AsyncTestLV.log.exists(_ == ("await$1$1" -> Cell(1))))
assert(AsyncTestLV.log.exists(_ == ("await$1$1" -> Cell(1))), AsyncTestLV.log)
val res = await(m2(b)) // await$3$1
assert(AsyncTestLV.log.exists(_ == ("await$2$1" -> Cell(2))))
res
Expand Down Expand Up @@ -141,12 +142,125 @@ class LiveVariablesSpec {

val b = await(m1(a, y.v)) // await$15$1
// state #8
assert(AsyncTestLV.log.exists(_ == ("a$1" -> MCell(10))))
assert(AsyncTestLV.log.exists(_ == ("a$1" -> MCell(10))), AsyncTestLV.log)
assert(AsyncTestLV.log.exists(_ == ("y$1" -> MCell(11))))
b
}

assert(m3() == 21)
assert(m3() == 21, m3())
}

@Test
def `don't zero captured fields captured lambda`() {
val f = async {
val x = "x"
val y = "y"
await(0)
y.reverse
val f = () => assert(x != null)
await(0)
f
}
AsyncTestLV.assertNotNulledOut("x")
AsyncTestLV.assertNulledOut("y")
f()
}

@Test
def `don't zero captured fields captured by-name`() {
def func0[A](a: => A): () => A = () => a
val f = async {
val x = "x"
val y = "y"
await(0)
y.reverse
val f = func0(assert(x != null))
await(0)
f
}
AsyncTestLV.assertNotNulledOut("x")
AsyncTestLV.assertNulledOut("y")
f()
}

@Test
def `don't zero captured fields nested class`() {
def func0[A](a: => A): () => A = () => a
val f = async {
val x = "x"
val y = "y"
await(0)
y.reverse
val f = new Function0[Unit] {
def apply = assert(x != null)
}
await(0)
f
}
AsyncTestLV.assertNotNulledOut("x")
AsyncTestLV.assertNulledOut("y")
f()
}

@Test
def `don't zero captured fields nested object`() {
def func0[A](a: => A): () => A = () => a
val f = async {
val x = "x"
val y = "y"
await(0)
y.reverse
object f extends Function0[Unit] {
def apply = assert(x != null)
}
await(0)
f
}
AsyncTestLV.assertNotNulledOut("x")
AsyncTestLV.assertNulledOut("y")
f()
}

@Test
def `don't zero captured fields nested def`() {
val f = async {
val x = "x"
val y = "y"
await(0)
y.reverse
def xx = x
val f = xx _
await(0)
f
}
AsyncTestLV.assertNotNulledOut("x")
AsyncTestLV.assertNulledOut("y")
f()
}

@Test
def `capture bug`() {
sealed trait Base
case class B1() extends Base
case class B2() extends Base
val outer = List[(Base, Int)]((B1(), 8))

def getMore(b: Base) = 4

def baz = async {
outer.head match {
case (a @ B1(), r) => {
val ents = await(getMore(a))

{ () =>
println(a)
assert(a ne null)
}
}
case (b @ B2(), x) =>
() => ???
}
}
baz()
}
}