Skip to content

Commit b55f8be

Browse files
committed
Add initial fmt.Sprintf rewriting
1 parent ae7166c commit b55f8be

File tree

6 files changed

+85
-22
lines changed

6 files changed

+85
-22
lines changed

internal/rewriter/rewriter.go

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,34 +31,18 @@ func Rewrite(filename string, oldSource []byte) ([]byte, error) {
3131
}
3232

3333
func visitor(n ast.Node) (ast.Node, bool) {
34-
c, ok := n.(*ast.CallExpr)
35-
if !ok {
34+
c, name := getCallExprLiteral(n)
35+
if c == nil {
3636
return n, true
3737
}
38-
39-
s, ok := c.Fun.(*ast.SelectorExpr)
40-
if !ok {
41-
return n, true
42-
}
43-
44-
i, ok := s.X.(*ast.Ident)
45-
if !ok {
46-
return n, true
47-
}
48-
49-
if i.Name != "errors" {
50-
return n, true
51-
}
52-
53-
switch s.Sel.Name {
54-
case "Wrap":
38+
switch name {
39+
case "errors.Wrap":
5540
return rewriteWrap(c), true
56-
case "Wrapf":
41+
case "errors.Wrapf":
5742
return rewriteWrap(c), true
5843
default:
44+
return n, true
5945
}
60-
61-
return n, true
6246
}
6347

6448
func rewriteWrap(ce *ast.CallExpr) *ast.CallExpr {
@@ -67,6 +51,12 @@ func rewriteWrap(ce *ast.CallExpr) *ast.CallExpr {
6751
copy(newArgs, ce.Args[1:])
6852
newArgs = append(newArgs, ce.Args[0])
6953

54+
// If the format string is a fmt.Sprintf call, we can unwrap it.
55+
c, name := getCallExprLiteral(newArgs[0])
56+
if c != nil && name == "fmt.Sprintf" {
57+
newArgs = append(c.Args, newArgs[1:]...)
58+
}
59+
7060
// If the format string is a literal, we can rewrite it:
7161
// "......" -> "......: %w"
7262
// Otherwise, we replace it with a binary op to add the wrap code:
@@ -87,6 +77,25 @@ func rewriteWrap(ce *ast.CallExpr) *ast.CallExpr {
8777
return newErrorfExpr(newArgs)
8878
}
8979

80+
func getCallExprLiteral(n ast.Node) (*ast.CallExpr, string) {
81+
c, ok := n.(*ast.CallExpr)
82+
if !ok {
83+
return nil, ""
84+
}
85+
86+
s, ok := c.Fun.(*ast.SelectorExpr)
87+
if !ok {
88+
return nil, ""
89+
}
90+
91+
i, ok := s.X.(*ast.Ident)
92+
if !ok {
93+
return nil, ""
94+
}
95+
96+
return c, i.Name + "." + s.Sel.Name
97+
}
98+
9099
func newErrorfExpr(args []ast.Expr) *ast.CallExpr {
91100
return &ast.CallExpr{
92101
Fun: &ast.SelectorExpr{

internal/rewriter/rewriter_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@ func TestRewrite(t *testing.T) {
1515
"wrap_string",
1616
"wrap_fcn",
1717
"wrap_var",
18+
"wrap_sprintf",
1819
"wrapf_string",
1920
"wrapf_fcn",
2021
"wrapf_var",
22+
"wrapf_sprintf",
2123
}
2224

2325
for _, c := range cases {
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package testdata
2+
3+
import (
4+
"fmt"
5+
"log"
6+
)
7+
8+
func main() {
9+
err := fmt.Errorf("this is an error")
10+
foo := "foo"
11+
log.Print(fmt.Errorf("error occurred '%s': %w", foo, err))
12+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package testdata
2+
3+
import (
4+
"fmt"
5+
"log"
6+
)
7+
8+
func main() {
9+
err := fmt.Errorf("this is an error")
10+
foo := "foo"
11+
log.Print(fmt.Errorf("error occurred '%s': %w", foo, err))
12+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package testdata
2+
3+
import (
4+
"fmt"
5+
"log"
6+
7+
"github.com/pkg/errors"
8+
)
9+
10+
func main() {
11+
err := fmt.Errorf("this is an error")
12+
foo := "foo"
13+
log.Print(errors.Wrap(err, fmt.Sprintf("error occurred '%s'", foo)))
14+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package testdata
2+
3+
import (
4+
"fmt"
5+
"log"
6+
7+
"github.com/pkg/errors"
8+
)
9+
10+
func main() {
11+
err := fmt.Errorf("this is an error")
12+
foo := "foo"
13+
log.Print(errors.Wrapf(err, fmt.Sprintf("error occurred '%s'", foo)))
14+
}

0 commit comments

Comments
 (0)