Skip to content

Redefine quoted.Expr.betaReduce #9469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion community-build/community-projects/utest
Original file line number Diff line number Diff line change
Expand Up @@ -2042,21 +2042,18 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
case _ => None
}

def betaReduce(fn: Term, args: List[Term])(using Context): Term = {
val (argVals0, argRefs0) = args.foldLeft((List.empty[ValDef], List.empty[Tree])) { case ((acc1, acc2), arg) => arg.tpe match {
case tpe: SingletonType if isIdempotentExpr(arg) => (acc1, arg :: acc2)
def betaReduce(tree: Term)(using Context): Option[Term] =
tree match
case app @ Apply(Select(fn, nme.apply), args) if defn.isFunctionType(fn.tpe) =>
val app1 = transform.BetaReduce(app, fn, args)
if app1 eq app then None
else Some(app1.withSpan(tree.span))
case Block(Nil, expr) =>
for e <- betaReduce(expr) yield cpy.Block(tree)(Nil, e)
case Inlined(_, Nil, expr) =>
betaReduce(expr)
case _ =>
val argVal = SyntheticValDef(NameKinds.UniqueName.fresh("x".toTermName), arg).withSpan(arg.span)
(argVal :: acc1, ref(argVal.symbol) :: acc2)
}}
val argVals = argVals0.reverse
val argRefs = argRefs0.reverse
val reducedBody = lambdaExtractor(fn, argRefs.map(_.tpe)) match {
case Some(body) => body(argRefs)
case None => fn.select(nme.apply).appliedToArgs(argRefs)
}
seq(argVals, reducedBody).withSpan(fn.span)
}
None

