Skip to content

Reimplement LabelDefs using single transform #3362

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 53 additions & 82 deletions compiler/src/dotty/tools/backend/jvm/LabelDefs.scala
Original file line number Diff line number Diff line change
@@ -1,38 +1,15 @@
package dotty.tools.backend.jvm

import dotty.tools.dotc.ast.Trees.Thicket
import dotty.tools.dotc.ast.{Trees, tpd}
import dotty.tools.dotc.ast.tpd
import dotty.tools.dotc.core.Contexts.Context
import dotty.tools.dotc.core.Types
import dotty.tools.dotc.transform.TreeTransforms.{TransformerInfo, TreeTransform, MiniPhase, MiniPhaseTransform}
import dotty.tools.dotc
import dotty.tools.dotc.backend.jvm.DottyPrimitives
import dotty.tools.dotc.core.Flags.FlagSet
import dotty.tools.dotc.transform.Erasure
import dotty.tools.dotc.transform.SymUtils._
import java.io.{File => JFile}
import dotty.tools.dotc.transform._
import dotty.tools.dotc.transform.TreeTransforms._

import scala.collection.generic.Clearable
import scala.collection.mutable
import scala.collection.mutable.{ListBuffer, ArrayBuffer}
import scala.reflect.ClassTag
import dotty.tools.io.{Directory, PlainDirectory, AbstractFile}
import scala.tools.asm.{ClassVisitor, FieldVisitor, MethodVisitor}
import scala.tools.nsc.backend.jvm.{BCodeHelpers, BackendInterface}
import dotty.tools.dotc.core._
import Periods._
import SymDenotations._
import Contexts._
import Types._
import Symbols._
import Denotations._
import Phases._
import java.lang.AssertionError
import dotty.tools.dotc.util.Positions.Position
import Decorators._
import tpd._
import Flags._
import StdNames.nme
import dotty.tools.dotc.core.StdNames._
import dotty.tools.dotc.core.Symbols._
import dotty.tools.dotc.core.Decorators._
import dotty.tools.dotc.core.Flags._

