Skip to content

Commit

Permalink
refactor: factorize createFuncInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
ldez committed Dec 2, 2024
1 parent 725e4d9 commit d0b0385
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 23 deletions.
10 changes: 5 additions & 5 deletions report.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
// because [os.CreateTemp] takes 2 args.
const nbArgCreateTemp = 2

func (a *analyzer) reportCallExpr(pass *analysis.Pass, ce *ast.CallExpr, fnInfo FuncInfo) bool {
func (a *analyzer) reportCallExpr(pass *analysis.Pass, ce *ast.CallExpr, fnInfo *FuncInfo) bool {
if !a.osCreateTemp {
return false
}
Expand Down Expand Up @@ -64,7 +64,7 @@ func (a *analyzer) reportCallExpr(pass *analysis.Pass, ce *ast.CallExpr, fnInfo
return false
}

func (a *analyzer) reportSelector(pass *analysis.Pass, sel *ast.SelectorExpr, fnInfo FuncInfo) {
func (a *analyzer) reportSelector(pass *analysis.Pass, sel *ast.SelectorExpr, fnInfo *FuncInfo) {
expr, ok := sel.X.(*ast.Ident)
if !ok {
return
Expand All @@ -77,7 +77,7 @@ func (a *analyzer) reportSelector(pass *analysis.Pass, sel *ast.SelectorExpr, fn
a.report(pass, sel.Pos(), expr.Name, sel.Sel.Name, fnInfo)
}

func (a *analyzer) reportIdent(pass *analysis.Pass, expr *ast.Ident, fnInfo FuncInfo) {
func (a *analyzer) reportIdent(pass *analysis.Pass, expr *ast.Ident, fnInfo *FuncInfo) {
if !slices.Contains(a.fieldNames, expr.Name) {
return
}
Expand All @@ -92,7 +92,7 @@ func (a *analyzer) reportIdent(pass *analysis.Pass, expr *ast.Ident, fnInfo Func
}

//nolint:gocyclo // The complexity is expected by the cases to check.
func (a *analyzer) report(pass *analysis.Pass, pos token.Pos, origPkgName, origName string, fnInfo FuncInfo) {
func (a *analyzer) report(pass *analysis.Pass, pos token.Pos, origPkgName, origName string, fnInfo *FuncInfo) {
switch {
case a.osMkdirTemp && origPkgName == osPkgName && origName == mkdirTempName:
report(pass, pos, origPkgName, origName, tempDirName, fnInfo)
Expand All @@ -114,7 +114,7 @@ func (a *analyzer) report(pass *analysis.Pass, pos token.Pos, origPkgName, origN
}
}

func report(pass *analysis.Pass, pos token.Pos, origPkgName, origName, expectName string, fnInfo FuncInfo) {
func report(pass *analysis.Pass, pos token.Pos, origPkgName, origName, expectName string, fnInfo *FuncInfo) {
pass.Reportf(
pos,
"%s.%s() could be replaced by %s.%s() in %s",
Expand Down
39 changes: 21 additions & 18 deletions usetesting.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,21 +148,16 @@ func (a *analyzer) checkFunc(pass *analysis.Pass, ft *ast.FuncType, block *ast.B
return
}

argName, ok := isTestFunction(ft.Params.List[0], testingPkgName)
if !ok {
fnInfo := checkTestFunctionSignature(ft.Params.List[0], testingPkgName, fnName)
if fnInfo == nil {
return
}

fnInfo := FuncInfo{
Name: fnName,
ArgName: argName,
}

checkStmts(a, pass, fnInfo, block.List)
}

//nolint:funlen // The complexity is expected by the number of [ast.Stmt] variants.
func (a *analyzer) checkStmt(pass *analysis.Pass, fnInfo FuncInfo, stmt ast.Stmt) {
func (a *analyzer) checkStmt(pass *analysis.Pass, fnInfo *FuncInfo, stmt ast.Stmt) {
if stmt == nil {
return
}
Expand Down Expand Up @@ -234,7 +229,7 @@ func (a *analyzer) checkStmt(pass *analysis.Pass, fnInfo FuncInfo, stmt ast.Stmt
}
}

func (a *analyzer) checkExpr(pass *analysis.Pass, fnInfo FuncInfo, exp ast.Expr) {
func (a *analyzer) checkExpr(pass *analysis.Pass, fnInfo *FuncInfo, exp ast.Expr) {
switch expr := exp.(type) {
case *ast.BinaryExpr:
a.checkExpr(pass, fnInfo, expr.X)
Expand Down Expand Up @@ -268,34 +263,42 @@ func (a *analyzer) checkExpr(pass *analysis.Pass, fnInfo FuncInfo, exp ast.Expr)
}
}

func checkStmts[T ast.Stmt](a *analyzer, pass *analysis.Pass, fnInfo FuncInfo, stmts []T) {
func checkStmts[T ast.Stmt](a *analyzer, pass *analysis.Pass, fnInfo *FuncInfo, stmts []T) {
for _, stmt := range stmts {
a.checkStmt(pass, fnInfo, stmt)
}
}

func checkExprs(a *analyzer, pass *analysis.Pass, fnInfo FuncInfo, exprs []ast.Expr) {
func checkExprs(a *analyzer, pass *analysis.Pass, fnInfo *FuncInfo, exprs []ast.Expr) {
for _, expr := range exprs {
a.checkExpr(pass, fnInfo, expr)
}
}

func isTestFunction(arg *ast.Field, pkgName string) (string, bool) {
func checkTestFunctionSignature(arg *ast.Field, pkgName, fnName string) *FuncInfo {
switch at := arg.Type.(type) {
case *ast.StarExpr:
if se, ok := at.X.(*ast.SelectorExpr); ok {
argName := getTestArgName(arg, "<t/b>")

return argName, checkSelectorName(se, pkgName, "T", "B")
return createFuncInfo(arg, "<t/b>", se, pkgName, fnName, "T", "B")
}

case *ast.SelectorExpr:
argName := getTestArgName(arg, "tb")
return createFuncInfo(arg, "tb", at, pkgName, fnName, "TB")
}

return argName, checkSelectorName(at, pkgName, "TB")
return nil
}

func createFuncInfo(arg *ast.Field, defaultName string, se *ast.SelectorExpr, pkgName, fnName string, selectorNames ...string) *FuncInfo {
ok := checkSelectorName(se, pkgName, selectorNames...)
if !ok {
return nil
}

return "", false
return &FuncInfo{
Name: fnName,
ArgName: getTestArgName(arg, defaultName),
}
}

func checkSelectorName(se *ast.SelectorExpr, pkgName string, selectorNames ...string) bool {
Expand Down
2 changes: 2 additions & 0 deletions usetesting_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ func TestAnalyzer(t *testing.T) {

for _, test := range testCases {
t.Run(test.dir, func(t *testing.T) {
t.Parallel()

newAnalyzer := NewAnalyzer()

for k, v := range test.options {
Expand Down

0 comments on commit d0b0385

Please sign in to comment.