def lambdaExtractor(fn: Term, paramTypes: List[Type])(using Context): Option[List[Term] => Term] = {
def rec(fn: Term, transformBody: Term => Term): Option[List[Term] => Term] = {
Expand Down
31 changes: 22 additions & 9 deletions compiler/src/dotty/tools/dotc/transform/BetaReduce.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,26 @@ class BetaReduce extends MiniPhase:

override def transformApply(app: Apply)(using Context): Tree = app.fun match
case Select(fn, nme.apply) if defn.isFunctionType(fn.tpe) =>
val app1 = betaReduce(app, fn, app.args)
val app1 = BetaReduce(app, fn, app.args)
if app1 ne app then report.log(i"beta reduce $app -> $app1")
app1
case _ =>
app

private def betaReduce(tree: Apply, fn: Tree, args: List[Tree])(using Context): Tree =
fn match
case Typed(expr, _) => betaReduce(tree, expr, args)
case Block(Nil, expr) => betaReduce(tree, expr, args)
case Block((anonFun: DefDef) :: Nil, closure: Closure) => BetaReduce(anonFun, args)
case _ => tree

object BetaReduce:
import ast.tpd._

/** Beta-reduces a call to `fn` with arguments `argSyms` or returns `tree` */
def apply(tree: Apply, fn: Tree, args: List[Tree])(using Context): Tree =
fn match
case Typed(expr, _) => BetaReduce(tree, expr, args)
case Block(Nil, expr) => BetaReduce(tree, expr, args)
case Inlined(_, Nil, expr) => BetaReduce(tree, expr, args)
case Block((anonFun: DefDef) :: Nil, closure: Closure) => BetaReduce(anonFun, args)
case _ => tree
end apply

/** Beta-reduces a call to `ddef` with arguments `argSyms` */
def apply(ddef: DefDef, args: List[Tree])(using Context) =
val bindings = List.newBuilder[ValDef]
Expand All @@ -65,7 +69,8 @@ object BetaReduce:
ref.symbol
case _ =>
val flags = Synthetic | (param.symbol.flags & Erased)
val binding = ValDef(newSymbol(ctx.owner, param.name, flags, arg.tpe.widen, coord = arg.span), arg)
val tpe = if arg.tpe.dealias.isInstanceOf[ConstantType] then arg.tpe.dealias else arg.tpe.widen
val binding = ValDef(newSymbol(ctx.owner, param.name, flags, tpe, coord = arg.span), arg).withSpan(arg.span)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unclear to me why we cannot just use arg.tpe.widen as the type of the binding?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is to allow TermRef to definitions with constant types to be constant folded after this transformation.

bindings += binding
binding.symbol

Expand All @@ -76,5 +81,13 @@ object BetaReduce:
substTo = argSyms
).transform(ddef.rhs)

seq(bindings.result(), expansion)
val expansion1 = new TreeMap {
override def transform(tree: Tree)(using Context) = tree.tpe.widenTermRefExpr match
case ConstantType(const) if isPureExpr(tree) => cpy.Literal(tree)(const)
case _ => super.transform(tree)
}.transform(expansion)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is constant folding -- can we reuse the logic in ConstFold?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not exactly constant folding. It is just the propagation of constants. After this ConstFold is used at some point and performs the actual folding.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about TreeInfo.constToLiteral?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It fails. It also looks like overkill.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Maybe as an optimization for later, beta-reduce & inlining may create new const-folding opportunities.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed

val bindings1 =
bindings.result().filterNot(vdef => vdef.tpt.tpe.isInstanceOf[ConstantType] && isPureExpr(vdef.rhs))

seq(bindings1, expansion1)
end apply
41 changes: 30 additions & 11 deletions compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -399,22 +399,15 @@ class InlineBytecodeTests extends DottyBytecodeTest {
val instructions = instructionsFromMethod(fun)
val expected =
List(
// Head tested separatly
VarOp(ALOAD, 0),
Invoke(INVOKEVIRTUAL, "Test", "given_Int", "()I", false),
Invoke(INVOKESTATIC, "scala/runtime/BoxesRunTime", "boxToInteger", "(I)Ljava/lang/Integer;", false),
Invoke(INVOKEINTERFACE, "dotty/runtime/function/JFunction1$mcZI$sp", "apply", "(Ljava/lang/Object;)Ljava/lang/Object;", true),
Invoke(INVOKESTATIC, "scala/runtime/BoxesRunTime", "unboxToBoolean", "(Ljava/lang/Object;)Z", false),
VarOp(ISTORE, 1),
Op(ICONST_1),
Op(IRETURN)
)

instructions.head match {
case InvokeDynamic(INVOKEDYNAMIC, "apply$mcZI$sp", "()Ldotty/runtime/function/JFunction1$mcZI$sp;", _, _) =>
case _ => assert(false, "`g` was not properly inlined in `test`\n")
}

assert(instructions.tail == expected,
"`fg was not properly inlined in `test`\n" + diffInstructions(instructions.tail, expected))
assert(instructions == expected,
"`fg was not properly inlined in `test`\n" + diffInstructions(instructions, expected))

}
}
Expand Down Expand Up @@ -505,4 +498,30 @@ class InlineBytecodeTests extends DottyBytecodeTest {
}
}


