diff --git a/mocks/github.com/vektra/mockery/v2/pkg/fixtures/ExpecterAndRolledVariadic.go b/mocks/github.com/vektra/mockery/v2/pkg/fixtures/ExpecterAndRolledVariadic.go index 201243a2..c8a278c2 100644 --- a/mocks/github.com/vektra/mockery/v2/pkg/fixtures/ExpecterAndRolledVariadic.go +++ b/mocks/github.com/vektra/mockery/v2/pkg/fixtures/ExpecterAndRolledVariadic.go @@ -148,13 +148,13 @@ func (_c *ExpecterAndRolledVariadic_NoReturn_Call) RunAndReturn(run func(string) // Variadic provides a mock function with given fields: ints func (_m *ExpecterAndRolledVariadic) Variadic(ints ...int) error { - ret := func() mock.Arguments { - if len(ints) > 0 { - return _m.Called(ints) - } else { - return _m.Called() - } - }() + var tmpRet mock.Arguments + if len(ints) > 0 { + tmpRet = _m.Called(ints) + } else { + tmpRet = _m.Called() + } + ret := tmpRet var r0 error if rf, ok := ret.Get(0).(func(...int) error); ok { @@ -203,13 +203,13 @@ func (_c *ExpecterAndRolledVariadic_Variadic_Call) RunAndReturn(run func(...int) // VariadicMany provides a mock function with given fields: i, a, intfs func (_m *ExpecterAndRolledVariadic) VariadicMany(i int, a string, intfs ...interface{}) error { - ret := func() mock.Arguments { - if len(intfs) > 0 { - return _m.Called(i, a, intfs) - } else { - return _m.Called(i, a) - } - }() + var tmpRet mock.Arguments + if len(intfs) > 0 { + tmpRet = _m.Called(i, a, intfs) + } else { + tmpRet = _m.Called(i, a) + } + ret := tmpRet var r0 error if rf, ok := ret.Get(0).(func(int, string, ...interface{}) error); ok { @@ -260,13 +260,12 @@ func (_c *ExpecterAndRolledVariadic_VariadicMany_Call) RunAndReturn(run func(int // VariadicNoReturn provides a mock function with given fields: j, is func (_m *ExpecterAndRolledVariadic) VariadicNoReturn(j int, is ...interface{}) { - func() mock.Arguments { - if len(is) > 0 { - return _m.Called(j, is) - } else { - return _m.Called(j) - } - }() + if len(is) > 0 { + _m.Called(j, is) + } else { + _m.Called(j) + } + } // ExpecterAndRolledVariadic_VariadicNoReturn_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VariadicNoReturn' diff --git a/pkg/generator.go b/pkg/generator.go index 8bb71289..cd3df88a 100644 --- a/pkg/generator.go +++ b/pkg/generator.go @@ -423,16 +423,24 @@ func (g *Generator) printf(s string, vals ...interface{}) { var templates = template.New("base template") -func (g *Generator) printTemplate(data interface{}, templateString string) { +func (g *Generator) printTemplateBytes(data interface{}, templateString string) *bytes.Buffer { tmpl, err := templates.New(templateString).Funcs(templateFuncMap).Parse(templateString) if err != nil { - // couldn't compile template panic(err) } - if err := tmpl.Execute(&g.buf, data); err != nil { + var buf bytes.Buffer + + err = tmpl.Execute(&buf, data) + if err != nil { panic(err) } + + return &buf +} + +func (g *Generator) printTemplate(data interface{}, templateString string) { + g.buf.Write(g.printTemplateBytes(data, templateString).Bytes()) } type namer interface { @@ -708,7 +716,7 @@ func (g *Generator) generateMethod(ctx context.Context, method *Method) { params := g.genList(ctx, ftype.Params(), ftype.Variadic()) returns := g.genList(ctx, ftype.Results(), false) - preamble, called := g.generateCalled(params) + preamble, called := g.generateCalled(params, returns) data := struct { FunctionName string @@ -926,26 +934,44 @@ func {{ .ConstructorName }}{{ .TypeConstraint }}(t interface { // steps to prepare its argument list. // // It is separate from Generate to avoid cyclomatic complexity through early return statements. -func (g *Generator) generateCalled(list *paramList) (preamble string, called string) { +func (g *Generator) generateCalled(list *paramList, returnList *paramList) (preamble string, called string) { namesLen := len(list.Names) if namesLen == 0 || !list.Variadic || !g.config.UnrollVariadic { if list.Variadic && !g.config.UnrollVariadic && g.config.WithExpecter { - variadicName := list.Names[namesLen-1] + isFuncReturns := len(returnList.Names) > 0 + + var tmpRet, tmpRetWithAssignment string + if isFuncReturns { + tmpRet = resolveCollision(list.Names, "tmpRet") + tmpRetWithAssignment = fmt.Sprintf("%s = ", tmpRet) + } - called = fmt.Sprintf( - `func() mock.Arguments { - if len(%s) > 0 { - return _m.Called(%s) + calledBytes := g.printTemplateBytes( + struct { + ParamList *paramList + ParamNamesWithoutVariadic []string + VariadicName string + IsFuncReturns bool + TmpRet string + TmpRetWithAssignment string + }{ + ParamList: list, + ParamNamesWithoutVariadic: list.Names[:len(list.Names)-1], + VariadicName: list.Names[namesLen-1], + IsFuncReturns: isFuncReturns, + TmpRet: tmpRet, + TmpRetWithAssignment: tmpRetWithAssignment, + }, + `{{ if .IsFuncReturns }}var {{ .TmpRet }} mock.Arguments {{ end }} + if len({{ .VariadicName }}) > 0 { + {{ .TmpRetWithAssignment }}_m.Called({{ join .ParamList.Names ", " }}) } else { - return _m.Called(%s) + {{ .TmpRetWithAssignment }}_m.Called({{ join .ParamNamesWithoutVariadic ", " }}) } -}()`, - variadicName, - strings.Join(list.Names, ", "), - strings.Join(list.Names[:len(list.Names)-1], ", "), +`, ) - return + return calledBytes.String(), tmpRet } called = "_m.Called(" + strings.Join(list.Names, ", ") + ")" diff --git a/pkg/generator_test.go b/pkg/generator_test.go index 92f3a6bb..56a5e568 100644 --- a/pkg/generator_test.go +++ b/pkg/generator_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" + mocks "github.com/vektra/mockery/v2/mocks/github.com/vektra/mockery/v2/pkg/fixtures" )