Skip to content

Commit 1dae57f

Browse files
committed
support generic type oveload method
1 parent 49986c4 commit 1dae57f

File tree

5 files changed

+177
-11
lines changed

5 files changed

+177
-11
lines changed

codebuild.go

+12-4
Original file line numberDiff line numberDiff line change
@@ -1598,7 +1598,7 @@ retry:
15981598
return kind
15991599
}
16001600
}
1601-
if kind := p.method(t, name, aliasName, flag, arg, srcExpr); kind != MemberInvalid {
1601+
if kind := p.method(t, name, aliasName, flag, arg, srcExpr, t.TypeArgs() != nil); kind != MemberInvalid {
16021602
return kind
16031603
}
16041604
if fstruc {
@@ -1614,7 +1614,7 @@ retry:
16141614
}
16151615
case *types.Named:
16161616
named, typ = o, p.getUnderlying(o) // may cause to loadNamed (delay-loaded)
1617-
if kind := p.method(o, name, aliasName, flag, arg, srcExpr); kind != MemberInvalid {
1617+
if kind := p.method(o, name, aliasName, flag, arg, srcExpr, o.TypeArgs() != nil); kind != MemberInvalid {
16181618
return kind
16191619
}
16201620
if _, ok := typ.(*types.Struct); ok {
@@ -1630,7 +1630,7 @@ retry:
16301630
}
16311631
case *types.Interface:
16321632
o.Complete()
1633-
if kind := p.method(o, name, aliasName, flag, arg, srcExpr); kind != MemberInvalid {
1633+
if kind := p.method(o, name, aliasName, flag, arg, srcExpr, false); kind != MemberInvalid {
16341634
return kind
16351635
}
16361636
case *types.Basic, *types.Slice, *types.Map, *types.Chan:
@@ -1640,6 +1640,7 @@ retry:
16401640
}
16411641

16421642
type methodList interface {
1643+
types.Type
16431644
NumMethods() int
16441645
Method(i int) *types.Func
16451646
}
@@ -1666,7 +1667,7 @@ func (p *CodeBuilder) allowAccess(pkg *types.Package, name string) bool {
16661667
}
16671668

16681669
func (p *CodeBuilder) method(
1669-
o methodList, name, aliasName string, flag MemberFlag, arg *Element, src ast.Node) (kind MemberKind) {
1670+
o methodList, name, aliasName string, flag MemberFlag, arg *Element, src ast.Node, namedHasTypeArgs bool) (kind MemberKind) {
16701671
var found *types.Func
16711672
var exact bool
16721673
for i, n := 0, o.NumMethods(); i < n; i++ {
@@ -1691,6 +1692,13 @@ func (p *CodeBuilder) method(
16911692
if autoprop && !methodHasAutoProperty(typ, 0) {
16921693
return memberBad
16931694
}
1695+
if namedHasTypeArgs {
1696+
if t, ok := CheckFuncEx(typ.(*types.Signature)); ok {
1697+
if m, ok := t.(*TyOverloadMethod); ok && m.IsGeneric() {
1698+
typ = m.Instantiate(o.(*types.Named))
1699+
}
1700+
}
1701+
}
16941702
sel := selector(arg, found.Name())
16951703
p.stk.Ret(1, &internal.Elem{
16961704
Val: sel,

func_ext.go

+43-3
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,9 @@ func CheckOverloadFunc(sig *types.Signature) (funcs []types.Object, ok bool) {
117117

118118
// TyOverloadMethod: overload function type
119119
type TyOverloadMethod struct {
120-
Methods []types.Object
120+
Methods []types.Object
121+
indexs []int // func object indexs
122+
instance map[*types.Named]*types.Signature // cache type signature for named
121123
}
122124

123125
func (p *TyOverloadMethod) At(i int) types.Object { return p.Methods[i] }
@@ -127,8 +129,46 @@ func (p *TyOverloadMethod) Underlying() types.Type { return p }
127129
func (p *TyOverloadMethod) String() string { return "TyOverloadMethod" }
128130
func (p *TyOverloadMethod) funcEx() {}
129131

130-
func NewOverloadMethod(typ *types.Named, pos token.Pos, pkg *types.Package, name string, methods ...types.Object) *types.Func {
131-
return newMethodEx(typ, pos, pkg, name, &TyOverloadMethod{methods})
132+
func NewOverloadMethod(typ *types.Named, pos token.Pos, pkg *types.Package, name string, objectIndex map[types.Object]int, methods ...types.Object) *types.Func {
133+
t := &TyOverloadMethod{Methods: methods}
134+
if typ.TypeParams() != nil {
135+
t.indexs = make([]int, len(methods))
136+
for i, obj := range methods {
137+
t.indexs[i] = objectIndex[obj]
138+
}
139+
t.instance = make(map[*types.Named]*types.Signature)
140+
}
141+
return newMethodEx(typ, pos, pkg, name, t)
142+
}
143+
144+
func (m *TyOverloadMethod) IsGeneric() bool {
145+
return len(m.indexs) != 0
146+
}
147+
148+
func (m *TyOverloadMethod) Instantiate(named *types.Named) *types.Signature {
149+
sig, ok := m.instance[named]
150+
if !ok {
151+
sig = newOverloadMethodType(named, m)
152+
m.instance[named] = sig
153+
}
154+
return sig
155+
}
156+
157+
func newOverloadMethodType(named *types.Named, m *TyOverloadMethod) *types.Signature {
158+
var list methodList
159+
switch t := named.Underlying().(type) {
160+
case *types.Interface:
161+
list = t
162+
default:
163+
list = named
164+
}
165+
pkg := named.Obj().Pkg()
166+
recv := types.NewVar(token.NoPos, pkg, "", named)
167+
methods := make([]types.Object, len(m.indexs))
168+
for i, index := range m.indexs {
169+
methods[i] = list.Method(index)
170+
}
171+
return sigFuncEx(pkg, recv, &TyOverloadMethod{Methods: methods})
132172
}
133173

134174
func CheckOverloadMethod(sig *types.Signature) (methods []types.Object, ok bool) {

import.go

+6-4
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ func initThisGopPkg(pkg *types.Package) {
107107
}
108108
gopos := make([]string, 0, 4)
109109
overloads := make(map[omthd][]types.Object)
110+
mobjectIndexs := make(map[types.Object]int)
110111
onameds := make(map[string][]*types.Named)
111112
names := scope.Names()
112113
for _, name := range names {
@@ -133,6 +134,7 @@ func initThisGopPkg(pkg *types.Package) {
133134
mthd := mName[:len(mName)-3]
134135
key := omthd{named, mthd}
135136
overloads[key] = append(overloads[key], m)
137+
mobjectIndexs[m] = i
136138
}
137139
}
138140
if isOverload(name) { // overload named
@@ -160,14 +162,14 @@ func initThisGopPkg(pkg *types.Package) {
160162
}
161163
fns[i] = lookupFunc(scope, name, tname)
162164
}
163-
newOverload(pkg, scope, m, fns)
165+
newOverload(pkg, scope, m, fns, nil)
164166
delete(overloads, m)
165167
}
166168
}
167169
for key, items := range overloads {
168170
off := len(key.name) + 2
169171
fns := overloadFuncs(off, items)
170-
newOverload(pkg, scope, key, fns)
172+
newOverload(pkg, scope, key, fns, mobjectIndexs)
171173
}
172174
for name, items := range onameds {
173175
off := len(name) + 2
@@ -290,7 +292,7 @@ func checkOverloads(scope *types.Scope, gopoName string) (ret []string, exists b
290292
return
291293
}
292294

293-
func newOverload(pkg *types.Package, scope *types.Scope, m omthd, fns []types.Object) {
295+
func newOverload(pkg *types.Package, scope *types.Scope, m omthd, fns []types.Object, mobjectIndexs map[types.Object]int) {
294296
if m.typ == nil {
295297
if debugImport {
296298
log.Println("==> NewOverloadFunc", m.name)
@@ -302,7 +304,7 @@ func newOverload(pkg *types.Package, scope *types.Scope, m omthd, fns []types.Ob
302304
if debugImport {
303305
log.Println("==> NewOverloadMethod", m.typ.Obj().Name(), m.name)
304306
}
305-
NewOverloadMethod(m.typ, token.NoPos, pkg, m.name, fns...)
307+
NewOverloadMethod(m.typ, token.NoPos, pkg, m.name, mobjectIndexs, fns...)
306308
}
307309
}
308310

internal/foo/foo.go

+32
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,36 @@ type NodeSeter interface {
116116
Attr__1(k, v string) (ret NodeSeter)
117117
}
118118

119+
type Data[T any] struct {
120+
data []T
121+
}
122+
123+
func (p *Data[T]) Size() int {
124+
return len(p.data)
125+
}
126+
127+
func (p *Data[T]) Add__0(v ...T) {
128+
p.data = append(p.data, v...)
129+
}
130+
131+
func (p *Data[T]) Add__1(v Data[T]) {
132+
p.data = append(p.data, v.data...)
133+
}
134+
135+
func (p *Data[T]) IndexOf__0(v T) int {
136+
return -1
137+
}
138+
139+
func (p *Data[T]) IndexOf__1(pos int, v T) int {
140+
return -1
141+
}
142+
143+
type DataInterface[T any] interface {
144+
Size() int
145+
Add__0(v ...T)
146+
Add__1(v DataInterface[T])
147+
IndexOf__0(v T) int
148+
IndexOf__1(pos int, v T) int
149+
}
150+
119151
// -----------------------------------------------------------------------------

typeparams_test.go

+84
Original file line numberDiff line numberDiff line change
@@ -962,3 +962,87 @@ func main() {
962962
}
963963
`)
964964
}
965+
966+
func TestGenericTypeOverloadMethod(t *testing.T) {
967+
pkg := newMainPackage()
968+
foo := pkg.Import("github.com/goplus/gox/internal/foo")
969+
tyDataT := foo.Ref("Data").Type()
970+
tyInt := types.Typ[types.Int]
971+
tyData, _ := types.Instantiate(nil, tyDataT, []types.Type{tyInt}, true)
972+
v := pkg.NewParam(token.NoPos, "v", tyData)
973+
pkg.NewFunc(nil, "bar", types.NewTuple(v), nil, false).BodyStart(pkg).
974+
DefineVarStart(token.NoPos, "n").Val(v).
975+
Debug(func(cb *gox.CodeBuilder) {
976+
cb.Member("size", gox.MemberFlagMethodAlias)
977+
}).
978+
Call(0).EndInit(1).EndStmt().
979+
Val(v).
980+
Debug(func(cb *gox.CodeBuilder) {
981+
cb.Member("add", gox.MemberFlagMethodAlias)
982+
}).
983+
Val(0).Val(1).Call(2).EndStmt().
984+
Val(v).
985+
Debug(func(cb *gox.CodeBuilder) {
986+
cb.Member("add", gox.MemberFlagMethodAlias)
987+
}).
988+
Val(v).Call(1).EndStmt().
989+
DefineVarStart(token.NoPos, "i").Val(v).
990+
Debug(func(cb *gox.CodeBuilder) {
991+
cb.Member("indexOf", gox.MemberFlagMethodAlias)
992+
}).
993+
Val(0).Val(1).Call(2).EndInit(1).EndStmt().
994+
End()
995+
domTest(t, pkg, `package main
996+
997+
import "github.com/goplus/gox/internal/foo"
998+
999+
func bar(v foo.Data[int]) {
1000+
n := v.Size()
1001+
v.Add__0(0, 1)
1002+
v.Add__1(v)
1003+
i := v.IndexOf__1(0, 1)
1004+
}
1005+
`)
1006+
}
1007+
1008+
func TestGenericInterfaceOverloadMethod(t *testing.T) {
1009+
pkg := newMainPackage()
1010+
foo := pkg.Import("github.com/goplus/gox/internal/foo")
1011+
tyDataT := foo.Ref("DataInterface").Type()
1012+
tyInt := types.Typ[types.Int]
1013+
tyData, _ := types.Instantiate(nil, tyDataT, []types.Type{tyInt}, true)
1014+
v := pkg.NewParam(token.NoPos, "v", tyData)
1015+
pkg.NewFunc(nil, "bar", types.NewTuple(v), nil, false).BodyStart(pkg).
1016+
DefineVarStart(token.NoPos, "n").Val(v).
1017+
Debug(func(cb *gox.CodeBuilder) {
1018+
cb.Member("size", gox.MemberFlagMethodAlias)
1019+
}).
1020+
Call(0).EndInit(1).EndStmt().
1021+
Val(v).
1022+
Debug(func(cb *gox.CodeBuilder) {
1023+
cb.Member("add", gox.MemberFlagMethodAlias)
1024+
}).
1025+
Val(0).Val(1).Call(2).EndStmt().
1026+
Val(v).
1027+
Debug(func(cb *gox.CodeBuilder) {
1028+
cb.Member("add", gox.MemberFlagMethodAlias)
1029+
}).
1030+
Val(v).Call(1).EndStmt().
1031+
DefineVarStart(token.NoPos, "i").Val(v).
1032+
Debug(func(cb *gox.CodeBuilder) {
1033+
cb.Member("indexOf", gox.MemberFlagMethodAlias)
1034+
}).
1035+
Val(0).Val(1).Call(2).EndInit(1).EndStmt().
1036+
End()
1037+
domTest(t, pkg, `package main
1038+
1039+
import "github.com/goplus/gox/internal/foo"
1040+
1041+
func bar(v foo.DataInterface[int]) {
1042+
n := v.Size()
1043+
v.Add__0(0, 1)
1044+
v.Add__1(v)
1045+
i := v.IndexOf__1(0, 1)
1046+
}
1047+
`)
1048+
}

0 commit comments

Comments
 (0)