Skip to content

Commit 5aff9d6

Browse files
authored
Merge pull request #425 from scala/backport-lts-3.3-23197
Backport "fix: handle multiple params lists in for infer type" to 3.3 LTS
2 parents 9288f90 + 2f7b57d commit 5aff9d6

File tree

7 files changed

+371
-309
lines changed

7 files changed

+371
-309
lines changed

compiler/src/dotty/tools/dotc/config/SourceVersion.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@ import core.Decorators.*
66
import util.Property
77

88
enum SourceVersion:
9-
case `3.0-migration`, `3.0`
9+
case `3.0-migration`, `3.0`
1010
case `3.1-migration`, `3.1`
1111
case `3.2-migration`, `3.2`
1212
case `3.3-migration`, `3.3`
1313
case `future-migration`, `future`
14+
case `never` // needed for MigrationVersion.errorFrom if we never want to issue an error
1415

1516
val isMigrating: Boolean = toString.endsWith("-migration")
1617

@@ -32,7 +33,7 @@ object SourceVersion extends Property.Key[SourceVersion]:
3233
val illegalInImports = List(`3.1-migration`, `never`)
3334

3435
/** language versions that may appear in a language import, are deprecated, but not removed from the standard library. */
35-
val illegalSourceVersionNames = "3.1-migration" :: illegalInImports.map(_.toString.toTermName)
36+
val illegalSourceVersionNames = illegalInImports.map(_.toString.toTermName)
3637

