From d366f44d350201cf9d6907a28c3eb2be7075db86 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Wed, 13 Nov 2024 23:16:35 +0000 Subject: [PATCH] gopls/internal/golang: more idiomatic result naming in extract Use field names or type names for result variables names, if available. Otherwise, use the same conventions for var naming as completion. If all else fails, use 'result' rather than 'returnValue'. For golang/go#66289 Change-Id: Ife1d5435f00d2c4930fad48e7373e987668139ef Reviewed-on: https://go-review.googlesource.com/c/tools/+/627775 Auto-Submit: Robert Findley LUCI-TryBot-Result: Go LUCI Reviewed-by: Hongxiang Jiang --- gopls/internal/golang/completion/literal.go | 29 +--- gopls/internal/golang/extract.go | 130 ++++++++++++++---- gopls/internal/golang/util.go | 34 +++++ .../codeaction/functionextraction.txt | 16 +-- .../functionextraction_issue66289.txt | 61 +++++++- 5 files changed, 204 insertions(+), 66 deletions(-) diff --git a/gopls/internal/golang/completion/literal.go b/gopls/internal/golang/completion/literal.go index 7427d559e94..82f123048e8 100644 --- a/gopls/internal/golang/completion/literal.go +++ b/gopls/internal/golang/completion/literal.go @@ -370,6 +370,8 @@ func (c *completer) functionLiteral(ctx context.Context, sig *types.Signature, m // conventionalAcronyms contains conventional acronyms for type names // in lower case. For example, "ctx" for "context" and "err" for "error". +// +// Keep this up to date with golang.conventionalVarNames. var conventionalAcronyms = map[string]string{ "context": "ctx", "error": "err", @@ -382,11 +384,6 @@ var conventionalAcronyms = map[string]string{ // non-identifier runes. For example, "[]int" becomes "i", and // "struct { i int }" becomes "s". func abbreviateTypeName(s string) string { - var ( - b strings.Builder - useNextUpper bool - ) - // Trim off leading non-letters. We trim everything between "[" and // "]" to handle array types like "[someConst]int". var inBracket bool @@ -407,27 +404,7 @@ func abbreviateTypeName(s string) string { return acr } - for i, r := range s { - // Stop if we encounter a non-identifier rune. - if !unicode.IsLetter(r) && !unicode.IsNumber(r) { - break - } - - if i == 0 { - b.WriteRune(unicode.ToLower(r)) - } - - if unicode.IsUpper(r) { - if useNextUpper { - b.WriteRune(unicode.ToLower(r)) - useNextUpper = false - } - } else { - useNextUpper = true - } - } - - return b.String() + return golang.AbbreviateVarName(s) } // compositeLiteral adds a composite literal completion item for the given typeName. diff --git a/gopls/internal/golang/extract.go b/gopls/internal/golang/extract.go index 510c6f6eba3..82ef6fd69ad 100644 --- a/gopls/internal/golang/extract.go +++ b/gopls/internal/golang/extract.go @@ -37,14 +37,14 @@ func extractVariable(fset *token.FileSet, start, end token.Pos, src []byte, file // TODO: stricter rules for selectorExpr. case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.SliceExpr, *ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr: - lhsName, _ := generateAvailableIdentifier(expr.Pos(), path, pkg, info, "x", 0) + lhsName, _ := generateAvailableName(expr.Pos(), path, pkg, info, "x", 0) lhsNames = append(lhsNames, lhsName) case *ast.CallExpr: tup, ok := info.TypeOf(expr).(*types.Tuple) if !ok { // If the call expression only has one return value, we can treat it the // same as our standard extract variable case. - lhsName, _ := generateAvailableIdentifier(expr.Pos(), path, pkg, info, "x", 0) + lhsName, _ := generateAvailableName(expr.Pos(), path, pkg, info, "x", 0) lhsNames = append(lhsNames, lhsName) break } @@ -52,7 +52,7 @@ func extractVariable(fset *token.FileSet, start, end token.Pos, src []byte, file for i := 0; i < tup.Len(); i++ { // Generate a unique variable for each return value. var lhsName string - lhsName, idx = generateAvailableIdentifier(expr.Pos(), path, pkg, info, "x", idx) + lhsName, idx = generateAvailableName(expr.Pos(), path, pkg, info, "x", idx) lhsNames = append(lhsNames, lhsName) } default: @@ -150,12 +150,12 @@ func calculateIndentation(content []byte, tok *token.File, insertBeforeStmt ast. return string(content[lineOffset:stmtOffset]), nil } -// generateAvailableIdentifier adjusts the new function name until there are no collisions in scope. +// generateAvailableName adjusts the new function name until there are no collisions in scope. // Possible collisions include other function and variable names. Returns the next index to check for prefix. -func generateAvailableIdentifier(pos token.Pos, path []ast.Node, pkg *types.Package, info *types.Info, prefix string, idx int) (string, int) { +func generateAvailableName(pos token.Pos, path []ast.Node, pkg *types.Package, info *types.Info, prefix string, idx int) (string, int) { scopes := CollectScopes(info, path, pos) scopes = append(scopes, pkg.Scope()) - return generateIdentifier(idx, prefix, func(name string) bool { + return generateName(idx, prefix, func(name string) bool { for _, scope := range scopes { if scope != nil && scope.Lookup(name) != nil { return true @@ -165,7 +165,31 @@ func generateAvailableIdentifier(pos token.Pos, path []ast.Node, pkg *types.Pack }) } -func generateIdentifier(idx int, prefix string, hasCollision func(string) bool) (string, int) { +// generateNameOutsideOfRange is like generateAvailableName, but ignores names +// declared between start and end for the purposes of detecting conflicts. +// +// This is used for function extraction, where [start, end) will be extracted +// to a new scope. +func generateNameOutsideOfRange(start, end token.Pos, path []ast.Node, pkg *types.Package, info *types.Info, prefix string, idx int) (string, int) { + scopes := CollectScopes(info, path, start) + scopes = append(scopes, pkg.Scope()) + return generateName(idx, prefix, func(name string) bool { + for _, scope := range scopes { + if scope != nil { + if obj := scope.Lookup(name); obj != nil { + // Only report a collision if the object declaration was outside the + // extracted range. + if obj.Pos() < start || end <= obj.Pos() { + return true + } + } + } + } + return false + }) +} + +func generateName(idx int, prefix string, hasCollision func(string) bool) (string, int) { name := prefix if idx != 0 { name += fmt.Sprintf("%d", idx) @@ -182,7 +206,7 @@ func generateIdentifier(idx int, prefix string, hasCollision func(string) bool) type returnVariable struct { // name is the identifier that is used on the left-hand side of the call to // the extracted function. - name ast.Expr + name *ast.Ident // decl is the declaration of the variable. It is used in the type signature of the // extracted function and for variable declarations. decl *ast.Field @@ -517,7 +541,7 @@ func extractFunctionMethod(fset *token.FileSet, start, end token.Pos, src []byte // statements in the selection. Update the type signature of the extracted // function and construct the if statement that will be inserted in the enclosing // function. - retVars, ifReturn, err = generateReturnInfo(enclosing, pkg, path, file, info, start, hasNonNestedReturn) + retVars, ifReturn, err = generateReturnInfo(enclosing, pkg, path, file, info, start, end, hasNonNestedReturn) if err != nil { return nil, nil, err } @@ -552,7 +576,7 @@ func extractFunctionMethod(fset *token.FileSet, start, end token.Pos, src []byte funName = name } else { name = "newFunction" - funName, _ = generateAvailableIdentifier(start, path, pkg, info, name, 0) + funName, _ = generateAvailableName(start, path, pkg, info, name, 0) } extractedFunCall := generateFuncCall(hasNonNestedReturn, hasReturnValues, params, append(returns, getNames(retVars)...), funName, sym, receiverName) @@ -1187,12 +1211,12 @@ func parseBlockStmt(fset *token.FileSet, src []byte) (*ast.BlockStmt, error) { // signature of the extracted function. We prepare names, signatures, and "zero values" that // represent the new variables. We also use this information to construct the if statement that // is inserted below the call to the extracted function. -func generateReturnInfo(enclosing *ast.FuncType, pkg *types.Package, path []ast.Node, file *ast.File, info *types.Info, pos token.Pos, hasNonNestedReturns bool) ([]*returnVariable, *ast.IfStmt, error) { +func generateReturnInfo(enclosing *ast.FuncType, pkg *types.Package, path []ast.Node, file *ast.File, info *types.Info, start, end token.Pos, hasNonNestedReturns bool) ([]*returnVariable, *ast.IfStmt, error) { var retVars []*returnVariable var cond *ast.Ident if !hasNonNestedReturns { // Generate information for the added bool value. - name, _ := generateAvailableIdentifier(pos, path, pkg, info, "shouldReturn", 0) + name, _ := generateNameOutsideOfRange(start, end, path, pkg, info, "shouldReturn", 0) cond = &ast.Ident{Name: name} retVars = append(retVars, &returnVariable{ name: cond, @@ -1202,7 +1226,7 @@ func generateReturnInfo(enclosing *ast.FuncType, pkg *types.Package, path []ast. } // Generate information for the values in the return signature of the enclosing function. if enclosing.Results != nil { - idx := 0 + nameIdx := make(map[string]int) // last integral suffixes of generated names for _, field := range enclosing.Results.List { typ := info.TypeOf(field.Type) if typ == nil { @@ -1213,17 +1237,32 @@ func generateReturnInfo(enclosing *ast.FuncType, pkg *types.Package, path []ast. if expr == nil { return nil, nil, fmt.Errorf("nil AST expression") } - var name string - name, idx = generateAvailableIdentifier(pos, path, pkg, info, "returnValue", idx) - z := analysisinternal.ZeroValue(file, pkg, typ) - if z == nil { - return nil, nil, fmt.Errorf("can't generate zero value for %T", typ) + names := []string{""} + if len(field.Names) > 0 { + names = nil + for _, n := range field.Names { + names = append(names, n.Name) + } + } + for _, name := range names { + bestName := "result" + if name != "" && name != "_" { + bestName = name + } else if n, ok := varNameForType(typ); ok { + bestName = n + } + retName, idx := generateNameOutsideOfRange(start, end, path, pkg, info, bestName, nameIdx[bestName]) + nameIdx[bestName] = idx + z := analysisinternal.ZeroValue(file, pkg, typ) + if z == nil { + return nil, nil, fmt.Errorf("can't generate zero value for %T", typ) + } + retVars = append(retVars, &returnVariable{ + name: ast.NewIdent(retName), + decl: &ast.Field{Type: expr}, + zeroVal: z, + }) } - retVars = append(retVars, &returnVariable{ - name: ast.NewIdent(name), - decl: &ast.Field{Type: expr}, - zeroVal: z, - }) } } var ifReturn *ast.IfStmt @@ -1240,6 +1279,48 @@ func generateReturnInfo(enclosing *ast.FuncType, pkg *types.Package, path []ast. return retVars, ifReturn, nil } +type objKey struct{ pkg, name string } + +// conventionalVarNames specifies conventional names for variables with various +// standard library types. +// +// Keep this up to date with completion.conventionalAcronyms. +// +// TODO(rfindley): consider factoring out a "conventions" library. +var conventionalVarNames = map[objKey]string{ + {"", "error"}: "err", + {"context", "Context"}: "ctx", + {"sql", "Tx"}: "tx", + {"http", "ResponseWriter"}: "rw", // Note: same as [AbbreviateVarName]. +} + +// varNameForTypeName chooses a "good" name for a variable with the given type, +// if possible. Otherwise, it returns "", false. +// +// For special types, it uses known conventional names. +func varNameForType(t types.Type) (string, bool) { + var typeName string + if tn, ok := t.(interface{ Obj() *types.TypeName }); ok { + obj := tn.Obj() + k := objKey{name: obj.Name()} + if obj.Pkg() != nil { + k.pkg = obj.Pkg().Name() + } + if name, ok := conventionalVarNames[k]; ok { + return name, true + } + typeName = obj.Name() + } else if b, ok := t.(*types.Basic); ok { + typeName = b.Name() + } + + if typeName == "" { + return "", false + } + + return AbbreviateVarName(typeName), true +} + // adjustReturnStatements adds "zero values" of the given types to each return statement // in the given AST node. func adjustReturnStatements(returnTypes []*ast.Field, seenVars map[types.Object]ast.Expr, file *ast.File, pkg *types.Package, extractedBlock *ast.BlockStmt) error { @@ -1346,9 +1427,8 @@ func initializeVars(uninitialized []types.Object, retVars []*returnVariable, see // Each variable added from a return statement in the selection // must be initialized. for i, retVar := range retVars { - n := retVar.name.(*ast.Ident) valSpec := &ast.ValueSpec{ - Names: []*ast.Ident{n}, + Names: []*ast.Ident{retVar.name}, Type: retVars[i].decl.Type, } genDecl := &ast.GenDecl{ diff --git a/gopls/internal/golang/util.go b/gopls/internal/golang/util.go index 18f72421a64..06239af17d6 100644 --- a/gopls/internal/golang/util.go +++ b/gopls/internal/golang/util.go @@ -12,6 +12,7 @@ import ( "go/types" "regexp" "strings" + "unicode" "golang.org/x/tools/gopls/internal/cache" "golang.org/x/tools/gopls/internal/cache/metadata" @@ -363,3 +364,36 @@ func btoi(b bool) int { return 0 } } + +// AbbreviateVarName returns an abbreviated var name based on the given full +// name (which may be a type name, for example). +// +// See the simple heuristics documented in line. +func AbbreviateVarName(s string) string { + var ( + b strings.Builder + useNextUpper bool + ) + for i, r := range s { + // Stop if we encounter a non-identifier rune. + if !unicode.IsLetter(r) && !unicode.IsNumber(r) { + break + } + + // Otherwise, take the first letter from word boundaries, assuming + // camelCase. + if i == 0 { + b.WriteRune(unicode.ToLower(r)) + } + + if unicode.IsUpper(r) { + if useNextUpper { + b.WriteRune(unicode.ToLower(r)) + useNextUpper = false + } + } else { + useNextUpper = true + } + } + return b.String() +} diff --git a/gopls/internal/test/marker/testdata/codeaction/functionextraction.txt b/gopls/internal/test/marker/testdata/codeaction/functionextraction.txt index 1b9f487c49d..6ae0bc7177e 100644 --- a/gopls/internal/test/marker/testdata/codeaction/functionextraction.txt +++ b/gopls/internal/test/marker/testdata/codeaction/functionextraction.txt @@ -56,9 +56,9 @@ package extract func _() bool { x := 1 //@codeaction("if", ifend, "refactor.extract.function", return) - shouldReturn, returnValue := newFunction(x) + shouldReturn, b := newFunction(x) if shouldReturn { - return returnValue + return b } //@loc(ifend, "}") return false } @@ -124,9 +124,9 @@ func _() (int, string, error) { x := 1 y := "hello" //@codeaction("z", rcEnd, "refactor.extract.function", rc) - z, shouldReturn, returnValue, returnValue1, returnValue2 := newFunction(y, x) + z, shouldReturn, i, s, err := newFunction(y, x) if shouldReturn { - return returnValue, returnValue1, returnValue2 + return i, s, err } //@loc(rcEnd, "}") return x, z, nil } @@ -205,9 +205,9 @@ import "go/ast" func _() { ast.Inspect(ast.NewIdent("a"), func(n ast.Node) bool { //@codeaction("if", rflEnd, "refactor.extract.function", rfl) - shouldReturn, returnValue := newFunction(n) + shouldReturn, b := newFunction(n) if shouldReturn { - return returnValue + return b } //@loc(rflEnd, "}") return false }) @@ -272,9 +272,9 @@ package extract func _() string { x := 1 //@codeaction("if", riEnd, "refactor.extract.function", ri) - shouldReturn, returnValue := newFunction(x) + shouldReturn, s := newFunction(x) if shouldReturn { - return returnValue + return s } //@loc(riEnd, "}") x = 2 return "b" diff --git a/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue66289.txt b/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue66289.txt index 65412ee91fa..c032c7797a6 100644 --- a/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue66289.txt +++ b/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue66289.txt @@ -8,19 +8,19 @@ import ( ) func F() error { - a, err := json.Marshal(0) //@codeaction("a", end, "refactor.extract.function", out) + a, err := json.Marshal(0) //@codeaction("a", endF, "refactor.extract.function", F) if err != nil { return fmt.Errorf("1: %w", err) } b, err := json.Marshal(0) if err != nil { return fmt.Errorf("2: %w", err) - } //@loc(end, "}") + } //@loc(endF, "}") fmt.Println(a, b) return nil } --- @out/a.go -- +-- @F/a.go -- package a import ( @@ -29,11 +29,11 @@ import ( ) func F() error { - //@codeaction("a", end, "refactor.extract.function", out) - a, b, shouldReturn, returnValue := newFunction() + //@codeaction("a", endF, "refactor.extract.function", F) + a, b, shouldReturn, err := newFunction() if shouldReturn { - return returnValue - } //@loc(end, "}") + return err + } //@loc(endF, "}") fmt.Println(a, b) return nil } @@ -50,3 +50,50 @@ func newFunction() ([]byte, []byte, bool, error) { return a, b, false, nil } +-- b.go -- +package a + +import ( + "fmt" + "math/rand" +) + +func G() (x, y int) { + v := rand.Int() //@codeaction("v", endG, "refactor.extract.function", G) + if v < 0 { + return 1, 2 + } + if v > 0 { + return 3, 4 + } //@loc(endG, "}") + fmt.Println(v) + return 5, 6 +} +-- @G/b.go -- +package a + +import ( + "fmt" + "math/rand" +) + +func G() (x, y int) { + //@codeaction("v", endG, "refactor.extract.function", G) + v, shouldReturn, x1, y1 := newFunction() + if shouldReturn { + return x1, y1 + } //@loc(endG, "}") + fmt.Println(v) + return 5, 6 +} + +func newFunction() (int, bool, int, int) { + v := rand.Int() + if v < 0 { + return 0, true, 1, 2 + } + if v > 0 { + return 0, true, 3, 4 + } + return v, false, 0, 0 +}