Skip to content

Commit

Permalink
Support adding new stubs to existing stub files (#25130)
Browse files Browse the repository at this point in the history
* stubmaker can generate stubs for only the missing functions

* check error
  • Loading branch information
miagilepner authored Feb 1, 2024
1 parent f0e7f11 commit eb2b905
Showing 1 changed file with 106 additions and 67 deletions.
173 changes: 106 additions & 67 deletions tools/stubmaker/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,22 @@ import (
"bytes"
"errors"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"go/types"
"io"
"os"
"path/filepath"
"regexp"
"sort"
"strings"

"github.com/go-git/go-git/v5"
"github.com/go-git/go-git/v5/plumbing"
"github.com/go-git/go-git/v5/plumbing/object"
"github.com/google/go-cmp/cmp"
"github.com/hashicorp/go-hclog"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/packages"
"golang.org/x/tools/imports"
)

var logger hclog.Logger
Expand All @@ -31,6 +33,11 @@ func fatal(err error) {
os.Exit(1)
}

type generator struct {
file *ast.File
fset *token.FileSet
}

func main() {
logger = hclog.New(&hclog.LoggerOptions{
Name: "stubmaker",
Expand Down Expand Up @@ -67,14 +74,15 @@ func main() {
fatal(err)
}

inputLines, err := readLines(bytes.NewBuffer(b))
inputParsed, err := parseFile(b)
if err != nil {
fatal(err)
}
funcs := getFuncs(inputLines)
if needed, err := isStubNeeded(funcs); err != nil {
needed, existing, err := inputParsed.areStubsNeeded()
if err != nil {
fatal(err)
} else if !needed {
}
if !needed {
return
}

Expand Down Expand Up @@ -107,7 +115,7 @@ func main() {
if err != nil {
fatal(err)
}
_, err = io.WriteString(output, strings.Join(getOutput(inputLines), "\n")+"\n")
err = inputParsed.writeStubs(output, existing)
if err != nil {
// If we don't end up writing to the file, delete it.
os.Remove(outputFile + ".tmp")
Expand All @@ -119,6 +127,57 @@ func main() {
}
}

func (g *generator) writeStubs(output *os.File, existingFuncs map[string]struct{}) error {
// delete all functions/methods that are already defined
g.modifyAST(existingFuncs)

// write the updated code to buf
buf := new(bytes.Buffer)
err := format.Node(buf, g.fset, g.file)
if err != nil {
return err
}

// remove any unneeded imports
res, err := imports.Process("", buf.Bytes(), &imports.Options{
Fragment: true,
AllErrors: false,
Comments: true,
FormatOnly: false,
})
if err != nil {
return err
}

// add the code generation line and update the build tags
outputLines, err := fixGeneratedComments(res)
if err != nil {
return err
}
_, err = output.WriteString(strings.Join(outputLines, "\n") + "\n")
return err
}

func fixGeneratedComments(b []byte) ([]string, error) {
warning := "// Code generated by tools/stubmaker; DO NOT EDIT."
goGenerate := "//go:generate go run github.com/hashicorp/vault/tools/stubmaker"

scanner := bufio.NewScanner(bytes.NewBuffer(b))
var outputLines []string
for scanner.Scan() {
line := scanner.Text()
switch {
case strings.Contains(line, "//go:build ") && strings.Contains(line, "!enterprise"):
outputLines = append(outputLines, warning, "")
line = strings.ReplaceAll(line, "!enterprise", "enterprise")
case line == goGenerate:
continue
}
outputLines = append(outputLines, line)
}
return outputLines, scanner.Err()
}

func inGit(wt *git.Worktree, st git.Status, obj object.Object, path string) (bool, error) {
absPath, err := filepath.Abs(path)
if err != nil {
Expand Down Expand Up @@ -189,27 +248,24 @@ func resolve(obj object.Object, path string) (*object.Blob, error) {
}
}

func readLines(r io.Reader) ([]string, error) {
scanner := bufio.NewScanner(r)
scanner.Split(bufio.ScanLines)

var lines []string
for scanner.Scan() {
lines = append(lines, scanner.Text())
}
if err := scanner.Err(); err != nil {
return nil, err
}
return lines, nil
}

func isStubNeeded(funcs []string) (bool, error) {
// areStubsNeeded checks if all functions and methods defined in the stub file
// are present in the package
func (g *generator) areStubsNeeded() (needed bool, existingStubs map[string]struct{}, err error) {
pkg, err := parsePackage(".", []string{"enterprise"})
if err != nil {
return false, err
return false, nil, err
}

var found []string
stubFunctions := make(map[string]struct{})
for _, d := range g.file.Decls {
dFunc, ok := d.(*ast.FuncDecl)
if !ok {
continue
}
stubFunctions[dFunc.Name.Name] = struct{}{}

}
found := make(map[string]struct{})
for name, val := range pkg.TypesInfo.Defs {
if val == nil {
continue
Expand All @@ -218,54 +274,25 @@ func isStubNeeded(funcs []string) (bool, error) {
if !ok {
continue
}
for _, f := range funcs {
if name.Name == f {
found = append(found, f)
}
if _, ok := stubFunctions[name.Name]; ok {
found[name.Name] = struct{}{}
}
}
switch {
case len(found) == len(funcs):
return false, nil
case len(found) != 0:
sort.Strings(found)
sort.Strings(funcs)
delta := cmp.Diff(found, funcs)
return false, fmt.Errorf("funcs partially defined, delta=%s", delta)
}

return true, nil
return len(found) != len(stubFunctions), found, nil
}

var funcRE = regexp.MustCompile("^func *(?:[(][^)]+[)])? *([^(]+)")

func getFuncs(inputLines []string) []string {
var funcs []string
for _, line := range inputLines {
matches := funcRE.FindStringSubmatch(line)
if len(matches) > 1 {
funcs = append(funcs, matches[1])
}
}
return funcs
}

func getOutput(inputLines []string) []string {
warning := "// Code generated by tools/stubmaker; DO NOT EDIT."

var outputLines []string
for _, line := range inputLines {
switch line {
case "//go:build !enterprise":
outputLines = append(outputLines, warning, "")
line = "//go:build enterprise"
case "//go:generate go run github.com/hashicorp/vault/tools/stubmaker":
continue
func (g *generator) modifyAST(exists map[string]struct{}) {
astutil.Apply(g.file, nil, func(c *astutil.Cursor) bool {
switch x := c.Node().(type) {
case *ast.FuncDecl:
if _, ok := exists[x.Name.Name]; ok {
c.Delete()
}
}
outputLines = append(outputLines, line)
}

return outputLines
return true
})
}

func parsePackage(name string, tags []string) (*packages.Package, error) {
Expand All @@ -283,3 +310,15 @@ func parsePackage(name string, tags []string) (*packages.Package, error) {
}
return pkgs[0], nil
}

func parseFile(buffer []byte) (*generator, error) {
fs := token.NewFileSet()
f, err := parser.ParseFile(fs, "", buffer, parser.AllErrors|parser.ParseComments)
if err != nil {
return nil, err
}
return &generator{
file: f,
fset: fs,
}, nil
}

0 comments on commit eb2b905

Please sign in to comment.