From cd8a04247aefd08d0c89a5d2e13658ab98781d66 Mon Sep 17 00:00:00 2001 From: visualfc Date: Sun, 28 Jan 2024 20:40:28 +0800 Subject: [PATCH] support generic type oveload method --- builtin_test.go | 2 +- codebuild.go | 17 ++++++--- func_ext.go | 47 ++++++++++++++++++++++--- import.go | 10 +++--- internal/foo/foo.go | 32 +++++++++++++++++ typeparams_test.go | 84 +++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 178 insertions(+), 14 deletions(-) diff --git a/builtin_test.go b/builtin_test.go index e7484963..8f00a294 100644 --- a/builtin_test.go +++ b/builtin_test.go @@ -759,7 +759,7 @@ func TestCheckSigFuncExObjects(t *testing.T) { count int }{ {"TyOverloadFunc", sigFuncEx(nil, nil, &TyOverloadFunc{objs}), 2}, - {"TyOverloadMethod", sigFuncEx(nil, nil, &TyOverloadMethod{objs}), 2}, + {"TyOverloadMethod", sigFuncEx(nil, nil, &TyOverloadMethod{Methods: objs}), 2}, {"TyTemplateRecvMethod", sigFuncEx(nil, nil, &TyTemplateRecvMethod{types.NewParam(0, nil, "", tyInt)}), 1}, {"TyTemplateRecvMethod", sigFuncEx(nil, nil, &TyTemplateRecvMethod{fn}), 2}, {"TyOverloadNamed", sigFuncEx(nil, nil, &TyOverloadNamed{Types: []*types.Named{named}}), 1}, diff --git a/codebuild.go b/codebuild.go index dc91dc7d..42fdac51 100644 --- a/codebuild.go +++ b/codebuild.go @@ -1625,7 +1625,7 @@ retry: return kind } } - if kind := p.method(t, name, aliasName, flag, arg, srcExpr); kind != MemberInvalid { + if kind := p.method(t, name, aliasName, flag, arg, srcExpr, t.TypeArgs() != nil); kind != MemberInvalid { return kind } if fstruc { @@ -1641,7 +1641,7 @@ retry: } case *types.Named: named, typ = o, p.getUnderlying(o) // may cause to loadNamed (delay-loaded) - if kind := p.method(o, name, aliasName, flag, arg, srcExpr); kind != MemberInvalid { + if kind := p.method(o, name, aliasName, flag, arg, srcExpr, o.TypeArgs() != nil); kind != MemberInvalid { return kind } if _, ok := typ.(*types.Struct); ok { @@ -1657,7 +1657,7 @@ retry: } case *types.Interface: o.Complete() - if kind := p.method(o, name, aliasName, flag, arg, srcExpr); kind != MemberInvalid { + if kind := p.method(o, name, aliasName, flag, arg, srcExpr, false); kind != MemberInvalid { return kind } case *types.Basic, *types.Slice, *types.Map, *types.Chan: @@ -1667,6 +1667,7 @@ retry: } type methodList interface { + types.Type NumMethods() int Method(i int) *types.Func } @@ -1713,7 +1714,7 @@ func (p *CodeBuilder) allowAccess(pkg *types.Package, name string) bool { } func (p *CodeBuilder) method( - o methodList, name, aliasName string, flag MemberFlag, arg *Element, src ast.Node) (kind MemberKind) { + o methodList, name, aliasName string, flag MemberFlag, arg *Element, src ast.Node, namedHasTypeArgs bool) (kind MemberKind) { var found *types.Func var exact bool for i, n := 0, o.NumMethods(); i < n; i++ { @@ -1738,7 +1739,13 @@ func (p *CodeBuilder) method( if autoprop && !methodHasAutoProperty(typ, 0) { return memberBad } - + if namedHasTypeArgs { + if t, ok := CheckFuncEx(typ.(*types.Signature)); ok { + if m, ok := t.(*TyOverloadMethod); ok && m.IsGeneric() { + typ = m.Instantiate(o.(*types.Named)) + } + } + } sel := selector(arg, found.Name()) ret := &internal.Elem{ Val: sel, diff --git a/func_ext.go b/func_ext.go index 2c9b13a7..ddf11f8d 100644 --- a/func_ext.go +++ b/func_ext.go @@ -175,7 +175,9 @@ func CheckOverloadFunc(sig *types.Signature) (funcs []types.Object, ok bool) { // TyOverloadMethod: overload function type type TyOverloadMethod struct { - Methods []types.Object + Methods []types.Object + indexs []int // func object indexs + instance map[*types.Named]*types.Signature // cache type signature for named } func (p *TyOverloadMethod) At(i int) types.Object { return p.Methods[i] } @@ -185,9 +187,46 @@ func (p *TyOverloadMethod) Underlying() types.Type { return p } func (p *TyOverloadMethod) String() string { return "TyOverloadMethod" } func (p *TyOverloadMethod) funcEx() {} -// NewOverloadMethod creates an overload method. -func NewOverloadMethod(typ *types.Named, pos token.Pos, pkg *types.Package, name string, methods ...types.Object) *types.Func { - return newMethodEx(typ, pos, pkg, name, &TyOverloadMethod{methods}) +func NewOverloadMethod(typ *types.Named, pos token.Pos, pkg *types.Package, name string, objectIndex map[types.Object]int, methods ...types.Object) *types.Func { + t := &TyOverloadMethod{Methods: methods} + if typ.TypeParams() != nil { + t.indexs = make([]int, len(methods)) + for i, obj := range methods { + t.indexs[i] = objectIndex[obj] + } + t.instance = make(map[*types.Named]*types.Signature) + } + return newMethodEx(typ, pos, pkg, name, t) +} + +func (m *TyOverloadMethod) IsGeneric() bool { + return len(m.indexs) != 0 +} + +func (m *TyOverloadMethod) Instantiate(named *types.Named) *types.Signature { + sig, ok := m.instance[named] + if !ok { + sig = newOverloadMethodType(named, m) + m.instance[named] = sig + } + return sig +} + +func newOverloadMethodType(named *types.Named, m *TyOverloadMethod) *types.Signature { + var list methodList + switch t := named.Underlying().(type) { + case *types.Interface: + list = t + default: + list = named + } + pkg := named.Obj().Pkg() + recv := types.NewVar(token.NoPos, pkg, "", named) + methods := make([]types.Object, len(m.indexs)) + for i, index := range m.indexs { + methods[i] = list.Method(index) + } + return sigFuncEx(pkg, recv, &TyOverloadMethod{Methods: methods}) } // CheckOverloadMethod checks a func is overload method or not. diff --git a/import.go b/import.go index 698a95fe..83ab39c5 100644 --- a/import.go +++ b/import.go @@ -94,6 +94,7 @@ func InitThisGopPkg(pkg *types.Package) { scope := pkg.Scope() gopos := make([]string, 0, 4) overloads := make(map[omthd][]types.Object) + mobjectIndexs := make(map[types.Object]int) onameds := make(map[string][]*types.Named) names := scope.Names() for _, name := range names { @@ -120,6 +121,7 @@ func InitThisGopPkg(pkg *types.Package) { mthd := mName[:len(mName)-3] key := omthd{named, mthd} overloads[key] = append(overloads[key], m) + mobjectIndexs[m] = i } } if isOverload(name) { // overload named @@ -150,7 +152,7 @@ func InitThisGopPkg(pkg *types.Package) { } } if len(fns) > 0 { - newOverload(pkg, scope, m, fns) + newOverload(pkg, scope, m, fns, nil) } delete(overloads, m) } @@ -158,7 +160,7 @@ func InitThisGopPkg(pkg *types.Package) { for key, items := range overloads { off := len(key.name) + 2 fns := overloadFuncs(off, items) - newOverload(pkg, scope, key, fns) + newOverload(pkg, scope, key, fns, mobjectIndexs) } for name, items := range onameds { off := len(name) + 2 @@ -282,7 +284,7 @@ func checkOverloads(scope *types.Scope, gopoName string) (ret []string, exists b return } -func newOverload(pkg *types.Package, scope *types.Scope, m omthd, fns []types.Object) { +func newOverload(pkg *types.Package, scope *types.Scope, m omthd, fns []types.Object, mobjectIndexs map[types.Object]int) { if m.typ == nil { if debugImport { log.Println("==> NewOverloadFunc", m.name) @@ -294,7 +296,7 @@ func newOverload(pkg *types.Package, scope *types.Scope, m omthd, fns []types.Ob if debugImport { log.Println("==> NewOverloadMethod", m.typ.Obj().Name(), m.name) } - NewOverloadMethod(m.typ, token.NoPos, pkg, m.name, fns...) + NewOverloadMethod(m.typ, token.NoPos, pkg, m.name, mobjectIndexs, fns...) } } diff --git a/internal/foo/foo.go b/internal/foo/foo.go index 5d3dc165..c3e8c32e 100644 --- a/internal/foo/foo.go +++ b/internal/foo/foo.go @@ -116,4 +116,36 @@ type NodeSeter interface { Attr__1(k, v string) (ret NodeSeter) } +type Data[T any] struct { + data []T +} + +func (p *Data[T]) Size() int { + return len(p.data) +} + +func (p *Data[T]) Add__0(v ...T) { + p.data = append(p.data, v...) +} + +func (p *Data[T]) Add__1(v Data[T]) { + p.data = append(p.data, v.data...) +} + +func (p *Data[T]) IndexOf__0(v T) int { + return -1 +} + +func (p *Data[T]) IndexOf__1(pos int, v T) int { + return -1 +} + +type DataInterface[T any] interface { + Size() int + Add__0(v ...T) + Add__1(v DataInterface[T]) + IndexOf__0(v T) int + IndexOf__1(pos int, v T) int +} + // ----------------------------------------------------------------------------- diff --git a/typeparams_test.go b/typeparams_test.go index c5af42e4..7f9de436 100644 --- a/typeparams_test.go +++ b/typeparams_test.go @@ -1157,3 +1157,87 @@ func main() { } `) } + +func TestGenericTypeOverloadMethod(t *testing.T) { + pkg := newMainPackage() + foo := pkg.Import("github.com/goplus/gogen/internal/foo") + tyDataT := foo.Ref("Data").Type() + tyInt := types.Typ[types.Int] + tyData, _ := types.Instantiate(nil, tyDataT, []types.Type{tyInt}, true) + v := pkg.NewParam(token.NoPos, "v", tyData) + pkg.NewFunc(nil, "bar", types.NewTuple(v), nil, false).BodyStart(pkg). + DefineVarStart(token.NoPos, "n").Val(v). + Debug(func(cb *gogen.CodeBuilder) { + cb.Member("size", gogen.MemberFlagMethodAlias) + }). + Call(0).EndInit(1).EndStmt(). + Val(v). + Debug(func(cb *gogen.CodeBuilder) { + cb.Member("add", gogen.MemberFlagMethodAlias) + }). + Val(0).Val(1).Call(2).EndStmt(). + Val(v). + Debug(func(cb *gogen.CodeBuilder) { + cb.Member("add", gogen.MemberFlagMethodAlias) + }). + Val(v).Call(1).EndStmt(). + DefineVarStart(token.NoPos, "i").Val(v). + Debug(func(cb *gogen.CodeBuilder) { + cb.Member("indexOf", gogen.MemberFlagMethodAlias) + }). + Val(0).Val(1).Call(2).EndInit(1).EndStmt(). + End() + domTest(t, pkg, `package main + +import "github.com/goplus/gogen/internal/foo" + +func bar(v foo.Data[int]) { + n := v.Size() + v.Add__0(0, 1) + v.Add__1(v) + i := v.IndexOf__1(0, 1) +} +`) +} + +func TestGenericInterfaceOverloadMethod(t *testing.T) { + pkg := newMainPackage() + foo := pkg.Import("github.com/goplus/gogen/internal/foo") + tyDataT := foo.Ref("DataInterface").Type() + tyInt := types.Typ[types.Int] + tyData, _ := types.Instantiate(nil, tyDataT, []types.Type{tyInt}, true) + v := pkg.NewParam(token.NoPos, "v", tyData) + pkg.NewFunc(nil, "bar", types.NewTuple(v), nil, false).BodyStart(pkg). + DefineVarStart(token.NoPos, "n").Val(v). + Debug(func(cb *gogen.CodeBuilder) { + cb.Member("size", gogen.MemberFlagMethodAlias) + }). + Call(0).EndInit(1).EndStmt(). + Val(v). + Debug(func(cb *gogen.CodeBuilder) { + cb.Member("add", gogen.MemberFlagMethodAlias) + }). + Val(0).Val(1).Call(2).EndStmt(). + Val(v). + Debug(func(cb *gogen.CodeBuilder) { + cb.Member("add", gogen.MemberFlagMethodAlias) + }). + Val(v).Call(1).EndStmt(). + DefineVarStart(token.NoPos, "i").Val(v). + Debug(func(cb *gogen.CodeBuilder) { + cb.Member("indexOf", gogen.MemberFlagMethodAlias) + }). + Val(0).Val(1).Call(2).EndInit(1).EndStmt(). + End() + domTest(t, pkg, `package main + +import "github.com/goplus/gogen/internal/foo" + +func bar(v foo.DataInterface[int]) { + n := v.Size() + v.Add__0(0, 1) + v.Add__1(v) + i := v.IndexOf__1(0, 1) +} +`) +}