/**
* Verifies that each Label DefDef has only a single address to jump back and
Expand Down Expand Up @@ -80,70 +57,64 @@ import StdNames.nme
* This is modified by setting `labelsReordered` flag in Phases.
*
* @author Dmitry Petrashko
* @author Nicolas Stucki
*/
class LabelDefs extends MiniPhaseTransform {
def phaseName: String = "labelDef"

val queue = new ArrayBuffer[Tree]()
val beingAppended = new mutable.HashSet[Symbol]()
var labelLevel = 0

override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = {
if (tree.symbol is Flags.Label) tree
else {
collectLabelDefs.clear
val newRhs = collectLabelDefs.transform(tree.rhs)
var labelDefs = collectLabelDefs.labelDefs
import tpd._

def putLabelDefsNearCallees = new TreeMap() {

override def transform(tree: tpd.Tree)(implicit ctx: Context): tpd.Tree = {
tree match {
case t: Apply if labelDefs.contains(t.symbol) =>
val labelDef = labelDefs(t.symbol)
labelDefs -= t.symbol

val labelDef2 = transform(labelDef)
Block(labelDef2:: Nil, t)
def phaseName: String = "labelDef"

case _ => if (labelDefs.nonEmpty) super.transform(tree) else tree
}
}
}
override def runsAfterGroupsOf = Set(classOf[Flatten])

val res = cpy.DefDef(tree)(rhs = putLabelDefsNearCallees.transform(newRhs))
private val labelDefs: mutable.HashMap[Symbol, DefDef] = new mutable.HashMap[Symbol, DefDef]()

res
override def prepareForDefDef(tree: DefDef)(implicit ctx: Context): TreeTransform = {
if (isWhileDef(tree)) this
else if (tree.symbol.is(Label)) NoTransform // transformation is done in transformApply
else {
collectLabelDefs(tree.rhs)
this
}
}

object collectLabelDefs extends TreeMap() {

// labelSymbol -> Defining tree
val labelDefs = new mutable.HashMap[Symbol, Tree]()
override def transformApply(tree: Apply)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (!tree.symbol.is(Label)) tree
else labelDefs.get(tree.symbol) match {
case Some(labelDef) =>
labelDefs -= tree.symbol
val newRhs = transform(labelDef.rhs) // transform the body of the label def
val transformedLabelDef = cpy.DefDef(labelDef)(rhs = newRhs)
Block(transformedLabelDef :: Nil, tree)
case None => tree
}
}

def clear = {
labelDefs.clear()
override def transformBlock(tree: Block)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (tree.stats.isEmpty || !tree.stats.exists(_.symbol.is(Label))) tree
else {
// Only keep label defs if they are followed by their applies
val newStats = tree.stats.zipWithConserve(tree.stats.tail :+ tree.expr) { (stat, next) =>
if (!stat.symbol.is(Label) || stat.symbol == next.symbol) stat
else EmptyTree
}.filterNot(_.isEmpty)
seq(newStats, tree.expr)
}
}

override def transform(tree: tpd.Tree)(implicit ctx: Context): tpd.Tree = tree match {
case t: Template => t
case t: Block =>
val r = super.transform(t)
r match {
case t: Block if t.stats.isEmpty => t.expr
case _ => r
}
case t: DefDef =>
assert(t.symbol is Flags.Label)
val r = super.transform(tree)
labelDefs(r.symbol) = r
EmptyTree
case t: Apply if t.symbol is Flags.Label =>
val sym = t.symbol
super.transform(tree)
case _ =>
super.transform(tree)
private def collectLabelDefs(tree: Tree)(implicit ctx: Context): Unit = new TreeTraverser() {
assert(labelDefs.isEmpty)
override def traverse(tree: Tree)(implicit ctx: Context): Unit = tree match {
case _: Template =>
case tree: DefDef if !isWhileDef(tree) =>
assert(tree.symbol.is(Label))
labelDefs(tree.symbol) = tree
traverseChildren(tree)
case _ => traverseChildren(tree)
}
}.traverse(tree)

private def isWhileDef(ddef: DefDef)(implicit ctx: Context): Boolean = {
ddef.symbol.is(Label) &&
(ddef.name == nme.WHILE_PREFIX || ddef.name == nme.DO_WHILE_PREFIX)
}
}
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,15 @@ class Compiler {
new ElimStaticThis, // Replace `this` references to static objects by global identifiers
new Flatten, // Lift all inner classes to package scope
new RestoreScopes), // Repair scopes rendered invalid by moving definitions in prior phases of the group
List(new RenameLifted, // Renames lifted classes to local numbering scheme
List(new LabelDefs, // Converts calls to labels to jumps
new RenameLifted, // Renames lifted classes to local numbering scheme
new TransformWildcards, // Replace wildcards with default values
new MoveStatics, // Move static methods to companion classes
new ExpandPrivate, // Widen private definitions accessed from nested classes
new SelectStatic, // get rid of selects that would be compiled into GetStatic
new CollectEntryPoints, // Find classes with main methods
new CollectSuperCalls, // Find classes that are called with super
new DropInlined, // Drop Inlined nodes, since backend has no use for them
new LabelDefs), // Converts calls to labels to jumps
new DropInlined), // Drop Inlined nodes, since backend has no use for them
List(new GenBCode) // Generate JVM bytecode
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package dotty.tools.dotc
package transform.localopt

import core.TypeErasure
import core.Contexts.Context
import core.Symbols._
import core.Types._
import core.Flags._
import core.StdNames._
import ast.Trees._
import Simplify._

Expand Down Expand Up @@ -36,7 +36,7 @@ class DropNoEffects(val simplifyPhase: Simplify) extends Optimisation {
case t => t :: Nil
}
val (newStats2, newExpr) = a.expr match {
case Block(stats2, expr) => (newStats1 ++ stats2, expr)
case Block(stats2, expr) if !isWhileLabel(expr.symbol) => (newStats1 ++ stats2, expr)
case _ => (newStats1, a.expr)
}

Expand Down Expand Up @@ -201,4 +201,7 @@ class DropNoEffects(val simplifyPhase: Simplify) extends Optimisation {
case _ =>
false
}

private def isWhileLabel(sym: Symbol)(implicit ctx: Context): Boolean =
sym.is(Label) && (sym.name == nme.WHILE_PREFIX || sym.name == nme.DO_WHILE_PREFIX)
}
67 changes: 67 additions & 0 deletions tests/pos/i2903-fixme.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
case class Foo01(x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, x8: Int, x9: Int, x10: Int)
case class Foo02(x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, x8: Int, x9: Int, x10: Int)
case class Foo03(x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, x8: Int, x9: Int, x10: Int)
case class Foo04(x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, x8: Int, x9: Int, x10: Int)
case class Foo05(x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, x8: Int, x9: Int, x10: Int)
case class Foo06(x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, x8: Int, x9: Int, x10: Int)
case class Foo07(x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, x8: Int, x9: Int, x10: Int)
case class Foo08(x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, x8: Int, x9: Int, x10: Int)
case class Foo09(x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, x8: Int, x9: Int, x10: Int)
case class Foo10(x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, x8: Int, x9: Int, x10: Int)
case class Foo11(x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, x8: Int, x9: Int, x10: Int)
case class Foo12(x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, x8: Int, x9: Int, x10: Int)
case class Foo13(x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, x8: Int, x9: Int, x10: Int)
case class Foo14(x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, x8: Int, x9: Int, x10: Int)
case class Foo15(x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, x8: Int, x9: Int, x10: Int)
case class Foo16(x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, x8: Int, x9: Int, x10: Int)
case class Foo17(x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, x8: Int, x9: Int, x10: Int)
case class Foo18(x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, x8: Int, x9: Int, x10: Int)
case class Foo19(x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, x8: Int, x9: Int, x10: Int)
case class Foo20(x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, x8: Int, x9: Int, x10: Int)
case class Foo21(x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, x8: Int, x9: Int, x10: Int)
case class Foo22(x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, x8: Int, x9: Int, x10: Int)
case class Foo23(x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, x8: Int, x9: Int, x10: Int)
case class Foo24(x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, x8: Int, x9: Int, x10: Int)
case class Foo25(x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, x8: Int, x9: Int, x10: Int)
case class Foo26(x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, x8: Int, x9: Int, x10: Int)
case class Foo27(x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, x8: Int, x9: Int, x10: Int)
case class Foo28(x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, x8: Int, x9: Int, x10: Int)
case class Foo29(x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, x8: Int, x9: Int, x10: Int)

object Test {
def stuff() = {}

def test(x: Any): Unit = x match {
case Foo01(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) => stuff()
case Foo02(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) => stuff()
case Foo03(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) => stuff()
case Foo04(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) => stuff()
case Foo05(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) => stuff()
case Foo06(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) => stuff()
case Foo07(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) => stuff()
case Foo08(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) => stuff()
case Foo09(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) => stuff()
case Foo10(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) => stuff()
case Foo11(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) => stuff()
case Foo12(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) => stuff()
/* FIXME: #2903
case Foo13(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) => stuff()
case Foo14(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) => stuff()
case Foo15(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) => stuff()
case Foo16(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) => stuff()
case Foo17(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) => stuff()
case Foo18(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) => stuff()
case Foo19(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) => stuff()
case Foo20(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) => stuff()
case Foo21(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) => stuff()
case Foo22(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) => stuff()
case Foo23(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) => stuff()
case Foo24(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) => stuff()
case Foo25(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) => stuff()
case Foo26(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) => stuff()
case Foo27(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) => stuff()
case Foo28(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) => stuff()
case Foo29(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) => stuff()
*/
}
}
9 changes: 9 additions & 0 deletions tests/pos/nested-do-while.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
object Test {
def foo(): Unit = {
val elems: Iterator[Int] = ???
do {
elems.next()
do elems.next() while (elems.hasNext)
} while (elems.hasNext)
}
}
9 changes: 9 additions & 0 deletions tests/pos/nested-while.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
object Test {
def foo(): Unit = {
val elems: Iterator[Int] = ???
while (elems.hasNext) {
elems.next()
while (elems.hasNext) elems.next()
}
}
}