Skip to content

Commit 02caa63

Browse files
committed
- Refactored RewriteAstFile and related functions to accept package files for better handling of function declarations.
- Enhanced in-memory testing capabilities for function wrappers with new test cases.
1 parent 8daff0d commit 02caa63

File tree

13 files changed

+184
-45
lines changed

13 files changed

+184
-45
lines changed
8.73 MB
Binary file not shown.

cmd/gen-func-wrappers/gen-func-wrappers.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ func main() {
8484

8585
jsonTypeReplacements := make(map[string]string)
8686
if replaceForJSON != "" {
87-
for _, repl := range strings.Split(replaceForJSON, ",") {
87+
for repl := range strings.SplitSeq(replaceForJSON, ",") {
8888
types := strings.Split(repl, ":")
8989
if len(types) != 2 {
9090
fmt.Fprintln(os.Stderr, "gen-func-wrappers error: invalid -replaceForJSON syntax")

cmd/gen-func-wrappers/gen/funcimpl.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,11 @@ func (impl Impl) String() string {
111111
// 7. CallWithJSON method (if ImplCallWithJSONWrapper is set): Unmarshals JSON to arguments
112112
//
113113
// The method handles:
114-
// - context.Context as first argument (automatic detection and handling)
115-
// - Variadic parameters (...type)
116-
// - Error return values (automatic error result detection)
117-
// - Type conversions for string parsing
118-
// - Proper argument descriptions from function comments
114+
// - context.Context as first argument (automatic detection and handling)
115+
// - Variadic parameters (...type)
116+
// - Error return values (automatic error result detection)
117+
// - Type conversions for string parsing
118+
// - Proper argument descriptions from function comments
119119
func (impl Impl) WriteFunctionWrapper(w io.Writer, funcFile *ast.File, funcDecl *ast.FuncDecl, implType, funcPackage string, neededImportLines map[string]struct{}, jsonTypeReplacements map[string]string, targetFileImports []*ast.ImportSpec) error {
120120
var (
121121
argNames = funcTypeArgNames(funcDecl.Type)
@@ -527,7 +527,7 @@ func reflectTypeOfTypeName(typeName string, packageRemap map[string]string) stri
527527
// remapPackageQualifiers translates package qualifiers in a type name using the provided mapping.
528528
// For example, "gmail.Label" with remap["gmail"]="gmailapi" becomes "gmailapi.Label".
529529
func remapPackageQualifiers(typeName string, packageRemap map[string]string) string {
530-
if packageRemap == nil || len(packageRemap) == 0 {
530+
if len(packageRemap) == 0 {
531531
return typeName
532532
}
533533

cmd/gen-func-wrappers/gen/imports.go

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ import (
1111

1212
// packageFuncs holds function declarations for a package along with package location info.
1313
type packageFuncs struct {
14-
Location *astvisit.PackageLocation // Package location and metadata
15-
Funcs map[string]funcDeclInFile // Map of function name to declaration
14+
Location *astvisit.PackageLocation // Package location and metadata
15+
Funcs map[string]funcDeclInFile // Map of function name to declaration
1616
}
1717

1818
// localAndImportedFunctions builds a complete map of all functions available to a file.
@@ -40,9 +40,9 @@ type packageFuncs struct {
4040
//
4141
// This is used by the code generator to find the wrapped function's declaration
4242
// when it's referenced as "pkg.FuncName" or just "FuncName".
43-
func localAndImportedFunctions(fset *token.FileSet, filePkg *ast.Package, file *ast.File, pkgDir string) (map[string]packageFuncs, error) {
43+
func localAndImportedFunctions(fset *token.FileSet, pkgName string, pkgFiles map[string]*ast.File, file *ast.File, pkgDir string) (map[string]packageFuncs, error) {
4444
localFuncs := make(map[string]funcDeclInFile)
45-
for _, f := range filePkg.Files {
45+
for _, f := range pkgFiles {
4646
for _, decl := range f.Decls {
4747
funcDecl, ok := decl.(*ast.FuncDecl)
4848
if ok && funcDecl.Recv == nil {
@@ -56,7 +56,7 @@ func localAndImportedFunctions(fset *token.FileSet, filePkg *ast.Package, file *
5656
functions := map[string]packageFuncs{
5757
"": {
5858
Location: &astvisit.PackageLocation{
59-
PkgName: filePkg.Name,
59+
PkgName: pkgName,
6060
SourcePath: pkgDir,
6161
},
6262
Funcs: localFuncs,
@@ -135,22 +135,20 @@ func gatherFieldListImports(funcFile *ast.File, fieldList *ast.FieldList, setImp
135135
// to detect conflicts and reuse existing aliases
136136
targetImportsByName := make(map[string]string) // name -> path
137137
targetImportsByPath := make(map[string]string) // path -> name
138-
if targetFileImports != nil {
139-
for _, imp := range targetFileImports {
140-
var name string
141-
if imp.Name != nil {
142-
name = imp.Name.Name
143-
} else {
144-
var err error
145-
name, err = guessPackageNameFromPath(imp.Path.Value)
146-
if err != nil {
147-
// Skip imports we can't parse
148-
continue
149-
}
138+
for _, imp := range targetFileImports {
139+
var name string
140+
if imp.Name != nil {
141+
name = imp.Name.Name
142+
} else {
143+
var err error
144+
name, err = guessPackageNameFromPath(imp.Path.Value)
145+
if err != nil {
146+
// Skip imports we can't parse
147+
continue
150148
}
151-
targetImportsByName[name] = imp.Path.Value
152-
targetImportsByPath[imp.Path.Value] = name
153149
}
150+
targetImportsByName[name] = imp.Path.Value
151+
targetImportsByPath[imp.Path.Value] = name
154152
}
155153

156154
packageNames := make(map[string]struct{})

cmd/gen-func-wrappers/gen/rewrite.go

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ func RewriteDir(path string, verbose bool, printOnly io.Writer, jsonTypeReplacem
5555
}
5656
if err == nil {
5757
for fileName, file := range pkg.Files {
58-
err = RewriteAstFile(fset, pkg, file, fileName, verbose, printOnly, jsonTypeReplacements, localImportPrefixes)
58+
err = RewriteAstFile(fset, pkg.Name, pkg.Files, file, fileName, verbose, printOnly, jsonTypeReplacements, localImportPrefixes)
5959
if err != nil {
6060
return err
6161
}
@@ -114,7 +114,7 @@ func RewriteFile(filePath string, verbose bool, printOnly io.Writer, jsonTypeRep
114114
if err != nil {
115115
return err
116116
}
117-
return RewriteAstFile(fset, pkg, pkg.Files[filePath], filePath, verbose, printOnly, jsonTypeReplacements, localImportPrefixes)
117+
return RewriteAstFile(fset, pkg.Name, pkg.Files, pkg.Files[filePath], filePath, verbose, printOnly, jsonTypeReplacements, localImportPrefixes)
118118
}
119119

120120
// RewriteAstFile is the core rewriting logic that processes an AST file.
@@ -144,7 +144,36 @@ func RewriteFile(filePath string, verbose bool, printOnly io.Writer, jsonTypeRep
144144
// Wrapper declarations are found by:
145145
// - Variable assignments with TODO calls: var x = function.WrapperTODO(F)
146146
// - Implementation comments: // myWrapper wraps F as function.Wrapper (generated code)
147-
func RewriteAstFile(fset *token.FileSet, filePkg *ast.Package, astFile *ast.File, filePath string, verbose bool, printTo io.Writer, jsonTypeReplacements map[string]string, localImportPrefixes []string) (err error) {
147+
func RewriteAstFile(fset *token.FileSet, pkgName string, pkgFiles map[string]*ast.File, astFile *ast.File, filePath string, verbose bool, printTo io.Writer, jsonTypeReplacements map[string]string, localImportPrefixes []string) (err error) {
148+
filePath = filepath.Clean(filePath)
149+
150+
source, err := os.ReadFile(filePath) //#nosec G304
151+
if err != nil {
152+
return err
153+
}
154+
return RewriteAstFileSource(fset, pkgName, pkgFiles, astFile, filePath, source, verbose, printTo, jsonTypeReplacements, localImportPrefixes)
155+
}
156+
157+
// RewriteAstFileSource is like RewriteAstFile but accepts the source content directly.
158+
// This allows for in-memory testing without requiring files on disk.
159+
//
160+
// Parameters:
161+
// - fset: Token file set for position information
162+
// - pkgName: The package name
163+
// - pkgFiles: All files in the package (for finding function declarations)
164+
// - astFile: The parsed AST of the file to process
165+
// - filePath: File path for error messages (doesn't need to exist)
166+
// - source: The original source code as bytes
167+
// - verbose: If true, print detailed information
168+
// - printTo: If not nil, print to this writer instead of modifying file
169+
// - jsonTypeReplacements: Map of interface types to concrete types for JSON
170+
// - localImportPrefixes: Import path prefixes to treat as "local"
171+
//
172+
// Returns:
173+
// - error if wrapper generation or file writing fails
174+
//
175+
// This function is useful for testing as it doesn't require actual files on disk.
176+
func RewriteAstFileSource(fset *token.FileSet, pkgName string, pkgFiles map[string]*ast.File, astFile *ast.File, filePath string, source []byte, verbose bool, printTo io.Writer, jsonTypeReplacements map[string]string, localImportPrefixes []string) (err error) {
148177
filePath = filepath.Clean(filePath)
149178

150179
// ast.Print(fset, file)
@@ -164,7 +193,7 @@ func RewriteAstFile(fset *token.FileSet, filePkg *ast.Package, astFile *ast.File
164193
// Also parse all functions of the file's package
165194
// because they could als be referenced with an empty import name.
166195
// Added with empty string as package/import name.
167-
functions, err := localAndImportedFunctions(fset, filePkg, astFile, pkgDir)
196+
functions, err := localAndImportedFunctions(fset, pkgName, pkgFiles, astFile, pkgDir)
168197
if err != nil {
169198
return err
170199
}
@@ -205,10 +234,6 @@ func RewriteAstFile(fset *token.FileSet, filePkg *ast.Package, astFile *ast.File
205234
replacements.Add(implReplacements)
206235
}
207236

208-
source, err := os.ReadFile(filePath) //#nosec G304
209-
if err != nil {
210-
return err
211-
}
212237
rewritten, err := replacements.Apply(fset, source)
213238
if err != nil {
214239
return err
@@ -241,11 +266,11 @@ func RewriteAstFile(fset *token.FileSet, filePkg *ast.Package, astFile *ast.File
241266
// wrapper represents a function wrapper declaration found in source code.
242267
// It contains all information needed to generate the wrapper implementation.
243268
type wrapper struct {
244-
VarName string // Name of the wrapper variable (e.g., "myWrapper")
245-
WrappedFunc string // Full name of the wrapped function (e.g., "pkg.MyFunc" or "MyFunc")
246-
Type string // Name of the wrapper type (e.g., "myWrapperT")
247-
Nodes []ast.Node // All AST nodes to be replaced (comments, var, type, methods)
248-
Impl Impl // Which wrapper interfaces to implement
269+
VarName string // Name of the wrapper variable (e.g., "myWrapper")
270+
WrappedFunc string // Full name of the wrapped function (e.g., "pkg.MyFunc" or "MyFunc")
271+
Type string // Name of the wrapper type (e.g., "myWrapperT")
272+
Nodes []ast.Node // All AST nodes to be replaced (comments, var, type, methods)
273+
Impl Impl // Which wrapper interfaces to implement
249274
}
250275

251276
// WrappedFuncPkgAndFuncName splits the WrappedFunc into package and function name.

cmd/gen-func-wrappers/gen/rewrite_test.go

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
package gen
22

3-
import "testing"
3+
import (
4+
"bytes"
5+
"go/ast"
6+
"go/parser"
7+
"go/token"
8+
"testing"
9+
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
)
413

514
func Test_parseImplementsComment(t *testing.T) {
615
type args struct {
@@ -61,3 +70,96 @@ func Test_parseImplementsComment(t *testing.T) {
6170
})
6271
}
6372
}
73+
74+
// TestRewriteAstFileSource_InMemory demonstrates using RewriteAstFileSource for in-memory testing
75+
func TestRewriteAstFileSource_InMemory(t *testing.T) {
76+
// Create in-memory source code with a wrapper TODO
77+
source := []byte(`package testpkg
78+
79+
import "github.com/domonda/go-function"
80+
81+
// SimpleAdd adds two integers.
82+
// a: First number
83+
// b: Second number
84+
func SimpleAdd(a, b int) int {
85+
return a + b
86+
}
87+
88+
var simpleAddWrapper = function.WrapperTODO(SimpleAdd)
89+
`)
90+
91+
// Parse the source into an AST
92+
fset := token.NewFileSet()
93+
astFile, err := parser.ParseFile(fset, "test.go", source, parser.ParseComments)
94+
require.NoError(t, err)
95+
96+
// Create package files map (single file in this case)
97+
pkgFiles := map[string]*ast.File{
98+
"test.go": astFile,
99+
}
100+
101+
// Buffer to capture the rewritten output
102+
var output bytes.Buffer
103+
104+
// Call RewriteAstFileSource with in-memory source
105+
err = RewriteAstFileSource(
106+
fset,
107+
"testpkg",
108+
pkgFiles,
109+
astFile,
110+
"test.go", // path is only used for error messages
111+
source,
112+
false, // verbose
113+
&output,
114+
nil, // no JSON type replacements
115+
nil, // no local import prefixes
116+
)
117+
require.NoError(t, err)
118+
119+
// Verify the output contains generated wrapper code
120+
result := output.String()
121+
assert.Contains(t, result, "// simpleAddWrapper wraps SimpleAdd as function.Wrapper (generated code)")
122+
assert.Contains(t, result, "var simpleAddWrapper simpleAddWrapperT")
123+
assert.Contains(t, result, "type simpleAddWrapperT struct{}")
124+
assert.Contains(t, result, "func (simpleAddWrapperT) Name() string")
125+
assert.Contains(t, result, "func (simpleAddWrapperT) Call")
126+
assert.Contains(t, result, "results[0] = SimpleAdd(args[0].(int), args[1].(int))")
127+
}
128+
129+
// TestRewriteAstFileSource_NoWrappers tests handling of files without wrappers
130+
func TestRewriteAstFileSource_NoWrappers(t *testing.T) {
131+
source := []byte(`package testpkg
132+
133+
// SimpleAdd adds two integers.
134+
func SimpleAdd(a, b int) int {
135+
return a + b
136+
}
137+
`)
138+
139+
fset := token.NewFileSet()
140+
astFile, err := parser.ParseFile(fset, "test.go", source, parser.ParseComments)
141+
require.NoError(t, err)
142+
143+
pkgFiles := map[string]*ast.File{
144+
"test.go": astFile,
145+
}
146+
147+
var output bytes.Buffer
148+
149+
err = RewriteAstFileSource(
150+
fset,
151+
"testpkg",
152+
pkgFiles,
153+
astFile,
154+
"test.go",
155+
source,
156+
false,
157+
&output,
158+
nil,
159+
nil,
160+
)
161+
require.NoError(t, err)
162+
163+
// Should not generate anything if there are no wrappers
164+
assert.Empty(t, output.String())
165+
}

cmd/gen-func-wrappers/go.mod

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,19 @@ module github.com/domonda/go-function/cmd/gen-func-wrappers
33
go 1.24.0
44

55
require (
6+
github.com/stretchr/testify v1.11.1
67
github.com/ungerik/go-astvisit v0.0.0-20251017171216-b7bb0384dd33
78
golang.org/x/tools v0.38.0
89
)
910

1011
require (
12+
github.com/davecgh/go-spew v1.1.1 // indirect
13+
github.com/kr/text v0.2.0 // indirect
14+
github.com/pmezard/go-difflib v1.0.0 // indirect
1115
golang.org/x/mod v0.29.0 // indirect
1216
golang.org/x/sync v0.17.0 // indirect
17+
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
18+
gopkg.in/yaml.v3 v3.0.1 // indirect
1319
)
1420

1521
// replace github.com/ungerik/go-astvisit => ../../../../ungerik/go-astvisit

cmd/gen-func-wrappers/go.sum

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
2+
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
23
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
34
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
5+
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
6+
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
47
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
8+
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
59
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
10+
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
611
github.com/ungerik/go-astvisit v0.0.0-20251017171216-b7bb0384dd33 h1:xV30N1y6stpoqp5vt/xJSj6uS1z9wvuI15W0BAplpig=
712
github.com/ungerik/go-astvisit v0.0.0-20251017171216-b7bb0384dd33/go.mod h1:HSuqDFbjplGwkDoVmkdCNG4fes4IEsQCjS2/DQrfHl8=
813
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
@@ -11,4 +16,7 @@ golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
1116
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
1217
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
1318
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
19+
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
20+
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
1421
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
22+
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ require (
1010

1111
require (
1212
github.com/davecgh/go-spew v1.1.1 // indirect
13+
github.com/kr/text v0.2.0 // indirect
1314
github.com/pmezard/go-difflib v1.0.0 // indirect
1415
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
1516
gopkg.in/yaml.v3 v3.0.1 // indirect

go.sum

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ github.com/h2non/filetype v1.1.3/go.mod h1:319b3zT68BvV+WRj7cwy856M2ehB3HqNOt6sy
55
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
66
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
77
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
8-
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
98
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
9+
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
1010
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
1111
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
1212
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=

0 commit comments

Comments
 (0)