Skip to content

Commit

Permalink
mockgen: Added support for dot imports
Browse files Browse the repository at this point in the history
  • Loading branch information
aksdb authored and poy committed Dec 5, 2018
1 parent bcb66aa commit 5f3a40b
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 13 deletions.
14 changes: 14 additions & 0 deletions mockgen/internal/tests/dot_imports/input.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
//go:generate mockgen -package dot_imports -destination mock.go -source input.go
package dot_imports

import (
"bytes"
. "context"
. "net/http"
)

type WithDotImports interface {
Method1() Request
Method2() *bytes.Buffer
Method3() Context
}
72 changes: 72 additions & 0 deletions mockgen/internal/tests/dot_imports/mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion mockgen/model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,12 @@ func (nt *NamedType) String(pm map[string]string, pkgOverride string) string {
if pkgOverride == nt.Package {
return nt.Type
}
return pm[nt.Package] + "." + nt.Type
prefix := pm[nt.Package]
if prefix != "" {
return prefix + "." + nt.Type
} else {
return nt.Type
}
}
func (nt *NamedType) addImports(im map[string]bool) {
if nt.Package != "" {
Expand Down
30 changes: 19 additions & 11 deletions mockgen/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ func ParseFile(source string) (*model.Package, error) {
if err != nil {
return nil, err
}
pkg.DotImports = make([]string, 0, len(dotImports))
for path := range dotImports {
pkg.DotImports = append(pkg.DotImports, path)
}
Expand Down Expand Up @@ -149,7 +148,7 @@ func (p *fileParser) addAuxInterfacesFromFile(pkg string, file *ast.File) {
// parseFile loads all file imports and auxiliary files import into the
// fileParser, parses all file interfaces and returns package model.
func (p *fileParser) parseFile(importPath string, file *ast.File) (*model.Package, error) {
allImports := importsOfFile(file)
allImports, dotImports := importsOfFile(file)
// Don't stomp imports provided by -imports. Those should take precedence.
for pkg, path := range allImports {
if _, ok := p.imports[pkg]; !ok {
Expand All @@ -159,7 +158,8 @@ func (p *fileParser) parseFile(importPath string, file *ast.File) (*model.Packag
// Add imports from auxiliary files, which might be needed for embedded interfaces.
// Don't stomp any other imports.
for _, f := range p.auxFiles {
for pkg, path := range importsOfFile(f) {
auxImports, _ := importsOfFile(f)
for pkg, path := range auxImports {
if _, ok := p.imports[pkg]; !ok {
p.imports[pkg] = path
}
Expand All @@ -177,6 +177,7 @@ func (p *fileParser) parseFile(importPath string, file *ast.File) (*model.Packag
return &model.Package{
Name: file.Name.String(),
Interfaces: is,
DotImports: dotImports,
}, nil
}

Expand All @@ -197,7 +198,8 @@ func (p *fileParser) parsePackage(path string) error {
for ni := range iterInterfaces(file) {
p.importedInterfaces[path][ni.name.Name] = ni.it
}
for pkgName, pkgPath := range importsOfFile(file) {
imports, _ := importsOfFile(file)
for pkgName, pkgPath := range imports {
if _, ok := p.imports[pkgName]; !ok {
p.imports[pkgName] = pkgPath
}
Expand Down Expand Up @@ -428,8 +430,9 @@ func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) {

// importsOfFile returns a map of package name to import path
// of the imports in file.
func importsOfFile(file *ast.File) map[string]string {
m := make(map[string]string)
func importsOfFile(file *ast.File) (normalImports map[string]string, dotImports []string) {
normalImports = make(map[string]string)
dotImports = make([]string, 0)
for _, is := range file.Imports {
var pkgName string
importPath := is.Path.Value[1 : len(is.Path.Value)-1] // remove quotes
Expand All @@ -439,7 +442,7 @@ func importsOfFile(file *ast.File) map[string]string {
if is.Name.Name == "_" {
continue
}
pkgName = removeDot(is.Name.Name)
pkgName = is.Name.Name
} else {
pkg, err := build.Import(importPath, "", 0)
if err != nil {
Expand All @@ -453,12 +456,17 @@ func importsOfFile(file *ast.File) map[string]string {
}
}

if _, ok := m[pkgName]; ok {
log.Fatalf("imported package collision: %q imported twice", pkgName)
if pkgName == "." {
dotImports = append(dotImports, importPath)
} else {

if _, ok := normalImports[pkgName]; ok {
log.Fatalf("imported package collision: %q imported twice", pkgName)
}
normalImports[pkgName] = importPath
}
m[pkgName] = importPath
}
return m
return
}

type namedInterface struct {
Expand Down
2 changes: 1 addition & 1 deletion mockgen/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func TestImportsOfFile(t *testing.T) {
t.Fatalf("Unexpected error: %v", err)
}

imports := importsOfFile(file)
imports, _ := importsOfFile(file)
checkGreeterImports(t, imports)
}

Expand Down

0 comments on commit 5f3a40b

Please sign in to comment.