Skip to content

Commit 14974ed

Browse files
committed
Drop trailing implicit args to dropped map
1 parent 240df1c commit 14974ed

File tree

2 files changed

+70
-24
lines changed

2 files changed

+70
-24
lines changed

compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,42 +12,62 @@ import dotty.tools.dotc.transform.MegaPhase.MiniPhase
1212
import dotty.tools.dotc.ast.desugar
1313

1414
/** Drop unused trailing map calls in for comprehensions.
15-
* We can drop the map call if:
16-
* - it won't change the type of the expression, and
17-
* - the function is an identity function or a const function to unit.
18-
*
19-
* The latter condition is checked in [[Desugar.scala#makeFor]]
20-
*/
15+
*
16+
* We can drop the map call if:
17+
* - it won't change the type of the expression, and
18+
* - the function is an identity function or a const function to unit.
19+
*
20+
* The latter condition is checked in [[Desugar.scala#makeFor]]
21+
*/
2122
class DropForMap extends MiniPhase:
2223
import DropForMap.*
2324

2425
override def phaseName: String = DropForMap.name
2526

2627
override def description: String = DropForMap.description
2728

28-
override def transformApply(tree: Apply)(using Context): Tree =
29-
tree.removeAttachment(desugar.TrailingForMap) match
30-
case Some(_) =>
29+
/** r.map(x => x)(using y) --> r
30+
* ^ TrailingForMap
31+
*/
32+
override def transformApply(tree: Apply)(using Context): Tree = tree match
33+
case Unmapped(f) =>
34+
if f.tpe =:= tree.tpe then // make sure that the type of the expression won't change
35+
f // drop the map call
36+
else
37+
f match
38+
case Converted(r) if r.tpe =:= tree.tpe => r // drop the map call and the conversion
39+
case _ => tree
40+
case tree => tree
41+
42+
// Extracts a fun from a possibly nested Apply with lambda and arbitrary implicit args.
43+
private object Unmapped:
44+
private def loop(tree: Tree)(using Context): Option[Tree] =
3145
tree match
32-
case aply @ Apply(MapCall(f), List(Lambda(List(param), body))) =>
33-
if f.tpe =:= aply.tpe then // make sure that the type of the expression won't change
34-
return f // drop the map call
35-
else
36-
f match
37-
case Converted(r) if r.tpe =:= aply.tpe =>
38-
return r // drop the map call and the conversion
46+
case Apply(fun, Lambda(_ :: Nil, _) :: Nil) =>
47+
tree.removeAttachment(desugar.TrailingForMap) match
48+
case Some(_) =>
49+
fun match
50+
case MapCall(f) => return Some(f)
3951
case _ =>
52+
case _ =>
53+
case Apply(fun, _) =>
54+
fun.tpe match
55+
case mt: MethodType if mt.isImplicitMethod => return loop(fun)
56+
case _ =>
4057
case _ =>
41-
case _ =>
42-
tree
58+
None
59+
end loop
60+
def unapply(tree: Apply)(using Context): Option[Tree] =
61+
tree.tpe match
62+
case _: MethodOrPoly => None
63+
case _ => loop(tree)
4364

4465
private object Lambda:
45-
def unapply(tree: Tree)(using Context): Option[(List[ValDef], Tree)] =
46-
tree match
47-
case Block(List(defdef: DefDef), Closure(Nil, ref, _))
48-
if ref.symbol == defdef.symbol && !defdef.paramss.exists(_.forall(_.isType)) =>
49-
Some((defdef.termParamss.flatten, defdef.rhs))
50-
case _ => None
66+
def unapply(tree: Tree)(using Context): Option[(List[ValDef], Tree)] = tree match
67+
case Block(List(defdef: DefDef), Closure(Nil, ref, _))
68+
if ref.symbol == defdef.symbol && !defdef.paramss.exists(_.forall(_.isType)) =>
69+
Some((defdef.termParamss.flatten, defdef.rhs))
70+
case _ => None
5171

5272
private object MapCall:
5373
def unapply(tree: Tree)(using Context): Option[Tree] = tree match

tests/run/i23409b.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//> using options -preview
2+
3+
final class Implicit()
4+
5+
final class Id[+A, -U](val value: A):
6+
def map[B](f: A => B)(using Implicit): Id[B, U] = ??? //Id(f(value))
7+
def flatMap[B, V <: U](f: A => Id[B, V]): Id[B, V] = f(value)
8+
def run: A = value
9+
10+
type Foo = Foo.type
11+
case object Foo:
12+
def get: Id[Int, Foo] = Id(42)
13+
14+
type Bar = Bar.type
15+
case object Bar:
16+
def inc(i: Int): Id[Int, Bar] = Id(i * 10)
17+
18+
def program(using Implicit) =
19+
for
20+
a <- Foo.get
21+
x <- Bar.inc(a)
22+
yield x
23+
24+
@main def Test = println:
25+
given Implicit = Implicit()
26+
program.run

0 commit comments

Comments
 (0)