@@ -55,7 +55,7 @@ func RewriteDir(path string, verbose bool, printOnly io.Writer, jsonTypeReplacem
5555 }
5656 if err == nil {
5757 for fileName , file := range pkg .Files {
58- err = RewriteAstFile (fset , pkg , file , fileName , verbose , printOnly , jsonTypeReplacements , localImportPrefixes )
58+ err = RewriteAstFile (fset , pkg . Name , pkg . Files , file , fileName , verbose , printOnly , jsonTypeReplacements , localImportPrefixes )
5959 if err != nil {
6060 return err
6161 }
@@ -114,7 +114,7 @@ func RewriteFile(filePath string, verbose bool, printOnly io.Writer, jsonTypeRep
114114 if err != nil {
115115 return err
116116 }
117- return RewriteAstFile (fset , pkg , pkg .Files [filePath ], filePath , verbose , printOnly , jsonTypeReplacements , localImportPrefixes )
117+ return RewriteAstFile (fset , pkg . Name , pkg . Files , pkg .Files [filePath ], filePath , verbose , printOnly , jsonTypeReplacements , localImportPrefixes )
118118}
119119
120120// RewriteAstFile is the core rewriting logic that processes an AST file.
@@ -144,7 +144,36 @@ func RewriteFile(filePath string, verbose bool, printOnly io.Writer, jsonTypeRep
144144// Wrapper declarations are found by:
145145// - Variable assignments with TODO calls: var x = function.WrapperTODO(F)
146146// - Implementation comments: // myWrapper wraps F as function.Wrapper (generated code)
147- func RewriteAstFile (fset * token.FileSet , filePkg * ast.Package , astFile * ast.File , filePath string , verbose bool , printTo io.Writer , jsonTypeReplacements map [string ]string , localImportPrefixes []string ) (err error ) {
147+ func RewriteAstFile (fset * token.FileSet , pkgName string , pkgFiles map [string ]* ast.File , astFile * ast.File , filePath string , verbose bool , printTo io.Writer , jsonTypeReplacements map [string ]string , localImportPrefixes []string ) (err error ) {
148+ filePath = filepath .Clean (filePath )
149+
150+ source , err := os .ReadFile (filePath ) //#nosec G304
151+ if err != nil {
152+ return err
153+ }
154+ return RewriteAstFileSource (fset , pkgName , pkgFiles , astFile , filePath , source , verbose , printTo , jsonTypeReplacements , localImportPrefixes )
155+ }
156+
157+ // RewriteAstFileSource is like RewriteAstFile but accepts the source content directly.
158+ // This allows for in-memory testing without requiring files on disk.
159+ //
160+ // Parameters:
161+ // - fset: Token file set for position information
162+ // - pkgName: The package name
163+ // - pkgFiles: All files in the package (for finding function declarations)
164+ // - astFile: The parsed AST of the file to process
165+ // - filePath: File path for error messages (doesn't need to exist)
166+ // - source: The original source code as bytes
167+ // - verbose: If true, print detailed information
168+ // - printTo: If not nil, print to this writer instead of modifying file
169+ // - jsonTypeReplacements: Map of interface types to concrete types for JSON
170+ // - localImportPrefixes: Import path prefixes to treat as "local"
171+ //
172+ // Returns:
173+ // - error if wrapper generation or file writing fails
174+ //
175+ // This function is useful for testing as it doesn't require actual files on disk.
176+ func RewriteAstFileSource (fset * token.FileSet , pkgName string , pkgFiles map [string ]* ast.File , astFile * ast.File , filePath string , source []byte , verbose bool , printTo io.Writer , jsonTypeReplacements map [string ]string , localImportPrefixes []string ) (err error ) {
148177 filePath = filepath .Clean (filePath )
149178
150179 // ast.Print(fset, file)
@@ -164,7 +193,7 @@ func RewriteAstFile(fset *token.FileSet, filePkg *ast.Package, astFile *ast.File
164193 // Also parse all functions of the file's package
165194 // because they could als be referenced with an empty import name.
166195 // Added with empty string as package/import name.
167- functions , err := localAndImportedFunctions (fset , filePkg , astFile , pkgDir )
196+ functions , err := localAndImportedFunctions (fset , pkgName , pkgFiles , astFile , pkgDir )
168197 if err != nil {
169198 return err
170199 }
@@ -205,10 +234,6 @@ func RewriteAstFile(fset *token.FileSet, filePkg *ast.Package, astFile *ast.File
205234 replacements .Add (implReplacements )
206235 }
207236
208- source , err := os .ReadFile (filePath ) //#nosec G304
209- if err != nil {
210- return err
211- }
212237 rewritten , err := replacements .Apply (fset , source )
213238 if err != nil {
214239 return err
@@ -241,11 +266,11 @@ func RewriteAstFile(fset *token.FileSet, filePkg *ast.Package, astFile *ast.File
241266// wrapper represents a function wrapper declaration found in source code.
242267// It contains all information needed to generate the wrapper implementation.
243268type wrapper struct {
244- VarName string // Name of the wrapper variable (e.g., "myWrapper")
245- WrappedFunc string // Full name of the wrapped function (e.g., "pkg.MyFunc" or "MyFunc")
246- Type string // Name of the wrapper type (e.g., "myWrapperT")
247- Nodes []ast.Node // All AST nodes to be replaced (comments, var, type, methods)
248- Impl Impl // Which wrapper interfaces to implement
269+ VarName string // Name of the wrapper variable (e.g., "myWrapper")
270+ WrappedFunc string // Full name of the wrapped function (e.g., "pkg.MyFunc" or "MyFunc")
271+ Type string // Name of the wrapper type (e.g., "myWrapperT")
272+ Nodes []ast.Node // All AST nodes to be replaced (comments, var, type, methods)
273+ Impl Impl // Which wrapper interfaces to implement
249274}
250275
251276// WrappedFuncPkgAndFuncName splits the WrappedFunc into package and function name.
0 commit comments