@Test def i9466 = {
val source = """class Test:
| inline def i(inline f: Int => Boolean): String =
| if f(34) then "a"
| else "b"
| def test = i(f = _ == 34)
""".stripMargin

checkBCode(source) { dir =>
val clsIn = dir.lookupName("Test.class", directory = false).input
val clsNode = loadClassNode(clsIn)

val fun = getMethod(clsNode, "test")
val instructions = instructionsFromMethod(fun)
val expected =
List(
Ldc(LDC, "a"),
Op(ARETURN)
)

assert(instructions == expected,
"`i was not properly inlined in `test`\n" + diffInstructions(instructions, expected))

}
}
}
31 changes: 8 additions & 23 deletions library/src-bootstrapped/scala/quoted/Expr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,30 +61,15 @@ abstract class Expr[+T] private[scala] {

object Expr {

/** Converts a tuple `(T1, ..., Tn)` to `(Expr[T1], ..., Expr[Tn])` */
type TupleOfExpr[Tup <: Tuple] = Tuple.Map[Tup, [X] =>> QuoteContext ?=> Expr[X]]

/** `Expr.betaReduce(f)(x1, ..., xn)` is functionally the same as `'{($f)($x1, ..., $xn)}`, however it optimizes this call
* by returning the result of beta-reducing `f(x1, ..., xn)` if `f` is a known lambda expression.
*
* `Expr.betaReduce` distributes applications of `Expr` over function arrows
* ```scala
* Expr.betaReduce(_): Expr[(T1, ..., Tn) => R] => ((Expr[T1], ..., Expr[Tn]) => Expr[R])
* ```
*/
def betaReduce[F, Args <: Tuple, R, G](f: Expr[F])(using tf: TupledFunction[F, Args => R], tg: TupledFunction[G, TupleOfExpr[Args] => Expr[R]], qctx: QuoteContext): G =
tg.untupled(args => qctx.tasty.internal.betaReduce(f.unseal, args.toArray.toList.map(_.asInstanceOf[QuoteContext => Expr[Any]](qctx).unseal)).seal.asInstanceOf[Expr[R]])

/** `Expr.betaReduceGiven(f)(x1, ..., xn)` is functionally the same as `'{($f)(using $x1, ..., $xn)}`, however it optimizes this call
* by returning the result of beta-reducing `f(using x1, ..., xn)` if `f` is a known lambda expression.
*
* `Expr.betaReduceGiven` distributes applications of `Expr` over function arrows
* ```scala
* Expr.betaReduceGiven(_): Expr[(T1, ..., Tn) ?=> R] => ((Expr[T1], ..., Expr[Tn]) => Expr[R])
* ```
/** `e.betaReduce` returns an expression that is functionally equivalent to `e`,
* however if `e` is of the form `((y1, ..., yn) => e2)(x1, ..., xn)`
* then it optimizes this the top most call by returning the result of beta-reducing the application.
* Otherwise returns `expr`.
*/
def betaReduceGiven[F, Args <: Tuple, R, G](f: Expr[F])(using tf: TupledFunction[F, Args ?=> R], tg: TupledFunction[G, TupleOfExpr[Args] => Expr[R]], qctx: QuoteContext): G =
tg.untupled(args => qctx.tasty.internal.betaReduce(f.unseal, args.toArray.toList.map(_.asInstanceOf[QuoteContext => Expr[Any]](qctx).unseal)).seal.asInstanceOf[Expr[R]])
def betaReduce[T](expr: Expr[T])(using qctx: QuoteContext): Expr[T] =
qctx.tasty.internal.betaReduce(expr.unseal) match
case Some(expr1) => expr1.seal.asInstanceOf[Expr[T]]
case _ => expr

/** Returns a null expresssion equivalent to `'{null}` */
def nullExpr: QuoteContext ?=> Expr[Null] = qctx ?=> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ object UnsafeExpr {
def underlyingArgument[T](expr: Expr[T])(using qctx: QuoteContext): Expr[T] =
expr.unseal.underlyingArgument.seal.asInstanceOf[Expr[T]]

// TODO generalize for any function arity (see Expr.betaReduce)
// TODO generalize for any function arity
/** Allows inspection or transformation of the body of the expression of function.
* This body may have references to the arguments of the function which should be closed
* over if the expression will be spliced.
Expand Down
6 changes: 2 additions & 4 deletions library/src/scala/tasty/reflect/CompilerInterface.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1600,10 +1600,8 @@ trait CompilerInterface {
*/
def searchImplicit(tpe: Type)(using ctx: Context): ImplicitSearchResult

/** Inline fn if it is an explicit closure possibly nested inside the expression of a block.
* Otherwise apply the arguments to the closure.
*/
def betaReduce(f: Term, args: List[Term])(using ctx: Context): Term
/** Returns Some with a beta-reduced application or None */
def betaReduce(tree: Term)(using Context): Option[Term]

def lambdaExtractor(term: Term, paramTypes: List[Type])(using ctx: Context): Option[List[Term] => Term]

Expand Down
4 changes: 2 additions & 2 deletions tests/neg-macros/beta-reduce-inline-result/Macro_1.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ object Macros {
inline def betaReduce[Arg,Result](inline fn: Arg=>Result)(inline arg: Arg): Result =
${ betaReduceImpl('{ fn })('{ arg }) }

def betaReduceImpl[Arg,Result](fn: Expr[Arg=>Result])(arg: Expr[Arg])(using qctx: QuoteContext): Expr[Result] =
Expr.betaReduce(fn)(arg)
def betaReduceImpl[Arg: Type, Result: Type](fn: Expr[Arg=>Result])(arg: Expr[Arg])(using qctx: QuoteContext): Expr[Result] =
Expr.betaReduce('{$fn($arg)})
}

2 changes: 1 addition & 1 deletion tests/pos-macros/i6783.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import scala.quoted._

def testImpl(f: Expr[(Int, Int) => Int])(using QuoteContext): Expr[Int] = Expr.betaReduce(f)('{1}, '{2})
def testImpl(f: Expr[(Int, Int) => Int])(using QuoteContext): Expr[Int] = Expr.betaReduce('{$f(1, 2)})

inline def test(f: (Int, Int) => Int) = ${
testImpl('f)
Expand Down
4 changes: 2 additions & 2 deletions tests/run-macros/beta-reduce-inline-result.check
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
compile-time: ((3.+(1): scala.Int): scala.Int)
compile-time: (4: scala.Int)
run-time: 4
compile-time: ((1: 1): scala.Int)
compile-time: (1: scala.Int)
run-time: 1
run-time: 5
run-time: 7
Expand Down
9 changes: 5 additions & 4 deletions tests/run-macros/beta-reduce-inline-result/Macro_1.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ object Macros {
inline def betaReduce[Arg,Result](inline fn : Arg=>Result)(inline arg: Arg): Result =
${ betaReduceImpl('{ fn })('{ arg }) }

def betaReduceImpl[Arg,Result](fn: Expr[Arg=>Result])(arg: Expr[Arg])(using qctx : QuoteContext): Expr[Result] =
Expr.betaReduce(fn)(arg)
def betaReduceImpl[Arg: Type, Result: Type](fn: Expr[Arg=>Result])(arg: Expr[Arg])(using qctx : QuoteContext): Expr[Result] =
Expr.betaReduce('{$fn($arg)})

inline def betaReduceAdd1[Arg](inline fn: Arg=>Int)(inline arg: Arg): Int =
${ betaReduceAdd1Impl('{ fn })('{ arg }) }

def betaReduceAdd1Impl[Arg](fn: Expr[Arg=>Int])(arg: Expr[Arg])(using qctx: QuoteContext): Expr[Int] =
'{ ${ Expr.betaReduce(fn)(arg) } + 1 }
def betaReduceAdd1Impl[Arg: Type](fn: Expr[Arg=>Int])(arg: Expr[Arg])(using qctx: QuoteContext): Expr[Int] =
val app = '{$fn.asInstanceOf[Arg=>Int]($arg)} // FIXME: remove asInstanceOf (workaround for #8612)
'{ ${ Expr.betaReduce(app) } + 1 }
}

1 change: 0 additions & 1 deletion tests/run-macros/beta-reduce-inline-result/Test_2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,3 @@ object Test {
println(s"run-time: ${Macros.betaReduce(dummy7)(8)}")
}
}

4 changes: 2 additions & 2 deletions tests/run-macros/gestalt-optional-staging/Macro_1.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ final class Optional[+A >: Null](val value: A) extends AnyVal {

inline def getOrElse[B >: A](alt: => B): B = ${ Optional.getOrElseImpl('this, 'alt) }

inline def map[B >: Null](f: A => B): Optional[B] = ${ Optional.mapImpl('this, 'f) }
inline def map[B >: Null](inline f: A => B): Optional[B] = ${ Optional.mapImpl('this, 'f) }

override def toString = if (isEmpty) "<empty>" else s"$value"
}
Expand All @@ -24,7 +24,7 @@ object Optional {
// FIXME fix issue #5097 and enable private
/*private*/ def mapImpl[A >: Null : Type, B >: Null : Type](opt: Expr[Optional[A]], f: Expr[A => B])(using QuoteContext): Expr[Optional[B]] = '{
if ($opt.isEmpty) new Optional(null)
else new Optional(${Expr.betaReduce(f)('{$opt.value})})
else new Optional(${Expr.betaReduce('{$f($opt.value)})})
}

}
4 changes: 2 additions & 2 deletions tests/run-macros/i4734/Macro_1.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import scala.annotation.tailrec
import scala.quoted._

object Macros {
inline def unrolledForeach(seq: IndexedSeq[Int], f: => Int => Unit, inline unrollSize: Int): Unit = // or f: Int => Unit
inline def unrolledForeach(seq: IndexedSeq[Int], inline f: Int => Unit, inline unrollSize: Int): Unit = // or f: Int => Unit
${ unrolledForeachImpl('seq, 'f, 'unrollSize) }

def unrolledForeachImpl(seq: Expr[IndexedSeq[Int]], f: Expr[Int => Unit], unrollSizeExpr: Expr[Int]) (using QuoteContext): Expr[Unit] =
Expand All @@ -17,7 +17,7 @@ object Macros {
for (j <- new UnrolledRange(0, unrollSize)) '{
val index = i + ${Expr(j)}
val element = ($seq)(index)
${ Expr.betaReduce(f)('element) } // or `($f)(element)` if `f` should not be inlined
${ Expr.betaReduce('{$f(element)}) } // or `($f)(element)` if `f` should not be inlined
}
}
i += ${Expr(unrollSize)}
Expand Down
2 changes: 1 addition & 1 deletion tests/run-macros/i4735/App_2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ object Test {
}

class Unrolled(arr: Array[Int]) extends AnyVal {
inline def foreach(f: => Int => Unit): Unit = Macro.unrolledForeach(3, arr, f)
inline def foreach(inline f: Int => Unit): Unit = Macro.unrolledForeach(3, arr, f)
}
4 changes: 2 additions & 2 deletions tests/run-macros/i4735/Macro_1.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import scala.quoted._

object Macro {

inline def unrolledForeach(inline unrollSize: Int, seq: Array[Int], f: => Int => Unit): Unit = // or f: Int => Unit
inline def unrolledForeach(inline unrollSize: Int, seq: Array[Int], inline f: Int => Unit): Unit = // or f: Int => Unit
${ unrolledForeachImpl('unrollSize, 'seq, 'f) }

private def unrolledForeachImpl(unrollSize: Expr[Int], seq: Expr[Array[Int]], f: Expr[Int => Unit]) (using QuoteContext): Expr[Unit] = '{
Expand All @@ -16,7 +16,7 @@ object Macro {
${
for (j <- new UnrolledRange(0, unrollSize.unliftOrError)) '{
val element = ($seq)(i + ${Expr(j)})
${Expr.betaReduce(f)('element)} // or `($f)(element)` if `f` should not be inlined
${Expr.betaReduce('{$f(element)})} // or `($f)(element)` if `f` should not be inlined
}
}
i += ${unrollSize}
Expand Down
2 changes: 1 addition & 1 deletion tests/run-macros/i7008/macro_1.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ def mcrProxy(expr: Expr[Boolean])(using QuoteContext): Expr[Unit] = {
def mcrImpl[T](func: Expr[Seq[Box[T]] => Unit], expr: Expr[T])(using ctx: QuoteContext, tt: Type[T]): Expr[Unit] = {
import ctx.tasty._
val arg = Varargs(Seq('{(Box($expr))}))
Expr.betaReduce(func)(arg)
Expr.betaReduce('{$func($arg)})
}
20 changes: 8 additions & 12 deletions tests/run-macros/quote-inline-function.check
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@ Normal function
var i: scala.Int = 0
val j: scala.Int = 5
while (i.<(j)) {
val x$1: scala.Int = i
f.apply(x$1)
f.apply(i)
i = i.+(1)
}
while ({
val x$2: scala.Int = i
f.apply(x$2)
f.apply(i)
i = i.+(1)
i.<(j)
}) ()
Expand All @@ -20,13 +18,11 @@ By name function
var i: scala.Int = 0
val j: scala.Int = 5
while (i.<(j)) {
val x$3: scala.Int = i
f.apply(x$3)
f.apply(i)
i = i.+(1)
}
while ({
val x$4: scala.Int = i
f.apply(x$4)
f.apply(i)
i = i.+(1)
i.<(j)
}) ()
Expand All @@ -37,13 +33,13 @@ Inline function
var i: scala.Int = 0
val j: scala.Int = 5
while (i.<(j)) {
val x$5: scala.Int = i
scala.Predef.println(x$5)
val x: scala.Int = i
scala.Predef.println(x)
i = i.+(1)
}
while ({
val x$6: scala.Int = i
scala.Predef.println(x$6)
val `x₂`: scala.Int = i
scala.Predef.println(`x₂`)
i = i.+(1)
i.<(j)
}) ()
Expand Down
Loading