Skip to content

Add support for await inside try-catch #11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 36 additions & 29 deletions src/main/scala/scala/async/AnfTransform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -121,44 +121,42 @@ private[async] final case class AnfTransform[C <: Context](c: C) {

private object inline {
def transformToList(tree: Tree): List[Tree] = trace("inline", tree) {
def branchWithAssign(orig: Tree, varDef: ValDef) = orig match {
case Block(stats, expr) => Block(stats, Assign(Ident(varDef.name), expr))
case _ => Assign(Ident(varDef.name), orig)
}

def casesWithAssign(cases: List[CaseDef], varDef: ValDef) = cases map {
case cd @ CaseDef(pat, guard, orig) =>
attachCopy(cd)(CaseDef(pat, guard, branchWithAssign(orig, varDef)))
}

val stats :+ expr = anf.transformToList(tree)
expr match {
// if type of if-else/try/match is Unit don't introduce assignment,
// but add Unit value to bring it into form expected by async transform
case If(_, _, _) | Try(_, _, _) | Match(_, _) if expr.tpe =:= definitions.UnitTpe =>
stats :+ expr :+ Literal(Constant(()))

case Apply(fun, args) if isAwait(fun) =>
val valDef = defineVal(name.await, expr, tree.pos)
stats :+ valDef :+ Ident(valDef.name)

case If(cond, thenp, elsep) =>
// if type of if-else is Unit don't introduce assignment,
// but add Unit value to bring it into form expected by async transform
if (expr.tpe =:= definitions.UnitTpe) {
stats :+ expr :+ Literal(Constant(()))
} else {
val varDef = defineVar(name.ifRes, expr.tpe, tree.pos)
def branchWithAssign(orig: Tree) = orig match {
case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.name), thenExpr))
case _ => Assign(Ident(varDef.name), orig)
}
val ifWithAssign = If(cond, branchWithAssign(thenp), branchWithAssign(elsep))
stats :+ varDef :+ ifWithAssign :+ Ident(varDef.name)
}
val varDef = defineVar(name.ifRes, expr.tpe, tree.pos)
val ifWithAssign = If(cond, branchWithAssign(thenp, varDef), branchWithAssign(elsep, varDef))
stats :+ varDef :+ ifWithAssign :+ Ident(varDef.name)

case Try(body, catches, finalizer) =>
val varDef = defineVar(name.tryRes, expr.tpe, tree.pos)
val tryWithAssign = Try(branchWithAssign(body, varDef), casesWithAssign(catches, varDef), finalizer)
stats :+ varDef :+ tryWithAssign :+ Ident(varDef.name)

case Match(scrut, cases) =>
// if type of match is Unit don't introduce assignment,
// but add Unit value to bring it into form expected by async transform
if (expr.tpe =:= definitions.UnitTpe) {
stats :+ expr :+ Literal(Constant(()))
}
else {
val varDef = defineVar(name.matchRes, expr.tpe, tree.pos)
val casesWithAssign = cases map {
case cd@CaseDef(pat, guard, Block(caseStats, caseExpr)) =>
attachCopy(cd)(CaseDef(pat, guard, Block(caseStats, Assign(Ident(varDef.name), caseExpr))))
case cd@CaseDef(pat, guard, body) =>
attachCopy(cd)(CaseDef(pat, guard, Assign(Ident(varDef.name), body)))
}
val matchWithAssign = attachCopy(tree)(Match(scrut, casesWithAssign))
stats :+ varDef :+ matchWithAssign :+ Ident(varDef.name)
}
val varDef = defineVar(name.matchRes, expr.tpe, tree.pos)
val matchWithAssign = attachCopy(tree)(Match(scrut, casesWithAssign(cases, varDef)))
stats :+ varDef :+ matchWithAssign :+ Ident(varDef.name)

case _ =>
stats :+ expr
}
Expand Down Expand Up @@ -225,6 +223,15 @@ private[async] final case class AnfTransform[C <: Context](c: C) {
val stats :+ expr = inline.transformToList(rhs)
stats :+ attachCopy(tree)(Assign(lhs, expr))

case Try(body, catches, finalizer) if containsAwait =>
val stats :+ expr = inline.transformToList(body)
val finBlock =
if (!finalizer.isEmpty) {
val fstats :+ fexpr = inline.transformToList(finalizer)
Block(fstats, fexpr)
} else finalizer
List(c.typeCheck(attachCopy(tree)(Try(Block(stats, expr), catches, finBlock))))

case If(cond, thenp, elsep) if containsAwait =>
val condStats :+ condExpr = inline.transformToList(cond)
val thenBlock = inline.transformToBlock(thenp)
Expand Down
9 changes: 8 additions & 1 deletion src/main/scala/scala/async/Async.scala
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ abstract class AsyncBase {
val stateVar = ValDef(Modifiers(Flag.MUTABLE), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0)))
val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree)
val execContext = ValDef(NoMods, name.execContext, TypeTree(), futureSystemOps.execContext.tree)

// the stack of currently active exception handlers
val handlers = ValDef(Modifiers(Flag.MUTABLE), name.handlers, TypeTree(typeOf[List[PartialFunction[Throwable, Unit]]]), (reify { List() }).tree)

// the exception that is currently in-flight or `null` otherwise
val exception = ValDef(Modifiers(Flag.MUTABLE), name.exception, TypeTree(typeOf[Throwable]), Literal(Constant(null)))

val applyDefDef: DefDef = {
val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree)))
val applyBody = asyncBlock.onCompleteHandler
Expand All @@ -132,7 +139,7 @@ abstract class AsyncBase {
val applyBody = asyncBlock.onCompleteHandler
DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.resume), Nil))
}
List(utils.emptyConstructor, stateVar, result, execContext) ++ localVarTrees ++ List(resumeFunTree, applyDefDef, apply0DefDef)
List(utils.emptyConstructor, stateVar, result, execContext, handlers, exception) ++ localVarTrees ++ List(resumeFunTree, applyDefDef, apply0DefDef)
}
val template = {
Template(List(stateMachineType), emptyValDef, body)
Expand Down
14 changes: 8 additions & 6 deletions src/main/scala/scala/async/AsyncAnalysis.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,16 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C, asyncBase: Asy
}

override def traverse(tree: Tree) {
def containsAwait = tree exists isAwait
def containsAwait(t: Tree) = t exists isAwait
tree match {
case Try(_, _, _) if containsAwait =>
reportUnsupportedAwait(tree, "try/catch")
case Try(_, catches, _) if catches exists containsAwait =>
reportUnsupportedAwait(tree, "catch")
super.traverse(tree)
case Return(_) =>
case Return(_) =>
c.abort(tree.pos, "return is illegal within a async block")
case ValDef(mods, _, _, _) if mods.hasFlag(Flag.LAZY) =>
case ValDef(mods, _, _, _) if mods.hasFlag(Flag.LAZY) =>
c.abort(tree.pos, "lazy vals are illegal within an async block")
case _ =>
case _ =>
super.traverse(tree)
}
}
Expand Down Expand Up @@ -152,6 +152,8 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C, asyncBase: Asy
traverseChunks(List(cond, thenp, elsep))
case Match(selector, cases) if tree exists isAwait =>
traverseChunks(selector :: cases)
case Try(body, catches, fin) if tree exists isAwait =>
traverseChunks((body :: catches) ::: (fin :: Nil))
case LabelDef(name, params, rhs) if rhs exists isAwait =>
traverseChunks(rhs :: Nil)
case Apply(fun, args) if isAwait(fun) =>
Expand Down
Loading