@@ -20,6 +20,9 @@ import dotty.tools.pc.utils.InteractiveEnrichments.*
20
20
import scala .meta .pc .reports .ReportContext
21
21
import scala .meta .pc .OffsetParams
22
22
import scala .meta .pc .SymbolSearch
23
+ import dotty .tools .dotc .util .Signatures
24
+ import dotty .tools .dotc .util .Signatures .MethodParam
25
+ import dotty .tools .dotc .util .Signatures .TypeParam
23
26
24
27
class InferExpectedType (
25
28
search : SymbolSearch ,
@@ -50,32 +53,32 @@ class InferExpectedType(
50
53
val indexedCtx = IndexedContext (pos)(using locatedCtx)
51
54
val printer =
52
55
ShortenedTypePrinter (search, IncludeDefaultParam .ResolveLater )(using indexedCtx)
53
- InterCompletionType .inferType(path)(using newctx).map{
56
+ InferCompletionType .inferType(path)(using newctx).map{
54
57
tpe => printer.tpe(tpe)
55
58
}
56
59
case None => None
57
60
58
- object InterCompletionType :
61
+ object InferCompletionType :
59
62
def inferType (path : List [Tree ])(using Context ): Option [Type ] =
60
63
path match
61
- case (lit : Literal ) :: Select (Literal (_), _) :: Apply (Select (Literal (_), _), List (s : Select )) :: rest if s.symbol == defn.Predef_undefined => inferType(rest, lit.span)
62
- case ident :: rest => inferType(rest, ident.span)
64
+ case (lit : Literal ) :: Select (Literal (_), _) :: Apply (Select (Literal (_), _), List (s : Select )) :: rest if s.symbol == defn.Predef_undefined => inferType(rest, lit.span, path )
65
+ case ident :: rest => inferType(rest, ident.span, path )
63
66
case _ => None
64
67
65
- def inferType (path : List [Tree ], span : Span )(using Context ): Option [Type ] =
68
+ def inferType (path : List [Tree ], span : Span , fullPath : List [ Tree ] )(using Context ): Option [Type ] =
66
69
path match
67
70
case Typed (expr, tpt) :: _ if expr.span.contains(span) && ! tpt.tpe.isErroneous => Some (tpt.tpe)
68
71
case Block (_, expr) :: rest if expr.span.contains(span) =>
69
- inferType(rest, span)
70
- case Bind (_, body) :: rest if body.span.contains(span) => inferType(rest, span)
71
- case Alternative (_) :: rest => inferType(rest, span)
72
- case Try (block, _, _) :: rest if block.span.contains(span) => inferType(rest, span)
73
- case CaseDef (_, _, body) :: Try (_, cases, _) :: rest if body.span.contains(span) && cases.exists(_.span.contains(span)) => inferType(rest, span)
74
- case If (cond, _, _) :: rest if ! cond.span.contains(span) => inferType(rest, span)
72
+ inferType(rest, span, fullPath )
73
+ case Bind (_, body) :: rest if body.span.contains(span) => inferType(rest, span, fullPath )
74
+ case Alternative (_) :: rest => inferType(rest, span, fullPath )
75
+ case Try (block, _, _) :: rest if block.span.contains(span) => inferType(rest, span, fullPath )
76
+ case CaseDef (_, _, body) :: Try (_, cases, _) :: rest if body.span.contains(span) && cases.exists(_.span.contains(span)) => inferType(rest, span, fullPath )
77
+ case If (cond, _, _) :: rest if ! cond.span.contains(span) => inferType(rest, span, fullPath )
75
78
case If (cond, _, _) :: rest if cond.span.contains(span) => Some (defn.BooleanType )
76
79
case CaseDef (_, _, body) :: Match (_, cases) :: rest if body.span.contains(span) && cases.exists(_.span.contains(span)) =>
77
- inferType(rest, span)
78
- case NamedArg (_, arg) :: rest if arg.span.contains(span) => inferType(rest, span)
80
+ inferType(rest, span, fullPath )
81
+ case NamedArg (_, arg) :: rest if arg.span.contains(span) => inferType(rest, span, fullPath )
79
82
// x match
80
83
// case @@
81
84
case CaseDef (pat, _, _) :: Match (sel, cases) :: rest if pat.span.contains(span) && cases.exists(_.span.contains(span)) && ! sel.tpe.isErroneous =>
@@ -94,37 +97,34 @@ object InterCompletionType:
94
97
else Some (UnapplyArgs (fun.tpe.finalResultType, fun, pats, NoSourcePosition ).argTypes(ind))
95
98
// f(@@)
96
99
case ApplyExtractor (app) =>
97
- val argsAndParams = ApplyArgsExtractor .getArgsAndParams(None , app, span).headOption
98
- argsAndParams.flatMap:
99
- case (args, params) =>
100
- val idx = args.indexWhere(_.span.contains(span))
101
- val param =
102
- if idx >= 0 && params.length > idx then Some (params(idx).info)
103
- else None
104
- param match
105
- // def f[T](a: T): T = ???
106
- // f[Int](@@)
107
- // val _: Int = f(@@)
108
- case Some (t : TypeRef ) if t.symbol.is(Flags .TypeParam ) =>
109
- for
110
- (typeParams, args) <-
111
- app match
112
- case Apply (TypeApply (fun, args), _) =>
113
- val typeParams = fun.symbol.paramSymss.headOption.filter(_.forall(_.isTypeParam))
114
- typeParams.map((_, args.map(_.tpe)))
115
- // val f: (j: "a") => Int
116
- // f(@@)
117
- case Apply (Select (v, StdNames .nme.apply), _) =>
118
- v.symbol.info match
119
- case AppliedType (des, args) =>
120
- Some ((des.typeSymbol.typeParams, args))
121
- case _ => None
122
- case _ => None
123
- ind = typeParams.indexOf(t.symbol)
124
- tpe <- args.get(ind)
125
- if ! tpe.isErroneous
126
- yield tpe
127
- case Some (tpe) => Some (tpe)
128
- case _ => None
100
+ val (idx, _, signatures) = Signatures .signatureHelp(fullPath, span)
101
+
102
+ val types : List [Type ] = signatures.flatMap { s =>
103
+ s.paramss.flatten.get(idx) match {
104
+ case Some (mp : MethodParam ) =>
105
+ mp.tpe match
106
+ case t : TypeParamRef =>
107
+ for
108
+ args <-
109
+ app match
110
+ case Apply (TypeApply (fun, args), _) =>
111
+ Some (args.map(_.tpe))
112
+ // val f: (j: "a") => Int
113
+ // f(@@)
114
+ case Apply (Select (v, StdNames .nme.apply), _) =>
115
+ v.symbol.info match
116
+ case AppliedType (des, args) =>
117
+ Some (args)
118
+ case _ => None
119
+ case _ => None
120
+ tpe <- args.get(t.paramNum)
121
+ if ! tpe.isErroneous
122
+ yield tpe
123
+ case tpe => Some (tpe)
124
+ case _ => None
125
+ }
126
+ }
127
+ if (types.isEmpty) None
128
+ else Some (types.reduce(_ | _))
129
129
case _ => None
130
130
0 commit comments