Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ErrorIs assertion; go1.13 feature #884

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,10 @@ Please feel free to submit issues, fork the repository and send pull requests!
When submitting an issue, we ask that you please include a complete test function that demonstrates the issue. Extra credit for those using Testify to write the test code that demonstrates it.

Code generation is used. Look for `CODE GENERATED AUTOMATICALLY` at the top of some files. Run `go generate ./...` to update generated files.
Assertions are added to `assert/assertions.go` file and then code generated to all other forms including `require.*` as
```bash
go generate ./assert/ ./require/
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this not redundant to line 335? Would running go generate ./... not have the same effect?

Copy link
Author

@nikandfor nikandfor Feb 10, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

335 says how to generate code, but doesn't say, where to add new one by hands.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I wrote about go generate second time.

```

We also chat on the [Gophers Slack](https://gophers.slack.com) group in the `#testify` and `#testify-dev` channels.

Expand Down
160 changes: 127 additions & 33 deletions _codegen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,45 +35,91 @@ var (
out = flag.String("out", "", "What file to write the source code to")
)

type Context struct {
files map[string]*ast.File
fset *token.FileSet
scope *types.Scope
docs *doc.Package

tags map[string]string
}

func main() {
flag.Parse()

scope, docs, err := parsePackageSource(*pkg)
var ctx Context

err := ctx.parsePackageSource(*pkg)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What problem are we solving by moving all the functions to be receivers on Context?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Getting access to more data. To assign build tags to a function in analyzeCode I needed *token.FileSet (that wasn't in returning arguments) and a map to save file to tags relations.

Adding these new return arguments meant to have 4 of them. I decided it's time to store them all in a receiver struct.

if err != nil {
log.Fatal(err)
}

importer, funcs, err := analyzeCode(scope, docs)
funcs, err := ctx.analyzeCode()
if err != nil {
log.Fatal(err)
}

if err := generateCode(importer, funcs); err != nil {
err = ctx.generateCode(funcs)
if err != nil {
log.Fatal(err)
}
}

func generateCode(importer imports.Importer, funcs []testFunc) error {
func (c *Context) generateCode(funcs []testFunc) error {
tags := map[string]struct{}{}
for _, f := range funcs {
tags[f.Tags] = struct{}{}
}

for tag := range tags {
err := c.generateCodeTag(funcs, tag)
if err != nil {
return err
}
}

return nil
}

func (c *Context) generateCodeTag(funcs []testFunc, tags string) error {
buff := bytes.NewBuffer(nil)

tmplHead, tmplFunc, err := parseTemplates()
if err != nil {
return err
}

// Make imports for given functions set (filtered by build tags)
importer := imports.New(*outputPkg)
for _, fn := range funcs {
if fn.Tags != tags {
continue
}

sig := fn.TypeInfo.Type().(*types.Signature)

importer.AddImportsFrom(sig.Params())
}

// Generate header
if err := tmplHead.Execute(buff, struct {
Name string
Imports map[string]string
Name string
BuildTags string
Imports map[string]string
}{
*outputPkg,
tags,
importer.Imports(),
}); err != nil {
return err
}

// Generate funcs
for _, fn := range funcs {
if fn.Tags != tags {
continue
}

buff.Write([]byte("\n\n"))
if err := tmplFunc.Execute(buff, &fn); err != nil {
return err
Expand All @@ -86,7 +132,7 @@ func generateCode(importer imports.Importer, funcs []testFunc) error {
}

// Write file
output, err := outputFile()
output, err := outputFile(tags)
if err != nil {
return err
}
Expand Down Expand Up @@ -114,28 +160,30 @@ func parseTemplates() (*template.Template, *template.Template, error) {
return tmplHead, tmpl, nil
}

func outputFile() (*os.File, error) {
func outputFile(tags string) (*os.File, error) {
filename := *out
if filename == "-" || (filename == "" && *tmplFile == "") {
return os.Stdout, nil
}
if filename == "" {
filename = strings.TrimSuffix(strings.TrimSuffix(*tmplFile, ".tmpl"), ".go") + ".go"
if tags != "" {
tags = "_" + tags
}
filename = strings.TrimSuffix(strings.TrimSuffix(*tmplFile, ".tmpl"), ".go") + tags + ".go"
}
return os.Create(filename)
}

// analyzeCode takes the types scope and the docs and returns the import
// information and information about all the assertion functions.
func analyzeCode(scope *types.Scope, docs *doc.Package) (imports.Importer, []testFunc, error) {
testingT := scope.Lookup("TestingT").Type().Underlying().(*types.Interface)
func (c *Context) analyzeCode() ([]testFunc, error) {
testingT := c.scope.Lookup("TestingT").Type().Underlying().(*types.Interface)

importer := imports.New(*outputPkg)
var funcs []testFunc
// Go through all the top level functions
for _, fdocs := range docs.Funcs {
for _, fdocs := range c.docs.Funcs {
// Find the function
obj := scope.Lookup(fdocs.Name)
obj := c.scope.Lookup(fdocs.Name)

fn, ok := obj.(*types.Func)
if !ok {
Expand Down Expand Up @@ -164,33 +212,43 @@ func analyzeCode(scope *types.Scope, docs *doc.Package) (imports.Importer, []tes
continue
}

funcs = append(funcs, testFunc{*outputPkg, fdocs, fn})
importer.AddImportsFrom(sig.Params())
tags := c.buildTags(obj)

funcs = append(funcs, testFunc{
CurrentPkg: *outputPkg,
Tags: tags,
DocInfo: fdocs,
TypeInfo: fn,
})
}
return importer, funcs, nil
return funcs, nil
}

// parsePackageSource returns the types scope and the package documentation from the package
func parsePackageSource(pkg string) (*types.Scope, *doc.Package, error) {
func (c *Context) parsePackageSource(pkg string) error {
pd, err := build.Import(pkg, ".", 0)
if err != nil {
return nil, nil, err
return err
}

fset := token.NewFileSet()
files := make(map[string]*ast.File)
c.fset = token.NewFileSet()
c.files = make(map[string]*ast.File)
c.tags = make(map[string]string)
fileList := make([]*ast.File, len(pd.GoFiles))
for i, fname := range pd.GoFiles {
src, err := ioutil.ReadFile(path.Join(pd.Dir, fname))
if err != nil {
return nil, nil, err
return err
}
f, err := parser.ParseFile(fset, fname, src, parser.ParseComments|parser.AllErrors)
f, err := parser.ParseFile(c.fset, fname, src, parser.ParseComments|parser.AllErrors)
if err != nil {
return nil, nil, err
return err
}
files[fname] = f

c.files[fname] = f
fileList[i] = f

c.parseBuildTags(fname, f)
}

cfg := types.Config{
Expand All @@ -199,21 +257,52 @@ func parsePackageSource(pkg string) (*types.Scope, *doc.Package, error) {
info := types.Info{
Defs: make(map[*ast.Ident]types.Object),
}
tp, err := cfg.Check(pkg, fset, fileList, &info)
tp, err := cfg.Check(pkg, c.fset, fileList, &info)
if err != nil {
return nil, nil, err
return err
}

scope := tp.Scope()
c.scope = tp.Scope()

ap, _ := ast.NewPackage(c.fset, c.files, nil, nil)
c.docs = doc.New(ap, pkg, 0)

return nil
}

func (c *Context) buildTags(o types.Object) string {
tf := c.fset.File(o.Pos())

return c.tags[tf.Name()]
}

func (c *Context) parseBuildTags(fname string, f *ast.File) {
const pref = "// +build "

for _, g := range f.Comments {
for _, comm := range g.List {
t := comm.Text
if !strings.HasPrefix(t, pref) {
continue
}

ap, _ := ast.NewPackage(fset, files, nil, nil)
docs := doc.New(ap, pkg, 0)
t = strings.TrimPrefix(t, pref)
t = strings.TrimSpace(t)
t = strings.ReplaceAll(t, ",", "-")
t = strings.ReplaceAll(t, " ", "_")
t = strings.ReplaceAll(t, "!", "N")

c.tags[fname] = t
return
}
}

return scope, docs, nil
c.tags[fname] = ""
}

type testFunc struct {
CurrentPkg string
Tags string
DocInfo *doc.Func
TypeInfo *types.Func
}
Expand Down Expand Up @@ -297,17 +386,22 @@ func (f *testFunc) CommentWithoutT(receiver string) string {
return strings.Replace(f.Comment(), search, replace, -1)
}

var headerTemplate = `/*
var headerTemplate = `{{ with .BuildTags }}// +build {{ . }}
{{ end }}
/*
* CODE GENERATED AUTOMATICALLY WITH github.com/stretchr/testify/_codegen
* THIS FILE MUST NOT BE EDITED BY HAND
*/

package {{.Name}}

{{ with .Imports }}
import (
{{range $path, $name := .Imports}}
{{range $path, $name := .}}
{{$name}} "{{$path}}"{{end}}
)
{{ end }}
{{ if ne .Name "assert" }}var _ assert.TestingT // in case no function required assert package{{ end }}
`

var funcTemplate = `{{.Comment}}
Expand Down
20 changes: 20 additions & 0 deletions assert/assertion_format_go1.13.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// +build go1.13

/*
* CODE GENERATED AUTOMATICALLY WITH github.com/stretchr/testify/_codegen
* THIS FILE MUST NOT BE EDITED BY HAND
*/

package assert

// ErrorIsf asserts that a specified error is an another error wrapper as defined by go1.13 errors package.
//
// actualObj, err := SomeFunction()
// assert.ErrorIsf(t, err, ErrNotFound, "error message %s", "formatted")
// assert.Nil(t, actualObj)
func ErrorIsf(t TestingT, theError error, theTarget error, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return ErrorIs(t, theError, theTarget, append([]interface{}{msg}, args...)...)
}
32 changes: 32 additions & 0 deletions assert/assertion_forward_go1.13.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// +build go1.13

/*
* CODE GENERATED AUTOMATICALLY WITH github.com/stretchr/testify/_codegen
* THIS FILE MUST NOT BE EDITED BY HAND
*/

package assert

// ErrorIs asserts that a specified error is an another error wrapper as defined by go1.13 errors package.
//
// actualObj, err := SomeFunction()
// a.ErrorIs(err, ErrNotFound)
// assert.Nil(t, actualObj)
func (a *Assertions) ErrorIs(theError error, theTarget error, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return ErrorIs(a.t, theError, theTarget, msgAndArgs...)
}

// ErrorIsf asserts that a specified error is an another error wrapper as defined by go1.13 errors package.
//
// actualObj, err := SomeFunction()
// a.ErrorIsf(err, ErrNotFound, "error message %s", "formatted")
// assert.Nil(t, actualObj)
func (a *Assertions) ErrorIsf(theError error, theTarget error, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return ErrorIsf(a.t, theError, theTarget, msg, args...)
}
25 changes: 25 additions & 0 deletions assert/assertions_go1.13.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// +build go1.13

package assert

import (
"errors"
"fmt"
)

// ErrorIs asserts that a specified error is an another error wrapper as defined by go1.13 errors package.
//
// actualObj, err := SomeFunction()
// assert.ErrorIs(t, err, ErrNotFound)
// assert.Nil(t, actualObj)
func ErrorIs(t TestingT, theError, theTarget error, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}

if errors.Is(theError, theTarget) {
return true
}

return Fail(t, fmt.Sprintf("Error is not %v, but %v", theTarget, theError), msgAndArgs...)
}
Loading