Skip to content

Commit cd8a042

Browse files
committed
support generic type oveload method
1 parent af40cbf commit cd8a042

File tree

6 files changed

+178
-14
lines changed

6 files changed

+178
-14
lines changed

builtin_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -759,7 +759,7 @@ func TestCheckSigFuncExObjects(t *testing.T) {
759759
count int
760760
}{
761761
{"TyOverloadFunc", sigFuncEx(nil, nil, &TyOverloadFunc{objs}), 2},
762-
{"TyOverloadMethod", sigFuncEx(nil, nil, &TyOverloadMethod{objs}), 2},
762+
{"TyOverloadMethod", sigFuncEx(nil, nil, &TyOverloadMethod{Methods: objs}), 2},
763763
{"TyTemplateRecvMethod", sigFuncEx(nil, nil, &TyTemplateRecvMethod{types.NewParam(0, nil, "", tyInt)}), 1},
764764
{"TyTemplateRecvMethod", sigFuncEx(nil, nil, &TyTemplateRecvMethod{fn}), 2},
765765
{"TyOverloadNamed", sigFuncEx(nil, nil, &TyOverloadNamed{Types: []*types.Named{named}}), 1},

codebuild.go

+12-5
Original file line numberDiff line numberDiff line change
@@ -1625,7 +1625,7 @@ retry:
16251625
return kind
16261626
}
16271627
}
1628-
if kind := p.method(t, name, aliasName, flag, arg, srcExpr); kind != MemberInvalid {
1628+
if kind := p.method(t, name, aliasName, flag, arg, srcExpr, t.TypeArgs() != nil); kind != MemberInvalid {
16291629
return kind
16301630
}
16311631
if fstruc {
@@ -1641,7 +1641,7 @@ retry:
16411641
}
16421642
case *types.Named:
16431643
named, typ = o, p.getUnderlying(o) // may cause to loadNamed (delay-loaded)
1644-
if kind := p.method(o, name, aliasName, flag, arg, srcExpr); kind != MemberInvalid {
1644+
if kind := p.method(o, name, aliasName, flag, arg, srcExpr, o.TypeArgs() != nil); kind != MemberInvalid {
16451645
return kind
16461646
}
16471647
if _, ok := typ.(*types.Struct); ok {
@@ -1657,7 +1657,7 @@ retry:
16571657
}
16581658
case *types.Interface:
16591659
o.Complete()
1660-
if kind := p.method(o, name, aliasName, flag, arg, srcExpr); kind != MemberInvalid {
1660+
if kind := p.method(o, name, aliasName, flag, arg, srcExpr, false); kind != MemberInvalid {
16611661
return kind
16621662
}
16631663
case *types.Basic, *types.Slice, *types.Map, *types.Chan:
@@ -1667,6 +1667,7 @@ retry:
16671667
}
16681668

16691669
type methodList interface {
1670+
types.Type
16701671
NumMethods() int
16711672
Method(i int) *types.Func
16721673
}
@@ -1713,7 +1714,7 @@ func (p *CodeBuilder) allowAccess(pkg *types.Package, name string) bool {
17131714
}
17141715

17151716
func (p *CodeBuilder) method(
1716-
o methodList, name, aliasName string, flag MemberFlag, arg *Element, src ast.Node) (kind MemberKind) {
1717+
o methodList, name, aliasName string, flag MemberFlag, arg *Element, src ast.Node, namedHasTypeArgs bool) (kind MemberKind) {
17171718
var found *types.Func
17181719
var exact bool
17191720
for i, n := 0, o.NumMethods(); i < n; i++ {
@@ -1738,7 +1739,13 @@ func (p *CodeBuilder) method(
17381739
if autoprop && !methodHasAutoProperty(typ, 0) {
17391740
return memberBad
17401741
}
1741-
1742+
if namedHasTypeArgs {
1743+
if t, ok := CheckFuncEx(typ.(*types.Signature)); ok {
1744+
if m, ok := t.(*TyOverloadMethod); ok && m.IsGeneric() {
1745+
typ = m.Instantiate(o.(*types.Named))
1746+
}
1747+
}
1748+
}
17421749
sel := selector(arg, found.Name())
17431750
ret := &internal.Elem{
17441751
Val: sel,

func_ext.go

+43-4
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,9 @@ func CheckOverloadFunc(sig *types.Signature) (funcs []types.Object, ok bool) {
175175

176176
// TyOverloadMethod: overload function type
177177
type TyOverloadMethod struct {
178-
Methods []types.Object
178+
Methods []types.Object
179+
indexs []int // func object indexs
180+
instance map[*types.Named]*types.Signature // cache type signature for named
179181
}
180182

181183
func (p *TyOverloadMethod) At(i int) types.Object { return p.Methods[i] }
@@ -185,9 +187,46 @@ func (p *TyOverloadMethod) Underlying() types.Type { return p }
185187
func (p *TyOverloadMethod) String() string { return "TyOverloadMethod" }
186188
func (p *TyOverloadMethod) funcEx() {}
187189

188-
// NewOverloadMethod creates an overload method.
189-
func NewOverloadMethod(typ *types.Named, pos token.Pos, pkg *types.Package, name string, methods ...types.Object) *types.Func {
190-
return newMethodEx(typ, pos, pkg, name, &TyOverloadMethod{methods})
190+
func NewOverloadMethod(typ *types.Named, pos token.Pos, pkg *types.Package, name string, objectIndex map[types.Object]int, methods ...types.Object) *types.Func {
191+
t := &TyOverloadMethod{Methods: methods}
192+
if typ.TypeParams() != nil {
193+
t.indexs = make([]int, len(methods))
194+
for i, obj := range methods {
195+
t.indexs[i] = objectIndex[obj]
196+
}
197+
t.instance = make(map[*types.Named]*types.Signature)
198+
}
199+
return newMethodEx(typ, pos, pkg, name, t)
200+
}
201+
202+
func (m *TyOverloadMethod) IsGeneric() bool {
203+
return len(m.indexs) != 0
204+
}
205+
206+
func (m *TyOverloadMethod) Instantiate(named *types.Named) *types.Signature {
207+
sig, ok := m.instance[named]
208+
if !ok {
209+
sig = newOverloadMethodType(named, m)
210+
m.instance[named] = sig
211+
}
212+
return sig
213+
}
214+
215+
func newOverloadMethodType(named *types.Named, m *TyOverloadMethod) *types.Signature {
216+
var list methodList
217+
switch t := named.Underlying().(type) {
218+
case *types.Interface:
219+
list = t
220+
default:
221+
list = named
222+
}
223+
pkg := named.Obj().Pkg()
224+
recv := types.NewVar(token.NoPos, pkg, "", named)
225+
methods := make([]types.Object, len(m.indexs))
226+
for i, index := range m.indexs {
227+
methods[i] = list.Method(index)
228+
}
229+
return sigFuncEx(pkg, recv, &TyOverloadMethod{Methods: methods})
191230
}
192231

193232
// CheckOverloadMethod checks a func is overload method or not.

import.go

+6-4
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ func InitThisGopPkg(pkg *types.Package) {
9494
scope := pkg.Scope()
9595
gopos := make([]string, 0, 4)
9696
overloads := make(map[omthd][]types.Object)
97+
mobjectIndexs := make(map[types.Object]int)
9798
onameds := make(map[string][]*types.Named)
9899
names := scope.Names()
99100
for _, name := range names {
@@ -120,6 +121,7 @@ func InitThisGopPkg(pkg *types.Package) {
120121
mthd := mName[:len(mName)-3]
121122
key := omthd{named, mthd}
122123
overloads[key] = append(overloads[key], m)
124+
mobjectIndexs[m] = i
123125
}
124126
}
125127
if isOverload(name) { // overload named
@@ -150,15 +152,15 @@ func InitThisGopPkg(pkg *types.Package) {
150152
}
151153
}
152154
if len(fns) > 0 {
153-
newOverload(pkg, scope, m, fns)
155+
newOverload(pkg, scope, m, fns, nil)
154156
}
155157
delete(overloads, m)
156158
}
157159
}
158160
for key, items := range overloads {
159161
off := len(key.name) + 2
160162
fns := overloadFuncs(off, items)
161-
newOverload(pkg, scope, key, fns)
163+
newOverload(pkg, scope, key, fns, mobjectIndexs)
162164
}
163165
for name, items := range onameds {
164166
off := len(name) + 2
@@ -282,7 +284,7 @@ func checkOverloads(scope *types.Scope, gopoName string) (ret []string, exists b
282284
return
283285
}
284286

285-
func newOverload(pkg *types.Package, scope *types.Scope, m omthd, fns []types.Object) {
287+
func newOverload(pkg *types.Package, scope *types.Scope, m omthd, fns []types.Object, mobjectIndexs map[types.Object]int) {
286288
if m.typ == nil {
287289
if debugImport {
288290
log.Println("==> NewOverloadFunc", m.name)
@@ -294,7 +296,7 @@ func newOverload(pkg *types.Package, scope *types.Scope, m omthd, fns []types.Ob
294296
if debugImport {
295297
log.Println("==> NewOverloadMethod", m.typ.Obj().Name(), m.name)
296298
}
297-
NewOverloadMethod(m.typ, token.NoPos, pkg, m.name, fns...)
299+
NewOverloadMethod(m.typ, token.NoPos, pkg, m.name, mobjectIndexs, fns...)
298300
}
299301
}
300302

internal/foo/foo.go

+32
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,36 @@ type NodeSeter interface {
116116
Attr__1(k, v string) (ret NodeSeter)
117117
}
118118

119+
type Data[T any] struct {
120+
data []T
121+
}
122+
123+
func (p *Data[T]) Size() int {
124+
return len(p.data)
125+
}
126+
127+
func (p *Data[T]) Add__0(v ...T) {
128+
p.data = append(p.data, v...)
129+
}
130+
131+
func (p *Data[T]) Add__1(v Data[T]) {
132+
p.data = append(p.data, v.data...)
133+
}
134+
135+
func (p *Data[T]) IndexOf__0(v T) int {
136+
return -1
137+
}
138+
139+
func (p *Data[T]) IndexOf__1(pos int, v T) int {
140+
return -1
141+
}
142+
143+
type DataInterface[T any] interface {
144+
Size() int
145+
Add__0(v ...T)
146+
Add__1(v DataInterface[T])
147+
IndexOf__0(v T) int
148+
IndexOf__1(pos int, v T) int
149+
}
150+
119151
// -----------------------------------------------------------------------------

typeparams_test.go

+84
Original file line numberDiff line numberDiff line change
@@ -1157,3 +1157,87 @@ func main() {
11571157
}
11581158
`)
11591159
}
1160+
1161+
func TestGenericTypeOverloadMethod(t *testing.T) {
1162+
pkg := newMainPackage()
1163+
foo := pkg.Import("github.com/goplus/gogen/internal/foo")
1164+
tyDataT := foo.Ref("Data").Type()
1165+
tyInt := types.Typ[types.Int]
1166+
tyData, _ := types.Instantiate(nil, tyDataT, []types.Type{tyInt}, true)
1167+
v := pkg.NewParam(token.NoPos, "v", tyData)
1168+
pkg.NewFunc(nil, "bar", types.NewTuple(v), nil, false).BodyStart(pkg).
1169+
DefineVarStart(token.NoPos, "n").Val(v).
1170+
Debug(func(cb *gogen.CodeBuilder) {
1171+
cb.Member("size", gogen.MemberFlagMethodAlias)
1172+
}).
1173+
Call(0).EndInit(1).EndStmt().
1174+
Val(v).
1175+
Debug(func(cb *gogen.CodeBuilder) {
1176+
cb.Member("add", gogen.MemberFlagMethodAlias)
1177+
}).
1178+
Val(0).Val(1).Call(2).EndStmt().
1179+
Val(v).
1180+
Debug(func(cb *gogen.CodeBuilder) {
1181+
cb.Member("add", gogen.MemberFlagMethodAlias)
1182+
}).
1183+
Val(v).Call(1).EndStmt().
1184+
DefineVarStart(token.NoPos, "i").Val(v).
1185+
Debug(func(cb *gogen.CodeBuilder) {
1186+
cb.Member("indexOf", gogen.MemberFlagMethodAlias)
1187+
}).
1188+
Val(0).Val(1).Call(2).EndInit(1).EndStmt().
1189+
End()
1190+
domTest(t, pkg, `package main
1191+
1192+
import "github.com/goplus/gogen/internal/foo"
1193+
1194+
func bar(v foo.Data[int]) {
1195+
n := v.Size()
1196+
v.Add__0(0, 1)
1197+
v.Add__1(v)
1198+
i := v.IndexOf__1(0, 1)
1199+
}
1200+
`)
1201+
}
1202+
1203+
func TestGenericInterfaceOverloadMethod(t *testing.T) {
1204+
pkg := newMainPackage()
1205+
foo := pkg.Import("github.com/goplus/gogen/internal/foo")
1206+
tyDataT := foo.Ref("DataInterface").Type()
1207+
tyInt := types.Typ[types.Int]
1208+
tyData, _ := types.Instantiate(nil, tyDataT, []types.Type{tyInt}, true)
1209+
v := pkg.NewParam(token.NoPos, "v", tyData)
1210+
pkg.NewFunc(nil, "bar", types.NewTuple(v), nil, false).BodyStart(pkg).
1211+
DefineVarStart(token.NoPos, "n").Val(v).
1212+
Debug(func(cb *gogen.CodeBuilder) {
1213+
cb.Member("size", gogen.MemberFlagMethodAlias)
1214+
}).
1215+
Call(0).EndInit(1).EndStmt().
1216+
Val(v).
1217+
Debug(func(cb *gogen.CodeBuilder) {
1218+
cb.Member("add", gogen.MemberFlagMethodAlias)
1219+
}).
1220+
Val(0).Val(1).Call(2).EndStmt().
1221+
Val(v).
1222+
Debug(func(cb *gogen.CodeBuilder) {
1223+
cb.Member("add", gogen.MemberFlagMethodAlias)
1224+
}).
1225+
Val(v).Call(1).EndStmt().
1226+
DefineVarStart(token.NoPos, "i").Val(v).
1227+
Debug(func(cb *gogen.CodeBuilder) {
1228+
cb.Member("indexOf", gogen.MemberFlagMethodAlias)
1229+
}).
1230+
Val(0).Val(1).Call(2).EndInit(1).EndStmt().
1231+
End()
1232+
domTest(t, pkg, `package main
1233+
1234+
import "github.com/goplus/gogen/internal/foo"
1235+
1236+
func bar(v foo.DataInterface[int]) {
1237+
n := v.Size()
1238+
v.Add__0(0, 1)
1239+
v.Add__1(v)
1240+
i := v.IndexOf__1(0, 1)
1241+
}
1242+
`)
1243+
}

0 commit comments

Comments
 (0)