diff --git a/internal/lsp/source/completion.go b/internal/lsp/source/completion.go index 639e6a6857b..670865fa3c6 100644 --- a/internal/lsp/source/completion.go +++ b/internal/lsp/source/completion.go @@ -1445,7 +1445,7 @@ func enclosingCompositeLiteral(path []ast.Node, pos token.Pos, info *types.Info) return &clInfo default: - if breaksExpectedTypeInference(n) { + if breaksExpectedTypeInference(n, pos) { return nil } } @@ -1535,11 +1535,11 @@ type typeModifier struct { type typeMod int const ( - star typeMod = iota // pointer indirection for expressions, pointer indicator for types - address // address operator ("&") - chanRead // channel read operator ("<-") - slice // make a slice type ("[]" in "[]int") - array // make an array type ("[2]" in "[2]int") + dereference typeMod = iota // pointer indirection: "*" + reference // adds level of pointer: "&" for values, "*" for type names + chanRead // channel read operator ("<-") + slice // make a slice type ("[]" in "[]int") + array // make an array type ("[2]" in "[2]int") ) type objKind int @@ -1651,6 +1651,10 @@ type typeNameInference struct { // seenTypeSwitchCases tracks types that have already been used by // the containing type switch. seenTypeSwitchCases []types.Type + + // compLitType is true if we are completing a composite literal type + // name, e.g "foo<>{}". + compLitType bool } // expectedCandidate returns information about the expected candidate @@ -1862,11 +1866,11 @@ Nodes: } return inf case *ast.StarExpr: - inf.modifiers = append(inf.modifiers, typeModifier{mod: star}) + inf.modifiers = append(inf.modifiers, typeModifier{mod: dereference}) case *ast.UnaryExpr: switch node.Op { case token.AND: - inf.modifiers = append(inf.modifiers, typeModifier{mod: address}) + inf.modifiers = append(inf.modifiers, typeModifier{mod: reference}) case token.ARROW: inf.modifiers = append(inf.modifiers, typeModifier{mod: chanRead}) } @@ -1874,7 +1878,7 @@ Nodes: inf.objKind |= kindFunc return inf default: - if breaksExpectedTypeInference(node) { + if breaksExpectedTypeInference(node, c.pos) { return inf } } @@ -1928,7 +1932,7 @@ func objChain(info *types.Info, e ast.Expr) []types.Object { func (ci candidateInference) applyTypeModifiers(typ types.Type, addressable bool) types.Type { for _, mod := range ci.modifiers { switch mod.mod { - case star: + case dereference: // For every "*" indirection operator, remove a pointer layer // from candidate type. if ptr, ok := typ.Underlying().(*types.Pointer); ok { @@ -1936,7 +1940,7 @@ func (ci candidateInference) applyTypeModifiers(typ types.Type, addressable bool } else { return nil } - case address: + case reference: // For every "&" address operator, add another pointer layer to // candidate type, if the candidate is addressable. if addressable { @@ -1961,8 +1965,7 @@ func (ci candidateInference) applyTypeModifiers(typ types.Type, addressable bool func (ci candidateInference) applyTypeNameModifiers(typ types.Type) types.Type { for _, mod := range ci.typeName.modifiers { switch mod.mod { - case star: - // For every "*" indicator, add a pointer layer to type name. + case reference: typ = types.NewPointer(typ) case array: typ = types.NewArray(typ, mod.arrayLen) @@ -2006,9 +2009,17 @@ func findSwitchStmt(path []ast.Node, pos token.Pos, c *ast.CaseClause) ast.Stmt // breaksExpectedTypeInference reports if an expression node's type is unrelated // to its child expression node types. For example, "Foo{Bar: x.Baz(<>)}" should // expect a function argument, not a composite literal value. -func breaksExpectedTypeInference(n ast.Node) bool { - switch n.(type) { - case *ast.FuncLit, *ast.CallExpr, *ast.IndexExpr, *ast.SliceExpr, *ast.CompositeLit: +func breaksExpectedTypeInference(n ast.Node, pos token.Pos) bool { + switch n := n.(type) { + case *ast.CompositeLit: + // Doesn't break inference if pos is in type name. + // For example: "Foo<>{Bar: 123}" + return !nodeContains(n.Type, pos) + case *ast.CallExpr: + // Doesn't break inference if pos is in func name. + // For example: "Foo<>(123)" + return !nodeContains(n.Fun, pos) + case *ast.FuncLit, *ast.IndexExpr, *ast.SliceExpr: return true default: return false @@ -2017,13 +2028,7 @@ func breaksExpectedTypeInference(n ast.Node) bool { // expectTypeName returns information about the expected type name at position. func expectTypeName(c *completer) typeNameInference { - var ( - wantTypeName bool - wantComparable bool - modifiers []typeModifier - assertableFrom types.Type - seenTypeSwitchCases []types.Type - ) + var inf typeNameInference Nodes: for i, p := range c.path { @@ -2034,7 +2039,7 @@ Nodes: // InterfaceType. We don't need to worry about the field name // because completion bails out early if pos is in an *ast.Ident // that defines an object. - wantTypeName = true + inf.wantTypeName = true break Nodes case *ast.CaseClause: // Expect type names in type switch case clauses. @@ -2042,12 +2047,12 @@ Nodes: // The case clause types must be assertable from the type switch parameter. ast.Inspect(swtch.Assign, func(n ast.Node) bool { if ta, ok := n.(*ast.TypeAssertExpr); ok { - assertableFrom = c.pkg.GetTypesInfo().TypeOf(ta.X) + inf.assertableFrom = c.pkg.GetTypesInfo().TypeOf(ta.X) return false } return true }) - wantTypeName = true + inf.wantTypeName = true // Track the types that have already been used in this // switch's case statements so we don't recommend them. @@ -2060,7 +2065,7 @@ Nodes: } if t := c.pkg.GetTypesInfo().TypeOf(typeExpr); t != nil { - seenTypeSwitchCases = append(seenTypeSwitchCases, t) + inf.seenTypeSwitchCases = append(inf.seenTypeSwitchCases, t) } } } @@ -2072,33 +2077,43 @@ Nodes: // Expect type names in type assert expressions. if n.Lparen < c.pos && c.pos <= n.Rparen { // The type in parens must be assertable from the expression type. - assertableFrom = c.pkg.GetTypesInfo().TypeOf(n.X) - wantTypeName = true + inf.assertableFrom = c.pkg.GetTypesInfo().TypeOf(n.X) + inf.wantTypeName = true break Nodes } return typeNameInference{} case *ast.StarExpr: - modifiers = append(modifiers, typeModifier{mod: star}) + inf.modifiers = append(inf.modifiers, typeModifier{mod: reference}) case *ast.CompositeLit: // We want a type name if position is in the "Type" part of a // composite literal (e.g. "Foo<>{}"). if n.Type != nil && n.Type.Pos() <= c.pos && c.pos <= n.Type.End() { - wantTypeName = true + inf.wantTypeName = true + inf.compLitType = true + + if i < len(c.path)-1 { + // Track preceding "&" operator. Technically it applies to + // the composite literal and not the type name, but if + // affects our type completion nonetheless. + if u, ok := c.path[i+1].(*ast.UnaryExpr); ok && u.Op == token.AND { + inf.modifiers = append(inf.modifiers, typeModifier{mod: reference}) + } + } } break Nodes case *ast.ArrayType: // If we are inside the "Elt" part of an array type, we want a type name. if n.Elt.Pos() <= c.pos && c.pos <= n.Elt.End() { - wantTypeName = true + inf.wantTypeName = true if n.Len == nil { // No "Len" expression means a slice type. - modifiers = append(modifiers, typeModifier{mod: slice}) + inf.modifiers = append(inf.modifiers, typeModifier{mod: slice}) } else { // Try to get the array type using the constant value of "Len". tv, ok := c.pkg.GetTypesInfo().Types[n.Len] if ok && tv.Value != nil && tv.Value.Kind() == constant.Int { if arrayLen, ok := constant.Int64Val(tv.Value); ok { - modifiers = append(modifiers, typeModifier{mod: array, arrayLen: arrayLen}) + inf.modifiers = append(inf.modifiers, typeModifier{mod: array, arrayLen: arrayLen}) } } } @@ -2114,34 +2129,28 @@ Nodes: break Nodes } case *ast.MapType: - wantTypeName = true + inf.wantTypeName = true if n.Key != nil { - wantComparable = nodeContains(n.Key, c.pos) + inf.wantComparable = nodeContains(n.Key, c.pos) } else { // If the key is empty, assume we are completing the key if // pos is directly after the "map[". - wantComparable = c.pos == n.Pos()+token.Pos(len("map[")) + inf.wantComparable = c.pos == n.Pos()+token.Pos(len("map[")) } break Nodes case *ast.ValueSpec: - wantTypeName = nodeContains(n.Type, c.pos) + inf.wantTypeName = nodeContains(n.Type, c.pos) break Nodes case *ast.TypeSpec: - wantTypeName = nodeContains(n.Type, c.pos) + inf.wantTypeName = nodeContains(n.Type, c.pos) default: - if breaksExpectedTypeInference(p) { + if breaksExpectedTypeInference(p, c.pos) { return typeNameInference{} } } } - return typeNameInference{ - wantTypeName: wantTypeName, - wantComparable: wantComparable, - modifiers: modifiers, - assertableFrom: assertableFrom, - seenTypeSwitchCases: seenTypeSwitchCases, - } + return inf } func (c *completer) fakeObj(T types.Type) *types.Var { @@ -2519,7 +2528,15 @@ func (c *completer) matchingTypeName(cand *candidate) bool { } if !isInterface(t) && typeMatches(types.NewPointer(t)) { - cand.makePointer = true + if c.inference.typeName.compLitType { + // If we are completing a composite literal type as in + // "foo<>{}", to make a pointer we must prepend "&". + cand.takeAddress = true + } else { + // If we are completing a normal type name such as "foo<>", to + // make a pointer we must prepend "*". + cand.makePointer = true + } return true } diff --git a/internal/lsp/testdata/lsp/primarymod/complit/complit.go.in b/internal/lsp/testdata/lsp/primarymod/complit/complit.go.in index ec6544eb6b4..465a72cc288 100644 --- a/internal/lsp/testdata/lsp/primarymod/complit/complit.go.in +++ b/internal/lsp/testdata/lsp/primarymod/complit/complit.go.in @@ -94,6 +94,20 @@ func _() { _ = position{X} //@complete("}", fieldX, varX) } +func _() { + type foo struct{} //@item(complitFoo, "foo", "struct{...}", "struct") + + "&foo" //@item(complitAndFoo, "&foo", "struct{...}", "struct") + + var _ *foo = &fo{} //@rank("{", complitFoo) + var _ *foo = fo{} //@rank("{", complitAndFoo) + + struct { a, b *foo }{ + a: &fo{}, //@rank("{", complitFoo) + b: fo{}, //@rank("{", complitAndFoo) + } +} + func _() { _ := position{ X: 1, //@complete("X", fieldX),complete(" 1", exportedFunc, multilineWithPrefix, structPosition, cVar, exportedConst, exportedType) diff --git a/internal/lsp/testdata/lsp/primarymod/snippets/literal_snippets.go.in b/internal/lsp/testdata/lsp/primarymod/snippets/literal_snippets.go.in index d970bf1f9d6..4a505e38f8e 100644 --- a/internal/lsp/testdata/lsp/primarymod/snippets/literal_snippets.go.in +++ b/internal/lsp/testdata/lsp/primarymod/snippets/literal_snippets.go.in @@ -199,6 +199,12 @@ func _() { ptrStruct{ p: &ptrSt, //@rank(",", litPtrStruct) } + + &ptrStruct{} //@item(litPtrStructPtr, "&ptrStruct{}", "", "var") + + &ptrStruct{ + p: ptrSt, //@rank(",", litPtrStructPtr) + } } func _() { diff --git a/internal/lsp/testdata/lsp/summary.txt.golden b/internal/lsp/testdata/lsp/summary.txt.golden index 8ff02aff16a..5d0f432e52c 100644 --- a/internal/lsp/testdata/lsp/summary.txt.golden +++ b/internal/lsp/testdata/lsp/summary.txt.golden @@ -6,7 +6,7 @@ CompletionSnippetCount = 85 UnimportedCompletionsCount = 6 DeepCompletionsCount = 5 FuzzyCompletionsCount = 8 -RankedCompletionsCount = 152 +RankedCompletionsCount = 157 CaseSensitiveCompletionsCount = 4 DiagnosticsCount = 44 FoldingRangesCount = 2