Skip to content

Allow await in applications with multiple argument lists #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 17, 2013
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 14 additions & 31 deletions src/main/scala/scala/async/AnfTransform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,7 @@ private[async] final case class AnfTransform[C <: Context](c: C) {
indent += 1
def oneLine(s: Any) = s.toString.replaceAll( """\n""", "\\\\n").take(127)
try {
AsyncUtils.trace(s"${
indentString
}$prefix(${oneLine(args)})")
AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})")
val result = t
AsyncUtils.trace(s"${indentString}= ${oneLine(result)}")
result
Expand Down Expand Up @@ -187,31 +185,31 @@ private[async] final case class AnfTransform[C <: Context](c: C) {

private[AnfTransform] def transformToList(tree: Tree): List[Tree] = trace("anf", tree) {
def containsAwait = tree exists isAwait

tree match {
case Select(qual, sel) if containsAwait =>
val stats :+ expr = inline.transformToList(qual)
stats :+ attachCopy(tree)(Select(expr, sel).setSymbol(tree.symbol))

case Apply(fun, args) if containsAwait =>
checkForAwaitInNonPrimaryParamSection(fun, args)

case Applied(fun, targs, argss) if argss.nonEmpty && containsAwait =>
// we an assume that no await call appears in a by-name argument position,
// this has already been checked.
val isByName: (Int) => Boolean = utils.isByName(fun)
val funStats :+ simpleFun = inline.transformToList(fun)
def isAwaitRef(name: Name) = name.toString.startsWith(utils.name.await + "$")
val (argStats, argExprs): (List[List[Tree]], List[Tree]) =
mapArguments[List[Tree]](args) {
case (arg, i) if isByName(i) || isSafeToInline(arg) => (Nil, arg)
case (arg@Ident(name), _) if isAwaitRef(name) => (Nil, arg) // not typed, so it eludes the check in `isSafeToInline`
case (arg, i) =>
inline.transformToList(arg) match {
case stats :+ expr =>
val valDef = defineVal(name.arg(i), expr, arg.pos)
val (argStatss, argExprss): (List[List[List[Tree]]], List[List[Tree]]) =
mapArgumentss[List[Tree]](fun, argss) {
case Arg(expr, byName, _) if byName || isSafeToInline(expr) => (Nil, expr)
case Arg(expr@Ident(name), _, _) if isAwaitRef(name) => (Nil, expr) // not typed, so it eludes the check in `isSafeToInline`
case Arg(expr, _, argName) =>
inline.transformToList(expr) match {
case stats :+ expr1 =>
val valDef = defineVal(argName, expr1, expr.pos)
(stats :+ valDef, Ident(valDef.name))
}
}
funStats ++ argStats.flatten :+ attachCopy(tree)(Apply(simpleFun, argExprs).setSymbol(tree.symbol))
val core = if (targs.isEmpty) simpleFun else TypeApply(simpleFun, targs)
val newApply = argExprss.foldLeft(core)(Apply(_, _).setSymbol(tree.symbol))
funStats ++ argStatss.flatten.flatten :+ attachCopy(tree)(newApply)
case Block(stats, expr) if containsAwait =>
inline.transformToList(stats :+ expr)

Expand Down Expand Up @@ -273,19 +271,4 @@ private[async] final case class AnfTransform[C <: Context](c: C) {
}
}
}

def checkForAwaitInNonPrimaryParamSection(fun: Tree, args: List[Tree]) {
// TODO treat the Apply(Apply(.., argsN), ...), args0) holistically, and rewrite
// *all* argument lists in the correct order to preserve semantics.
fun match {
case Apply(fun1, _) =>
fun1.tpe match {
case MethodType(_, resultType: MethodType) if resultType =:= fun.tpe =>
c.error(fun.pos, "implementation restriction: await may only be used in the first parameter list.")
case _ =>
}
case _ =>
}

}
}
70 changes: 55 additions & 15 deletions src/main/scala/scala/async/TransformUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ private[async] final case class TransformUtils[C <: Context](c: C) {
val await = "await"
val bindSuffix = "$bind"

def arg(i: Int) = "arg" + i

def fresh(name: TermName): TermName = newTermName(fresh(name.toString))

def fresh(name: String): String = if (name.toString.contains("$")) name else c.fresh("" + name + "$")
Expand Down Expand Up @@ -102,11 +100,13 @@ private[async] final case class TransformUtils[C <: Context](c: C) {
case dd: DefDef => nestedMethod(dd)
case fun: Function => function(fun)
case m@Match(EmptyTree, _) => patMatFunction(m) // Pattern matching anonymous function under -Xoldpatmat of after `restorePatternMatchingFunctions`
case Apply(fun, args) =>
case Applied(fun, targs, argss) if argss.nonEmpty =>
val isInByName = isByName(fun)
for ((arg, index) <- args.zipWithIndex) {
if (!isInByName(index)) traverse(arg)
else byNameArgument(arg)
for ((args, i) <- argss.zipWithIndex) {
for ((arg, j) <- args.zipWithIndex) {
if (!isInByName(i, j)) traverse(arg)
else byNameArgument(arg)
}
}
traverse(fun)
case _ => super.traverse(tree)
Expand All @@ -122,13 +122,31 @@ private[async] final case class TransformUtils[C <: Context](c: C) {
Set(Boolean_&&, Boolean_||)
}

def isByName(fun: Tree): (Int => Boolean) = {
if (Boolean_ShortCircuits contains fun.symbol) i => true
else fun.tpe match {
case MethodType(params, _) =>
val isByNameParams = params.map(_.asTerm.isByNameParam)
(i: Int) => isByNameParams.applyOrElse(i, (_: Int) => false)
case _ => Map()
def isByName(fun: Tree): ((Int, Int) => Boolean) = {
if (Boolean_ShortCircuits contains fun.symbol) (i, j) => true
else {
val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable]
val paramss = fun.tpe.asInstanceOf[symtab.Type].paramss
val byNamess = paramss.map(_.map(_.isByNameParam))
(i, j) => util.Try(byNamess(i)(j)).getOrElse(false)
}
}
def argName(fun: Tree): ((Int, Int) => String) = {
val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable]
val paramss = fun.tpe.asInstanceOf[symtab.Type].paramss
val namess = paramss.map(_.map(_.name.toString))
(i, j) => util.Try(namess(i)(j)).getOrElse(s"arg_${i}_${j}")
}

object Applied {
val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable]
object treeInfo extends {
val global: symtab.type = symtab
} with reflect.internal.TreeInfo

def unapply(tree: Tree): Some[(Tree, List[Tree], List[List[Tree]])] = {
val treeInfo.Applied(core, targs, argss) = tree.asInstanceOf[symtab.Tree]
Some((core.asInstanceOf[Tree], targs.asInstanceOf[List[Tree]], argss.asInstanceOf[List[List[Tree]]]))
}
}

Expand Down Expand Up @@ -302,7 +320,6 @@ private[async] final case class TransformUtils[C <: Context](c: C) {
}
}


def isSafeToInline(tree: Tree) = {
val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable]
object treeInfo extends {
Expand All @@ -322,7 +339,7 @@ private[async] final case class TransformUtils[C <: Context](c: C) {
* @param f A function from argument (with '_*' unwrapped) and argument index to argument.
* @tparam A The type of the auxillary result
*/
def mapArguments[A](args: List[Tree])(f: (Tree, Int) => (A, Tree)): (List[A], List[Tree]) = {
private def mapArguments[A](args: List[Tree])(f: (Tree, Int) => (A, Tree)): (List[A], List[Tree]) = {
args match {
case args :+ Typed(tree, Ident(tpnme.WILDCARD_STAR)) =>
val (a, argExprs :+ lastArgExpr) = (args :+ tree).zipWithIndex.map(f.tupled).unzip
Expand All @@ -332,4 +349,27 @@ private[async] final case class TransformUtils[C <: Context](c: C) {
args.zipWithIndex.map(f.tupled).unzip
}
}

case class Arg(expr: Tree, isByName: Boolean, argName: String)

/**
* Transform a list of argument lists, producing the transformed lists, and lists of auxillary
* results.
*
* The function `f` need not concern itself with varargs arguments e.g (`xs : _*`). It will
* receive `xs`, and it's result will be re-wrapped as `f(xs) : _*`.
*
* @param fun The function being applied
* @param argss The argument lists
* @return (auxillary results, mapped argument trees)
*/
def mapArgumentss[A](fun: Tree, argss: List[List[Tree]])(f: Arg => (A, Tree)): (List[List[A]], List[List[Tree]]) = {
val isByNamess: (Int, Int) => Boolean = isByName(fun)
val argNamess: (Int, Int) => String = argName(fun)
argss.zipWithIndex.map { case (args, i) =>
mapArguments[A](args) {
(tree, j) => f(Arg(tree, isByNamess(i, j), argNamess(i, j)))
}
}.unzip
}
}
17 changes: 7 additions & 10 deletions src/test/scala/scala/async/TreeInterrogation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,14 @@ object TreeInterrogation extends App {
val tb = mkToolbox("-cp target/scala-2.10/classes -Xprint:flatten")
import scala.async.Async._
val tree = tb.parse(
""" import scala.async.AsyncId._
| async {
| val x = 1
| val opt = Some("")
| await(0)
| val o @ Some(y) = opt
|
| {
| val o @ Some(y) = Some(".")
| }
""" import _root_.scala.async.AsyncId.{async, await}
| def foo[T](a0: Int)(b0: Int*) = s"a0 = $a0, b0 = ${b0.head}"
| val res = async {
| var i = 0
| def get = async {i += 1; i}
| foo[Int](await(get))(await(get) :: Nil : _*)
| }
| res
| """.stripMargin)
println(tree)
val tree1 = tb.typeCheck(tree.duplicate)
Expand Down
76 changes: 58 additions & 18 deletions src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -232,28 +232,68 @@ class AnfTransformSpec {
}

@Test
def awaitNotAllowedInNonPrimaryParamSection1() {
expectError("implementation restriction: await may only be used in the first parameter list.") {
"""
| import _root_.scala.async.AsyncId.{async, await}
| def foo(primary: Any)(i: Int) = i
| async {
| foo(???)(await(0))
| }
""".stripMargin
def awaitInNonPrimaryParamSection1() {
import _root_.scala.async.AsyncId.{async, await}
def foo(a0: Int)(b0: Int) = s"a0 = $a0, b0 = $b0"
val res = async {
var i = 0
def get = {i += 1; i}
foo(get)(get)
}
res mustBe "a0 = 1, b0 = 2"
}

@Test
def awaitInNonPrimaryParamSection2() {
import _root_.scala.async.AsyncId.{async, await}
def foo[T](a0: Int)(b0: Int*) = s"a0 = $a0, b0 = ${b0.head}"
val res = async {
var i = 0
def get = async {i += 1; i}
foo[Int](await(get))(await(get) :: await(async(Nil)) : _*)
}
res mustBe "a0 = 1, b0 = 2"
}

@Test
def awaitInNonPrimaryParamSectionWithLazy1() {
import _root_.scala.async.AsyncId.{async, await}
def foo[T](a: => Int)(b: Int) = b
val res = async {
def get = async {0}
foo[Int](???)(await(get))
}
res mustBe 0
}

@Test
def awaitNotAllowedInNonPrimaryParamSection2() {
expectError("implementation restriction: await may only be used in the first parameter list.") {
"""
| import _root_.scala.async.AsyncId.{async, await}
| def foo[T](primary: Any)(i: Int) = i
| async {
| foo[Int](???)(await(0))
| }
""".stripMargin
def awaitInNonPrimaryParamSectionWithLazy2() {
import _root_.scala.async.AsyncId.{async, await}
def foo[T](a: Int)(b: => Int) = a
val res = async {
def get = async {0}
foo[Int](await(get))(???)
}
res mustBe 0
}

@Test
def awaitWithLazy() {
import _root_.scala.async.AsyncId.{async, await}
def foo[T](a: Int, b: => Int) = a
val res = async {
def get = async {0}
foo[Int](await(get), ???)
}
res mustBe 0
}

@Test
def awaitOkInReciever() {
import scala.async.AsyncId.{async, await}
class Foo { def bar(a: Int)(b: Int) = a + b }
async {
await(async(new Foo)).bar(1)(2)
}
}

Expand Down