Skip to content

Commit

Permalink
generate variadic call by template
Browse files Browse the repository at this point in the history
  • Loading branch information
jokly committed Sep 1, 2023
1 parent 80c5909 commit 73a6a6f
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 37 deletions.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

58 changes: 42 additions & 16 deletions pkg/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -423,16 +423,24 @@ func (g *Generator) printf(s string, vals ...interface{}) {

var templates = template.New("base template")

func (g *Generator) printTemplate(data interface{}, templateString string) {
func (g *Generator) printTemplateBytes(data interface{}, templateString string) *bytes.Buffer {
tmpl, err := templates.New(templateString).Funcs(templateFuncMap).Parse(templateString)
if err != nil {
// couldn't compile template
panic(err)
}

if err := tmpl.Execute(&g.buf, data); err != nil {
var buf bytes.Buffer

err = tmpl.Execute(&buf, data)
if err != nil {
panic(err)
}

return &buf
}

func (g *Generator) printTemplate(data interface{}, templateString string) {
g.buf.Write(g.printTemplateBytes(data, templateString).Bytes())
}

type namer interface {
Expand Down Expand Up @@ -708,7 +716,7 @@ func (g *Generator) generateMethod(ctx context.Context, method *Method) {

params := g.genList(ctx, ftype.Params(), ftype.Variadic())
returns := g.genList(ctx, ftype.Results(), false)
preamble, called := g.generateCalled(params)
preamble, called := g.generateCalled(params, returns)

data := struct {
FunctionName string
Expand Down Expand Up @@ -926,26 +934,44 @@ func {{ .ConstructorName }}{{ .TypeConstraint }}(t interface {
// steps to prepare its argument list.
//
// It is separate from Generate to avoid cyclomatic complexity through early return statements.
func (g *Generator) generateCalled(list *paramList) (preamble string, called string) {
func (g *Generator) generateCalled(list *paramList, returnList *paramList) (preamble string, called string) {
namesLen := len(list.Names)
if namesLen == 0 || !list.Variadic || !g.config.UnrollVariadic {
if list.Variadic && !g.config.UnrollVariadic && g.config.WithExpecter {
variadicName := list.Names[namesLen-1]
isFuncReturns := len(returnList.Names) > 0

var tmpRet, tmpRetWithAssignment string
if isFuncReturns {
tmpRet = resolveCollision(list.Names, "tmpRet")
tmpRetWithAssignment = fmt.Sprintf("%s = ", tmpRet)
}

called = fmt.Sprintf(
`func() mock.Arguments {
if len(%s) > 0 {
return _m.Called(%s)
calledBytes := g.printTemplateBytes(
struct {
ParamList *paramList
ParamNamesWithoutVariadic []string
VariadicName string
IsFuncReturns bool
TmpRet string
TmpRetWithAssignment string
}{
ParamList: list,
ParamNamesWithoutVariadic: list.Names[:len(list.Names)-1],
VariadicName: list.Names[namesLen-1],
IsFuncReturns: isFuncReturns,
TmpRet: tmpRet,
TmpRetWithAssignment: tmpRetWithAssignment,
},
`{{ if .IsFuncReturns }}var {{ .TmpRet }} mock.Arguments {{ end }}
if len({{ .VariadicName }}) > 0 {
{{ .TmpRetWithAssignment }}_m.Called({{ join .ParamList.Names ", " }})
} else {
return _m.Called(%s)
{{ .TmpRetWithAssignment }}_m.Called({{ join .ParamNamesWithoutVariadic ", " }})
}
}()`,
variadicName,
strings.Join(list.Names, ", "),
strings.Join(list.Names[:len(list.Names)-1], ", "),
`,
)

return
return calledBytes.String(), tmpRet
}

called = "_m.Called(" + strings.Join(list.Names, ", ") + ")"
Expand Down
1 change: 1 addition & 0 deletions pkg/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"

mocks "github.com/vektra/mockery/v2/mocks/github.com/vektra/mockery/v2/pkg/fixtures"
)

Expand Down

0 comments on commit 73a6a6f

Please sign in to comment.