Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring mirror so we can use analysis.SuggestedFixes #20

Merged
merged 3 commits into from
May 12, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
feature: analysis.SuggestedFix
  • Loading branch information
butuzov committed May 12, 2023
commit 1c407e09b59d71e5b27997f668ab7ad31d46cce8
40 changes: 12 additions & 28 deletions analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,32 +40,16 @@ func (a *analyzer) run(pass *analysis.Pass) (interface{}, error) {
// --- Setup -----------------------------------------------------------------

check := checker.New(
BytesFunctions,
BytesBufferMethods,
BytesFunctions,
BytesBufferMethods,
BufioMethods,
HTTPTestMethods,
OsFileMethods,
RegexpFunctions,
RegexpRegexpMethods,
StringFunctions,
StringsBuilderMethods,
MaphashMethods,
BytesFunctions, BytesBufferMethods,
RegexpFunctions, RegexpRegexpMethods,
StringFunctions, StringsBuilderMethods,
BufioMethods, HTTPTestMethods,
OsFileMethods, MaphashMethods,
UTF8Functions,
)

check.Type = func(node ast.Expr) string {
if t := pass.TypesInfo.TypeOf(node); t != nil {
return t.String()
}

if tv, ok := pass.TypesInfo.Types[node]; ok {
return tv.Type.Underlying().String()
}

return ""
}
check.Type = checker.WrapType(pass.TypesInfo)
check.Print = checker.WrapPrint(pass.Fset)

violations := []*checker.Violation{}

Expand All @@ -76,7 +60,7 @@ func (a *analyzer) run(pass *analysis.Pass) (interface{}, error) {

// --- Preorder Checker ------------------------------------------------------
ins.Preorder([]ast.Node{(*ast.CallExpr)(nil)}, func(n ast.Node) {
callExpr := n.(*ast.CallExpr) //nolint: forcetypeassert
callExpr := n.(*ast.CallExpr)
fileName := pass.Fset.Position(callExpr.Pos()).Filename

if !a.withTests && strings.HasSuffix(fileName, "_test.go") {
Expand Down Expand Up @@ -106,7 +90,7 @@ func (a *analyzer) run(pass *analysis.Pass) (interface{}, error) {
if pkg, ok := imports.Lookup(fileName, pkgName); ok {
if v := check.Match(pkg, name); v != nil {
if args, found := check.Handle(v, callExpr); found {
violations = append(violations, v.With(callExpr, args))
violations = append(violations, v.With(check.Print(expr.X), callExpr, args))
}
return
}
Expand All @@ -121,7 +105,7 @@ func (a *analyzer) run(pass *analysis.Pass) (interface{}, error) {
pkgStruct, name := cleanAsterisk(tv.Type.String()), expr.Sel.Name
if v := check.Match(pkgStruct, name); v != nil {
if args, found := check.Handle(v, callExpr); found {
violations = append(violations, v.With(callExpr, args))
violations = append(violations, v.With(check.Print(expr.X), callExpr, args))
}
return
}
Expand All @@ -132,7 +116,7 @@ func (a *analyzer) run(pass *analysis.Pass) (interface{}, error) {
if pkg, ok := imports.Lookup(fileName, "."); ok {
if v := check.Match(pkg, expr.Name); v != nil {
if args, found := check.Handle(v, callExpr); found {
violations = append(violations, v.With(callExpr, args))
violations = append(violations, v.With(nil, callExpr, args))
}
return
}
Expand All @@ -142,7 +126,7 @@ func (a *analyzer) run(pass *analysis.Pass) (interface{}, error) {

// --- Reporting violations via issues ---------------------------------------
for _, violation := range violations {
pass.Report(violation.Issue())
pass.Report(violation.Issue(pass.Fset))
}

return nil, nil
Expand Down
29 changes: 28 additions & 1 deletion internal/checker/checker.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
package checker

import (
"bytes"
"go/ast"
"go/printer"
"go/token"
"go/types"
"strings"
)

// Checker will perform standart check on package and its methods.
type Checker struct {
Violations []Violation // List of available violations
Packages map[string][]int // Storing indexes of Violations per pkg/kg.Struct
Type func(ast.Expr) string // Closure for the
Type func(ast.Expr) string // Type Checker closure.
Print func(ast.Node) []byte // String representation of the expresion.
}

func New(violations ...[]Violation) Checker {
Expand Down Expand Up @@ -109,3 +114,25 @@ func (c *Checker) register(violations []Violation) {
func (c *Checker) registerIdxPer(pkg string) {
c.Packages[pkg] = append(c.Packages[pkg], len(c.Violations)-1)
}

func WrapType(info *types.Info) func(node ast.Expr) string {
return func(node ast.Expr) string {
if t := info.TypeOf(node); t != nil {
return t.String()
}

if tv, ok := info.Types[node]; ok {
return tv.Type.Underlying().String()
}

return ""
}
}

func WrapPrint(fSet *token.FileSet) func(ast.Node) []byte {
return func(node ast.Node) []byte {
var buf bytes.Buffer
printer.Fprint(&buf, fSet, node)
return buf.Bytes()
}
}
128 changes: 59 additions & 69 deletions internal/checker/violation.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package checker

import (
"bytes"
"fmt"
"go/ast"
"go/printer"
"go/token"
"path"

"golang.org/x/tools/go/analysis"
Expand Down Expand Up @@ -37,8 +40,9 @@ type Violation struct {
Generate *Generate

// --- suggestions related info about violation of rules.
callExpr *ast.CallExpr
arguments map[int]ast.Expr
base []byte // receiver of the method or pkg name
callExpr *ast.CallExpr // actual call expression, to extract arguments
arguments map[int]ast.Expr // fixed arguments
}

// Tests (generation) related struct.
Expand All @@ -48,68 +52,12 @@ type Generate struct {
Returns int // Expected to return n elements
}

// func (v *Violation) Diagnostic(fset *token.FileSet, n *ast.CallExpr) *analysis.Diagnostic {
// diagnostic := &analysis.Diagnostic{
// Pos: n.Pos(),
// End: n.Pos(),
// Message: v.Message,
// }

// if b := v.suggestion(fset, n); len(b) > 0 {
// diagnostic.SuggestedFixes = []analysis.SuggestedFix{{
// Message: "",
// TextEdits: []analysis.TextEdit{{Pos: n.Pos(), End: n.End(), NewText: b}},
// }}

// fmt.Println(">>>", string(b))
// }

// return diagnostic
// }

// // nolint: revive
// func (v *Violation) suggestion(fset *token.FileSet, n *ast.CallExpr) []byte {
// var buf bytes.Buffer

// changeCall := func(buf *bytes.Buffer, base, alternative string) []byte {
// buf.WriteString(base)
// buf.WriteString(".")
// buf.WriteString(alternative)
// buf.WriteString("(")

// for i := range n.Args {
// if arg, ok := v.arguments.args[i]; ok {
// printer.Fprint(buf, fset, arg)
// } else {
// printer.Fprint(buf, fset, n.Args[i])
// }

// if i != len(n.Args)-1 {
// buf.WriteString(", ")
// }
// }

// buf.WriteString(")")
// return buf.Bytes()
// }

// // is it method call?
// if len(v.arguments.obj) > 0 {
// return changeCall(&buf, v.arguments.obj, v.Alternative.Method)
// }

// // Old And New imports (and names) needs to be resolved.

// // rest of cases.
// return changeCall(&buf, v.arguments.pkg, v.Alternative.Function)
// }

func (v *Violation) With(e *ast.CallExpr, args map[int]ast.Expr) *Violation {
v2 := (*v)
v2.callExpr = e
v2.arguments = args

return &v2
func (v *Violation) With(base []byte, e *ast.CallExpr, args map[int]ast.Expr) *Violation {
v.base = base
v.callExpr = e
v.arguments = args

return v
}

func (v *Violation) Message() string {
Expand All @@ -126,20 +74,62 @@ func (v *Violation) Message() string {
return fmt.Sprintf("avoid allocations with %s.%s", path.Base(pkg), v.AltCaller)
}

func (v *Violation) Issue() analysis.Diagnostic {
func (v *Violation) suggest(fSet *token.FileSet) []byte {
var buf bytes.Buffer

if len(v.base) > 0 {
buf.Write(v.base)
buf.WriteString(".")
}

buf.WriteString(v.AltCaller)
buf.WriteByte('(')
for idx := range v.callExpr.Args {
if arg, ok := v.arguments[idx]; ok {
printer.Fprint(&buf, fSet, arg)
} else {
printer.Fprint(&buf, fSet, v.callExpr.Args[idx])
}

if idx != len(v.callExpr.Args)-1 {
buf.WriteString(", ")
}
}
buf.WriteByte(')')

return buf.Bytes()
}

func (v *Violation) Issue(fSet *token.FileSet) analysis.Diagnostic {
diagnostic := analysis.Diagnostic{
Pos: v.callExpr.Pos(),
End: v.callExpr.Pos(),
Message: v.Message(),
}

// fmt.Println(string(v.suggest(fSet)))

// Struct based fix.
if v.Type == Method {
return diagnostic
diagnostic.SuggestedFixes = []analysis.SuggestedFix{{
Message: "Fix Issue With",
TextEdits: []analysis.TextEdit{{
Pos: v.callExpr.Pos(), End: v.callExpr.End(), NewText: v.suggest(fSet),
}},
}}
}

// Hooray! we dont need to change package and redo imports.
if v.Type == Function && len(v.AltPackage) == 0 {
diagnostic.SuggestedFixes = []analysis.SuggestedFix{{
Message: "Fix Issue With",
TextEdits: []analysis.TextEdit{{
Pos: v.callExpr.Pos(), End: v.callExpr.End(), NewText: v.suggest(fSet),
}},
}}
}

// fmt.Println("package", c.Package)
// fmt.Println("target methods ?", v.Type == Method)
// fmt.Println("alternative", v.Alt)
// do not change

return diagnostic
}