Skip to content
This repository was archived by the owner on Jun 27, 2023. It is now read-only.

Commit cf7e215

Browse files
authored
fix: Support array parsing with length using binary expression and parenthesis (#603)
Fixes #575
1 parent d0edad8 commit cf7e215

File tree

4 files changed

+90
-22
lines changed

4 files changed

+90
-22
lines changed

mockgen/internal/tests/const_array_length/input.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,7 @@ type I interface {
1010
Foo() [C]int
1111
Bar() [2]int
1212
Baz() [math.MaxInt8]int
13+
Qux() [1 + 2]int
14+
Quux() [(1 + 2)]int
15+
Corge() [math.MaxInt8 - 120]int
1316
}

mockgen/internal/tests/const_array_length/mock.go

Lines changed: 42 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

mockgen/parse.go

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -418,31 +418,14 @@ func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) {
418418
case *ast.ArrayType:
419419
ln := -1
420420
if v.Len != nil {
421-
var value string
422-
switch val := v.Len.(type) {
423-
case (*ast.BasicLit):
424-
value = val.Value
425-
case (*ast.Ident):
426-
// when the length is a const defined locally
427-
value = val.Obj.Decl.(*ast.ValueSpec).Values[0].(*ast.BasicLit).Value
428-
case (*ast.SelectorExpr):
429-
// when the length is a const defined in an external package
430-
usedPkg, err := importer.Default().Import(fmt.Sprintf("%s", val.X))
431-
if err != nil {
432-
return nil, p.errorf(v.Len.Pos(), "unknown package in array length: %v", err)
433-
}
434-
ev, err := types.Eval(token.NewFileSet(), usedPkg, token.NoPos, val.Sel.Name)
435-
if err != nil {
436-
return nil, p.errorf(v.Len.Pos(), "unknown constant in array length: %v", err)
437-
}
438-
value = ev.Value.String()
421+
value, err := p.parseArrayLength(v.Len)
422+
if err != nil {
423+
return nil, err
439424
}
440-
441-
x, err := strconv.Atoi(value)
425+
ln, err = strconv.Atoi(value)
442426
if err != nil {
443427
return nil, p.errorf(v.Len.Pos(), "bad array size: %v", err)
444428
}
445-
ln = x
446429
}
447430
t, err := p.parseType(pkg, v.Elt)
448431
if err != nil {
@@ -525,6 +508,46 @@ func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) {
525508
return nil, fmt.Errorf("don't know how to parse type %T", typ)
526509
}
527510

511+
func (p *fileParser) parseArrayLength(expr ast.Expr) (string, error) {
512+
switch val := expr.(type) {
513+
case (*ast.BasicLit):
514+
return val.Value, nil
515+
case (*ast.Ident):
516+
// when the length is a const defined locally
517+
return val.Obj.Decl.(*ast.ValueSpec).Values[0].(*ast.BasicLit).Value, nil
518+
case (*ast.SelectorExpr):
519+
// when the length is a const defined in an external package
520+
usedPkg, err := importer.Default().Import(fmt.Sprintf("%s", val.X))
521+
if err != nil {
522+
return "", p.errorf(expr.Pos(), "unknown package in array length: %v", err)
523+
}
524+
ev, err := types.Eval(token.NewFileSet(), usedPkg, token.NoPos, val.Sel.Name)
525+
if err != nil {
526+
return "", p.errorf(expr.Pos(), "unknown constant in array length: %v", err)
527+
}
528+
return ev.Value.String(), nil
529+
case (*ast.ParenExpr):
530+
return p.parseArrayLength(val.X)
531+
case (*ast.BinaryExpr):
532+
x, err := p.parseArrayLength(val.X)
533+
if err != nil {
534+
return "", err
535+
}
536+
y, err := p.parseArrayLength(val.Y)
537+
if err != nil {
538+
return "", err
539+
}
540+
biExpr := fmt.Sprintf("%s%v%s", x, val.Op, y)
541+
tv, err := types.Eval(token.NewFileSet(), nil, token.NoPos, biExpr)
542+
if err != nil {
543+
return "", p.errorf(expr.Pos(), "invalid expression in array length: %v", err)
544+
}
545+
return tv.Value.String(), nil
546+
default:
547+
return "", p.errorf(expr.Pos(), "invalid expression in array length: %v", val)
548+
}
549+
}
550+
528551
// importsOfFile returns a map of package name to import path
529552
// of the imports in file.
530553
func importsOfFile(file *ast.File) (normalImports map[string]importedPackage, dotImports []string) {

mockgen/parse_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ func TestParseArrayWithConstLength(t *testing.T) {
136136
t.Fatalf("Unexpected error: %v", err)
137137
}
138138

139-
expects := []string{"[2]int", "[2]int", "[127]int"}
139+
expects := []string{"[2]int", "[2]int", "[127]int", "[3]int", "[3]int", "[7]int"}
140140
for i, e := range expects {
141141
got := pkg.Interfaces[0].Methods[i].Out[0].Type.String(nil, "")
142142
if got != e {

0 commit comments

Comments
 (0)