Skip to content

Commit 3dd303b

Browse files
committed
fix: use signature help to infer type for apply
1 parent a8f2e1f commit 3dd303b

File tree

5 files changed

+93
-58
lines changed

5 files changed

+93
-58
lines changed

compiler/src/dotty/tools/dotc/util/Signatures.scala

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,15 @@ object Signatures {
5757
*/
5858
case class MethodParam(
5959
name: String,
60-
tpe: String,
60+
tpe: Type,
61+
tpeString: String,
6162
override val doc: Option[String] = None,
6263
isImplicit: Boolean = false,
6364
isReordered: Boolean = false
6465
) extends Param:
6566
def show: String = if name.nonEmpty && !isReordered then s"$name: $tpe"
6667
else if name.nonEmpty then s"[$name: $tpe]"
67-
else tpe
68+
else tpeString
6869

6970
/**
7071
* Represent a type parameter.
@@ -389,12 +390,7 @@ object Signatures {
389390
&& tree.symbol.exists
390391
&& ctx.definitions.isTupleClass(tree.symbol.owner.companionClass)
391392

392-
val isFunctionNApply =
393-
tree.symbol.name == nme.apply
394-
&& tree.symbol.exists
395-
&& ctx.definitions.isFunctionSymbol(tree.symbol.owner)
396-
397-
!isTupleApply && !isFunctionNApply
393+
!isTupleApply
398394

399395
/**
400396
* Get unapply method result type omiting unknown types and another method calls.
@@ -538,10 +534,11 @@ object Signatures {
538534
val params = currentParams.map: (symbol, info) =>
539535
// TODO after we migrate ShortenedTypePrinter into the compiler, it should rely on its api
540536
val name = if symbol.isAllOf(Flags.Given | Flags.Param) && symbol.name.startsWith("x$") then nme.EMPTY else symbol.name.asTermName
541-
537+
val tpe = info.widenTermRefExpr
542538
Signatures.MethodParam(
543539
name.show,
544-
info.widenTermRefExpr.show,
540+
tpe,
541+
tpe.show,
545542
docComment.flatMap(_.paramDoc(name)),
546543
isImplicit = tp.isImplicitMethod,
547544
)
@@ -615,10 +612,10 @@ object Signatures {
615612
*/
616613
private def toUnapplySignature(denot: SingleDenotation, paramNames: List[Name], paramTypes: List[Type])(using Context): Option[Signature] =
617614
val params = if paramNames.length == paramTypes.length then
618-
(paramNames zip paramTypes).map((name, info) => MethodParam(name.show, info.show))
615+
(paramNames zip paramTypes).map((name, info) => MethodParam(name.show, info, info.show))
619616
else
620617
// even if we only show types of arguments, they are still method params
621-
paramTypes.map(info => MethodParam("", info.show))
618+
paramTypes.map(info => MethodParam("", info, info.show))
622619

623620
if params.nonEmpty then Some(Signature("", List(params), None, None, Some(denot)))
624621
else None

presentation-compiler/src/main/dotty/tools/pc/InferExpectedType.scala

Lines changed: 45 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ import dotty.tools.pc.utils.InteractiveEnrichments.*
2020
import scala.meta.pc.reports.ReportContext
2121
import scala.meta.pc.OffsetParams
2222
import scala.meta.pc.SymbolSearch
23+
import dotty.tools.dotc.util.Signatures
24+
import dotty.tools.dotc.util.Signatures.MethodParam
25+
import dotty.tools.dotc.util.Signatures.TypeParam
2326

2427
class InferExpectedType(
2528
search: SymbolSearch,
@@ -50,32 +53,32 @@ class InferExpectedType(
5053
val indexedCtx = IndexedContext(pos)(using locatedCtx)
5154
val printer =
5255
ShortenedTypePrinter(search, IncludeDefaultParam.ResolveLater)(using indexedCtx)
53-
InterCompletionType.inferType(path)(using newctx).map{
56+
InferCompletionType.inferType(path)(using newctx).map{
5457
tpe => printer.tpe(tpe)
5558
}
5659
case None => None
5760

58-
object InterCompletionType:
61+
object InferCompletionType:
5962
def inferType(path: List[Tree])(using Context): Option[Type] =
6063
path match
61-
case (lit: Literal) :: Select(Literal(_), _) :: Apply(Select(Literal(_), _), List(s: Select)) :: rest if s.symbol == defn.Predef_undefined => inferType(rest, lit.span)
62-
case ident :: rest => inferType(rest, ident.span)
64+
case (lit: Literal) :: Select(Literal(_), _) :: Apply(Select(Literal(_), _), List(s: Select)) :: rest if s.symbol == defn.Predef_undefined => inferType(rest, lit.span, path)
65+
case ident :: rest => inferType(rest, ident.span, path)
6366
case _ => None
6467

65-
def inferType(path: List[Tree], span: Span)(using Context): Option[Type] =
68+
def inferType(path: List[Tree], span: Span, fullPath: List[Tree])(using Context): Option[Type] =
6669
path match
6770
case Typed(expr, tpt) :: _ if expr.span.contains(span) && !tpt.tpe.isErroneous => Some(tpt.tpe)
6871
case Block(_, expr) :: rest if expr.span.contains(span) =>
69-
inferType(rest, span)
70-
case Bind(_, body) :: rest if body.span.contains(span) => inferType(rest, span)
71-
case Alternative(_) :: rest => inferType(rest, span)
72-
case Try(block, _, _) :: rest if block.span.contains(span) => inferType(rest, span)
73-
case CaseDef(_, _, body) :: Try(_, cases, _) :: rest if body.span.contains(span) && cases.exists(_.span.contains(span)) => inferType(rest, span)
74-
case If(cond, _, _) :: rest if !cond.span.contains(span) => inferType(rest, span)
72+
inferType(rest, span, fullPath)
73+
case Bind(_, body) :: rest if body.span.contains(span) => inferType(rest, span, fullPath)
74+
case Alternative(_) :: rest => inferType(rest, span, fullPath)
75+
case Try(block, _, _) :: rest if block.span.contains(span) => inferType(rest, span, fullPath)
76+
case CaseDef(_, _, body) :: Try(_, cases, _) :: rest if body.span.contains(span) && cases.exists(_.span.contains(span)) => inferType(rest, span, fullPath)
77+
case If(cond, _, _) :: rest if !cond.span.contains(span) => inferType(rest, span, fullPath)
7578
case If(cond, _, _) :: rest if cond.span.contains(span) => Some(defn.BooleanType)
7679
case CaseDef(_, _, body) :: Match(_, cases) :: rest if body.span.contains(span) && cases.exists(_.span.contains(span)) =>
77-
inferType(rest, span)
78-
case NamedArg(_, arg) :: rest if arg.span.contains(span) => inferType(rest, span)
80+
inferType(rest, span, fullPath)
81+
case NamedArg(_, arg) :: rest if arg.span.contains(span) => inferType(rest, span, fullPath)
7982
// x match
8083
// case @@
8184
case CaseDef(pat, _, _) :: Match(sel, cases) :: rest if pat.span.contains(span) && cases.exists(_.span.contains(span)) && !sel.tpe.isErroneous =>
@@ -94,37 +97,34 @@ object InterCompletionType:
9497
else Some(UnapplyArgs(fun.tpe.finalResultType, fun, pats, NoSourcePosition).argTypes(ind))
9598
// f(@@)
9699
case ApplyExtractor(app) =>
97-
val argsAndParams = ApplyArgsExtractor.getArgsAndParams(None, app, span).headOption
98-
argsAndParams.flatMap:
99-
case (args, params) =>
100-
val idx = args.indexWhere(_.span.contains(span))
101-
val param =
102-
if idx >= 0 && params.length > idx then Some(params(idx).info)
103-
else None
104-
param match
105-
// def f[T](a: T): T = ???
106-
// f[Int](@@)
107-
// val _: Int = f(@@)
108-
case Some(t : TypeRef) if t.symbol.is(Flags.TypeParam) =>
109-
for
110-
(typeParams, args) <-
111-
app match
112-
case Apply(TypeApply(fun, args), _) =>
113-
val typeParams = fun.symbol.paramSymss.headOption.filter(_.forall(_.isTypeParam))
114-
typeParams.map((_, args.map(_.tpe)))
115-
// val f: (j: "a") => Int
116-
// f(@@)
117-
case Apply(Select(v, StdNames.nme.apply), _) =>
118-
v.symbol.info match
119-
case AppliedType(des, args) =>
120-
Some((des.typeSymbol.typeParams, args))
121-
case _ => None
122-
case _ => None
123-
ind = typeParams.indexOf(t.symbol)
124-
tpe <- args.get(ind)
125-
if !tpe.isErroneous
126-
yield tpe
127-
case Some(tpe) => Some(tpe)
128-
case _ => None
100+
val (idx, _, signatures) = Signatures.signatureHelp(fullPath, span)
101+
102+
val types: List[Type] = signatures.flatMap { s =>
103+
s.paramss.flatten.get(idx) match {
104+
case Some(mp: MethodParam) =>
105+
mp.tpe match
106+
case t : TypeParamRef =>
107+
for
108+
args <-
109+
app match
110+
case Apply(TypeApply(fun, args), _) =>
111+
Some(args.map(_.tpe))
112+
// val f: (j: "a") => Int
113+
// f(@@)
114+
case Apply(Select(v, StdNames.nme.apply), _) =>
115+
v.symbol.info match
116+
case AppliedType(des, args) =>
117+
Some(args)
118+
case _ => None
119+
case _ => None
120+
tpe <- args.get(t.paramNum)
121+
if !tpe.isErroneous
122+
yield tpe
123+
case tpe => Some(tpe)
124+
case _ => None
125+
}
126+
}
127+
if(types.isEmpty) None
128+
else Some(types.reduce(_ | _))
129129
case _ => None
130130

presentation-compiler/src/main/dotty/tools/pc/completions/Completions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ class Completions(
520520
config.isCompletionSnippetsEnabled()
521521
)
522522
(args, false)
523-
val singletonCompletions = InterCompletionType.inferType(path).map(
523+
val singletonCompletions = InferCompletionType.inferType(path).map(
524524
SingletonCompletions.contribute(path, _, completionPos)
525525
).getOrElse(Nil)
526526
(singletonCompletions ++ advanced, exclusive)

presentation-compiler/test/dotty/tools/pc/tests/InferExpectedTypeSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,3 +335,15 @@ class InferExpectedTypeSuite extends BasePCSuite:
335335
"""|String
336336
|""".stripMargin
337337
)
338+
339+
@Test def using =
340+
check(
341+
"""|def go(using Ordering[Int])(x: Int, y: Int): Int =
342+
| Ordering[Int].compare(x, y)
343+
|
344+
|def test =
345+
| go(@@, ???)
346+
|""".stripMargin,
347+
"""|Int
348+
|""".stripMargin
349+
)

presentation-compiler/test/dotty/tools/pc/tests/completion/SingletonCompletionsSuite.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,4 +297,30 @@ class SingletonCompletionsSuite extends BaseCompletionSuite {
297297
"""|"foo": "foo"
298298
|""".stripMargin
299299
)
300+
301+
@Test def `type-apply` =
302+
check(
303+
"""|class Consumer[A]:
304+
| def eat(a: A) = ()
305+
|
306+
|def test =
307+
| Consumer[7].eat(@@)
308+
|""".stripMargin,
309+
"7: 7",
310+
topLines = Some(1)
311+
)
312+
313+
@Test def `type-apply-2` =
314+
check(
315+
"""|class Consumer[A]:
316+
| def eat(a: A) = ()
317+
|
318+
|object Consumer7 extends Consumer[7]
319+
|
320+
|def test =
321+
| Consumer7.eat(@@)
322+
|""".stripMargin,
323+
"7: 7",
324+
topLines = Some(1)
325+
)
300326
}

0 commit comments

Comments
 (0)