3738
/** language versions that the compiler recognises. */
3839
val validSourceVersionNames = values.toList.map(_.toString.toTermName)
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
package dotty.tools.pc
2+
3+
import scala.util.Try
4+
5+
import dotty.tools.dotc.ast.Trees.ValDef
6+
import dotty.tools.dotc.ast.tpd.*
7+
import dotty.tools.dotc.core.Contexts.Context
8+
import dotty.tools.dotc.core.Flags
9+
import dotty.tools.dotc.core.Flags.Method
10+
import dotty.tools.dotc.core.Names.Name
11+
import dotty.tools.dotc.core.StdNames.*
12+
import dotty.tools.dotc.core.SymDenotations.NoDenotation
13+
import dotty.tools.dotc.core.Symbols.defn
14+
import dotty.tools.dotc.core.Symbols.NoSymbol
15+
import dotty.tools.dotc.core.Symbols.Symbol
16+
import dotty.tools.dotc.core.Types.*
17+
import dotty.tools.pc.IndexedContext
18+
import dotty.tools.pc.utils.InteractiveEnrichments.*
19+
import scala.annotation.tailrec
20+
import dotty.tools.dotc.core.Denotations.SingleDenotation
21+
import dotty.tools.dotc.core.Denotations.MultiDenotation
22+
import dotty.tools.dotc.util.Spans.Span
23+
24+
object ApplyExtractor:
25+
def unapply(path: List[Tree])(using Context): Option[Apply] =
26+
path match
27+
case ValDef(_, _, _) :: Block(_, app: Apply) :: _
28+
if !app.fun.isInfix => Some(app)
29+
case rest =>
30+
def getApplyForContextFunctionParam(path: List[Tree]): Option[Apply] =
31+
path match
32+
// fun(arg@@)
33+
case (app: Apply) :: _ => Some(app)
34+
// fun(arg@@), where fun(argn: Context ?=> SomeType)
35+
// recursively matched for multiple context arguments, e.g. Context1 ?=> Context2 ?=> SomeType
36+
case (_: DefDef) :: Block(List(_), _: Closure) :: rest =>
37+
getApplyForContextFunctionParam(rest)
38+
case _ => None
39+
for
40+
app <- getApplyForContextFunctionParam(rest)
41+
if !app.fun.isInfix
42+
yield app
43+
end match
44+
45+
46+
object ApplyArgsExtractor:
47+
def getArgsAndParams(
48+
optIndexedContext: Option[IndexedContext],
49+
apply: Apply,
50+
span: Span
51+
)(using Context): List[(List[Tree], List[ParamSymbol])] =
52+
def collectArgss(a: Apply): List[List[Tree]] =
53+
def stripContextFuntionArgument(argument: Tree): List[Tree] =
54+
argument match
55+
case Block(List(d: DefDef), _: Closure) =>
56+
d.rhs match
57+
case app: Apply =>
58+
app.args
59+
case b @ Block(List(_: DefDef), _: Closure) =>
60+
stripContextFuntionArgument(b)
61+
case _ => Nil
62+
case v => List(v)
63+
64+
val args = a.args.flatMap(stripContextFuntionArgument)
65+
a.fun match
66+
case app: Apply => collectArgss(app) :+ args
67+
case _ => List(args)
68+
end collectArgss
69+
70+
val method = apply.fun
71+
72+
val argss = collectArgss(apply)
73+
74+
def fallbackFindApply(sym: Symbol) =
75+
sym.info.member(nme.apply) match
76+
case NoDenotation => Nil
77+
case den => List(den.symbol)
78+
79+
// fallback for when multiple overloaded methods match the supplied args
80+
def fallbackFindMatchingMethods() =
81+
def matchingMethodsSymbols(
82+
indexedContext: IndexedContext,
83+
method: Tree
84+
): List[Symbol] =
85+
method match
86+
case Ident(name) => indexedContext.findSymbol(name).getOrElse(Nil)
87+
case Select(This(_), name) => indexedContext.findSymbol(name).getOrElse(Nil)
88+
case sel @ Select(from, name) =>
89+
val symbol = from.symbol
90+
val ownerSymbol =
91+
if symbol.is(Method) && symbol.owner.isClass then
92+
Some(symbol.owner)
93+
else Try(symbol.info.classSymbol).toOption
94+
ownerSymbol.map(sym => sym.info.member(name)).collect{
95+
case single: SingleDenotation => List(single.symbol)
96+
case multi: MultiDenotation => multi.allSymbols
97+
}.getOrElse(Nil)
98+
case Apply(fun, _) => matchingMethodsSymbols(indexedContext, fun)
99+
case _ => Nil
100+
val matchingMethods =
101+
for
102+
indexedContext <- optIndexedContext.toList
103+
potentialMatch <- matchingMethodsSymbols(indexedContext, method)
104+
if potentialMatch.is(Flags.Method) &&
105+
potentialMatch.vparamss.length >= argss.length &&
106+
Try(potentialMatch.isAccessibleFrom(apply.symbol.info)).toOption
107+
.getOrElse(false) &&
108+
potentialMatch.vparamss
109+
.zip(argss)
110+
.reverse
111+
.zipWithIndex
112+
.forall { case (pair, index) =>
113+
FuzzyArgMatcher(potentialMatch.tparams)
114+
.doMatch(allArgsProvided = index != 0, span)
115+
.tupled(pair)
116+
}
117+
yield potentialMatch
118+
matchingMethods
119+
end fallbackFindMatchingMethods
120+
121+
val matchingMethods: List[Symbol] =
122+
if method.symbol.paramSymss.nonEmpty then
123+
val allArgsAreSupplied =
124+
val vparamss = method.symbol.vparamss
125+
vparamss.length == argss.length && vparamss
126+
.zip(argss)
127+
.lastOption
128+
.exists { case (baseParams, baseArgs) =>
129+
baseArgs.length == baseParams.length
130+
}
131+
// ```
132+
// m(arg : Int)
133+
// m(arg : Int, anotherArg : Int)
134+
// m(a@@)
135+
// ```
136+
// complier will choose the first `m`, so we need to manually look for the other one
137+
if allArgsAreSupplied then
138+
val foundPotential = fallbackFindMatchingMethods()
139+
if foundPotential.contains(method.symbol) then foundPotential
140+
else method.symbol :: foundPotential
141+
else List(method.symbol)
142+
else if method.symbol.is(Method) || method.symbol == NoSymbol then
143+
fallbackFindMatchingMethods()
144+
else fallbackFindApply(method.symbol)
145+
end if
146+
end matchingMethods
147+
148+
matchingMethods.map { methodSym =>
149+
val vparamss = methodSym.vparamss
150+
151+
// get params and args we are interested in
152+
// e.g.
153+
// in the following case, the interesting args and params are
154+
// - params: [apple, banana]
155+
// - args: [apple, b]
156+
// ```
157+
// def curry(x: Int)(apple: String, banana: String) = ???
158+
// curry(1)(apple = "test", b@@)
159+
// ```
160+
val (baseParams0, baseArgs) =
161+
vparamss.zip(argss).lastOption.getOrElse((Nil, Nil))
162+
163+
val baseParams: List[ParamSymbol] =
164+
def defaultBaseParams = baseParams0.map(JustSymbol(_))
165+
@tailrec
166+
def getRefinedParams(refinedType: Type, level: Int): List[ParamSymbol] =
167+
if level > 0 then
168+
val resultTypeOpt =
169+
refinedType match
170+
case RefinedType(AppliedType(_, args), _, _) => args.lastOption
171+
case AppliedType(_, args) => args.lastOption
172+
case _ => None
173+
resultTypeOpt match
174+
case Some(resultType) => getRefinedParams(resultType, level - 1)
175+
case _ => defaultBaseParams
176+
else
177+
refinedType match
178+
case RefinedType(AppliedType(_, args), _, MethodType(ri)) =>
179+
baseParams0.zip(ri).zip(args).map { case ((sym, name), arg) =>
180+
RefinedSymbol(sym, name, arg)
181+
}
182+
case _ => defaultBaseParams
183+
// finds param refinements for lambda expressions
184+
// val hello: (x: Int, y: Int) => Unit = (x, _) => println(x)
185+
@tailrec
186+
def refineParams(method: Tree, level: Int): List[ParamSymbol] =
187+
method match
188+
case Select(Apply(f, _), _) => refineParams(f, level + 1)
189+
case Select(h, name) =>
190+
// for Select(foo, name = apply) we want `foo.symbol`
191+
if name == nme.apply then getRefinedParams(h.symbol.info, level)
192+
else getRefinedParams(method.symbol.info, level)
193+
case Apply(f, _) =>
194+
refineParams(f, level + 1)
195+
case _ => getRefinedParams(method.symbol.info, level)
196+
refineParams(method, 0)
197+
end baseParams
198+
(baseArgs, baseParams)
199+
}
200+
201+
extension (method: Symbol)
202+
def vparamss(using Context) = method.filteredParamss(_.isTerm)
203+
def tparams(using Context) = method.filteredParamss(_.isType).flatten
204+
def filteredParamss(f: Symbol => Boolean)(using Context) =
205+
method.paramSymss.filter(params => params.forall(f))
206+
sealed trait ParamSymbol:
207+
def name: Name
208+
def info: Type
209+
def symbol: Symbol
210+
def nameBackticked(using Context) = name.decoded.backticked
211+
212+
case class JustSymbol(symbol: Symbol)(using Context) extends ParamSymbol:
213+
def name: Name = symbol.name
214+
def info: Type = symbol.info
215+
216+
case class RefinedSymbol(symbol: Symbol, name: Name, info: Type)
217+
extends ParamSymbol
218+
219+
220+
class FuzzyArgMatcher(tparams: List[Symbol])(using Context):
221+
222+
/**
223+
* A heuristic for checking if the passed arguments match the method's arguments' types.
224+
* For non-polymorphic methods we use the subtype relation (`<:<`)
225+
* and for polymorphic methods we use a heuristic.
226+
* We check the args types not the result type.
227+
*/
228+
def doMatch(
229+
allArgsProvided: Boolean,
230+
span: Span
231+
)(expectedArgs: List[Symbol], actualArgs: List[Tree]) =
232+
(expectedArgs.length == actualArgs.length ||
233+
(!allArgsProvided && expectedArgs.length >= actualArgs.length)) &&
234+
actualArgs.zipWithIndex.forall {
235+
case (arg: Ident, _) if arg.span.contains(span) => true
236+
case (NamedArg(name, arg), _) =>
237+
expectedArgs.exists { expected =>
238+
expected.name == name && (!arg.hasType || arg.typeOpt.unfold
239+
.fuzzyArg_<:<(expected.info))
240+
}
241+
case (arg, i) =>
242+
!arg.hasType || arg.typeOpt.unfold.fuzzyArg_<:<(expectedArgs(i).info)
243+
}
244+
245+
extension (arg: Type)
246+
def fuzzyArg_<:<(expected: Type) =
247+
if tparams.isEmpty then arg <:< expected
248+
else arg <:< substituteTypeParams(expected)
249+
def unfold =
250+
arg match
251+
case arg: TermRef => arg.underlying
252+
case e => e
253+
254+
private def substituteTypeParams(t: Type): Type =
255+
t match
256+
case e if tparams.exists(_ == e.typeSymbol) =>
257+
val matchingParam = tparams.find(_ == e.typeSymbol).get
258+
matchingParam.info match
259+
case b @ TypeBounds(_, _) => WildcardType(b)
260+
case _ => WildcardType
261+
case o @ OrType(e1, e2) =>
262+
OrType(substituteTypeParams(e1), substituteTypeParams(e2), o.isSoft)
263+
case AndType(e1, e2) =>
264+
AndType(substituteTypeParams(e1), substituteTypeParams(e2))
265+
case AppliedType(et, eparams) =>
266+
AppliedType(et, eparams.map(substituteTypeParams))
267+
case _ => t
268+
269+
end FuzzyArgMatcher

presentation-compiler/src/main/dotty/tools/pc/CompilerSearchVisitor.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class CompilerSearchVisitor(
3434
logger.log(Level.WARNING, err.getMessage())
3535
false
3636
case NonFatal(e) =>
37-
reports.incognito.create(
37+
reports.incognito.nn.create(
3838
() => Report(
3939
"is_public",
4040
s"""Symbol: $sym""".stripMargin,

presentation-compiler/src/main/dotty/tools/pc/HoverProvider.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ object HoverProvider:
8686
s"$uri::$posId"
8787
)
8888
end report
89-
reportContext.unsanitized.create(() => report, /*ifVerbose =*/ true)
89+
reportContext.unsanitized.nn.create(() => report, /*ifVerbose =*/ true)
9090
ju.Optional.empty().nn
9191
else
9292
val skipCheckOnName =

0 commit comments

Comments
 (0)