diff --git a/errorlint/allowed.go b/errorlint/allowed.go index 7490cef..791e50f 100644 --- a/errorlint/allowed.go +++ b/errorlint/allowed.go @@ -127,13 +127,13 @@ func isAllowedErrAndFunc(err, fun string) bool { return false } -func isAllowedErrorComparison(pass *TypesInfoExt, binExpr *ast.BinaryExpr) bool { +func isAllowedErrorComparison(pass *TypesInfoExt, a, b ast.Expr) bool { var errName string // `.`, e.g. `io.EOF` var callExprs []*ast.CallExpr // Figure out which half of the expression is the returned error and which // half is the presumed error declaration. - for _, expr := range []ast.Expr{binExpr.X, binExpr.Y} { + for _, expr := range []ast.Expr{a, b} { switch t := expr.(type) { case *ast.SelectorExpr: // A selector which we assume refers to a staticaly declared error diff --git a/errorlint/lint.go b/errorlint/lint.go index 6648d31..9ac465c 100644 --- a/errorlint/lint.go +++ b/errorlint/lint.go @@ -172,15 +172,15 @@ func LintErrorComparisons(info *TypesInfoExt) []analysis.Diagnostic { continue } // Comparing errors with nil is okay. - if isNilComparison(binExpr) { + if isNil(binExpr.X) || isNil(binExpr.Y) { continue } // Find comparisons of which one side is a of type error. - if !isErrorComparison(info.TypesInfo, binExpr) { + if !isErrorType(info.TypesInfo, binExpr.X) && !isErrorType(info.TypesInfo, binExpr.Y) { continue } // Some errors that are returned from some functions are exempt. - if isAllowedErrorComparison(info, binExpr) { + if isAllowedErrorComparison(info, binExpr.X, binExpr.Y) { continue } // Comparisons that happen in `func (type) Is(error) bool` are okay. @@ -201,13 +201,29 @@ func LintErrorComparisons(info *TypesInfoExt) []analysis.Diagnostic { continue } // Check whether the switch operates on an error type. - if switchStmt.Tag == nil { + if !isErrorType(info.TypesInfo, switchStmt.Tag) { continue } - tagType := info.TypesInfo.Types[switchStmt.Tag] - if tagType.Type.String() != "error" { + + var problematicCaseClause *ast.CaseClause + outer: + for _, stmt := range switchStmt.Body.List { + caseClause := stmt.(*ast.CaseClause) + for _, caseExpr := range caseClause.List { + if isNil(caseExpr) { + continue + } + // Some errors that are returned from some functions are exempt. + if !isAllowedErrorComparison(info, switchStmt.Tag, caseExpr) { + problematicCaseClause = caseClause + break outer + } + } + } + if problematicCaseClause == nil { continue } + // Comparisons that happen in `func (type) Is(error) bool` are okay. if isNodeInErrorIsFunc(info, switchStmt) { continue } @@ -215,29 +231,22 @@ func LintErrorComparisons(info *TypesInfoExt) []analysis.Diagnostic { if switchComparesNonNil(switchStmt) { lints = append(lints, analysis.Diagnostic{ Message: "switch on an error will fail on wrapped errors. Use errors.Is to check for specific errors", - Pos: switchStmt.Pos(), + Pos: problematicCaseClause.Pos(), }) } - } return lints } -func isNilComparison(binExpr *ast.BinaryExpr) bool { - if ident, ok := binExpr.X.(*ast.Ident); ok && ident.Name == "nil" { - return true - } - if ident, ok := binExpr.Y.(*ast.Ident); ok && ident.Name == "nil" { - return true - } - return false +func isNil(ex ast.Expr) bool { + ident, ok := ex.(*ast.Ident) + return ok && ident.Name == "nil" } -func isErrorComparison(info *types.Info, binExpr *ast.BinaryExpr) bool { - tx := info.Types[binExpr.X] - ty := info.Types[binExpr.Y] - return tx.Type.String() == "error" || ty.Type.String() == "error" +func isErrorType(info *types.Info, ex ast.Expr) bool { + t := info.Types[ex].Type + return t != nil && t.String() == "error" } func isNodeInErrorIsFunc(info *TypesInfoExt, node ast.Node) bool { diff --git a/errorlint/testdata/src/errorsis/errorsis.go b/errorlint/testdata/src/errorsis/errorsis.go index 905678e..cbf0991 100644 --- a/errorlint/testdata/src/errorsis/errorsis.go +++ b/errorlint/testdata/src/errorsis/errorsis.go @@ -76,10 +76,10 @@ func NotEqualOperatorYoda() { func CompareSwitch() { err := doThing() - switch err { // want `switch on an error will fail on wrapped errors. Use errors.Is to check for specific errors` + switch err { case nil: fmt.Println("nil") - case ErrFoo: + case ErrFoo: // want `switch on an error will fail on wrapped errors. Use errors.Is to check for specific errors` fmt.Println("ErrFoo") } } @@ -95,8 +95,8 @@ func CompareSwitchSafe() { } func CompareSwitchInline() { - switch doThing() { // want `switch on an error will fail on wrapped errors. Use errors.Is to check for specific errors` - case ErrFoo: + switch doThing() { + case ErrFoo: // want `switch on an error will fail on wrapped errors. Use errors.Is to check for specific errors` fmt.Println("ErrFoo") } } diff --git a/errorlint/testdata/src/issues/github-54.go b/errorlint/testdata/src/issues/github-54.go new file mode 100644 index 0000000..0debd3e --- /dev/null +++ b/errorlint/testdata/src/issues/github-54.go @@ -0,0 +1,18 @@ +package issues + +import ( + "fmt" + + "golang.org/x/sys/unix" +) + +func SwitchOnUnixErrors() { + err := unix.Rmdir("somepath") + switch err { + case unix.ENOENT: + return + case unix.EPERM: + return + } + fmt.Println(err) +}