Skip to content

Commit

Permalink
improved generics support
Browse files Browse the repository at this point in the history
  • Loading branch information
ccbrown committed Aug 27, 2019
1 parent e70023c commit a58af6b
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 75 deletions.
Binary file modified experimental/generics/cmd/preprocess.wasm
Binary file not shown.
217 changes: 142 additions & 75 deletions experimental/generics/preprocessor/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ func isParameterized(obj types.Object) bool {
}

var generatedDecls = map[string]ast.Decl{}
var methodTemplates = map[string][]*ast.FuncDecl{}
var functionTemplates = map[string]*ast.FuncDecl{}
var typeTemplates = map[string]*ast.TypeSpec{}

Expand All @@ -114,6 +115,7 @@ func preprocess(node ast.Node, info *types.Info, parentTypeParams map[string]typ
targs []types.Type
}

instantiatedMethods := map[string]instantiation{}
instantiatedFunctions := map[string]instantiation{}
instantiatedTypes := map[string]instantiation{}

Expand Down Expand Up @@ -148,9 +150,18 @@ func preprocess(node ast.Node, info *types.Info, parentTypeParams map[string]typ
targs: targs,
}
} else {
instantiatedTypes[name] = instantiation{
name: node.Name,
targs: targs,
hasKnownTypes := true
for _, targ := range targs {
if _, ok := targ.(*types.TypeParam); ok {
hasKnownTypes = false
break
}
}
if hasKnownTypes {
instantiatedTypes[name] = instantiation{
name: node.Name,
targs: targs,
}
}
}
c.Replace(ast.NewIdent(name))
Expand All @@ -160,62 +171,22 @@ func preprocess(node ast.Node, info *types.Info, parentTypeParams map[string]typ
return true
}, nil)

// Generate functions
for name, f := range instantiatedFunctions {
if _, ok := generatedDecls[name]; ok {
continue
}

for i, targ := range f.targs {
if param, ok := targ.(*types.TypeParam); ok {
f.targs[i] = parentTypeParams[param.String()]
for len(instantiatedMethods) > 0 || len(instantiatedFunctions) > 0 || len(instantiatedTypes) > 0 {
// Generate functions
for name, f := range instantiatedFunctions {
if _, ok := generatedDecls[name]; ok {
continue
}
}

template := astcopy.FuncDecl(functionTemplates[f.name])
template.Name = ast.NewIdent(name)
decl := astutil.Apply(template, func(c *astutil.Cursor) bool {
ident, ok := c.Node().(*ast.Ident)
if !ok {
return true
}
t := info.Types[ident]
if !t.IsType() {
return true
}
for i, tparam := range template.TParams.List {
for j, tpident := range tparam.Names {
if info.Defs[tpident] == info.Uses[ident] {
newIdent := ast.NewIdent(f.targs[i+j].String())
c.Replace(newIdent)
}
for i, targ := range f.targs {
if param, ok := targ.(*types.TypeParam); ok {
f.targs[i] = parentTypeParams[param.String()]
}
}
return true
}, nil).(*ast.FuncDecl)
parentTypeParams := map[string]types.Type{}
for i, tparam := range template.TParams.List {
for j, tpident := range tparam.Names {
parentTypeParams[tpident.Name] = f.targs[i+j]
}
}
generatedDecls[name] = nil
decl = preprocess(decl, info, parentTypeParams).(*ast.FuncDecl)
decl.TParams = nil
generatedDecls[name] = decl
}

// Generate types
for len(instantiatedTypes) > 0 {
temp := instantiatedTypes
instantiatedTypes = map[string]instantiation{}
for name, f := range temp {
if _, ok := generatedDecls[name]; ok {
continue
}

template := astcopy.TypeSpec(typeTemplates[f.name])
spec := astutil.Apply(template, func(c *astutil.Cursor) bool {
template := astcopy.FuncDecl(functionTemplates[f.name])
template.Name = ast.NewIdent(name)
decl := astutil.Apply(template, func(c *astutil.Cursor) bool {
ident, ok := c.Node().(*ast.Ident)
if !ok {
return true
Expand All @@ -224,38 +195,125 @@ func preprocess(node ast.Node, info *types.Info, parentTypeParams map[string]typ
if !t.IsType() {
return true
}
if info.Uses[ident] == info.Defs[template.Name] {
c.Replace(ast.NewIdent(name))
return true
}
for i, tparam := range template.TParams.List {
for j, tpident := range tparam.Names {
if info.Defs[tpident] == info.Uses[ident] {
if named, ok := f.targs[i+j].(*types.Named); ok {
if targs, ok := info.NamedTypeArguments[named]; ok {
templateName := strings.Split(strings.Split(named.String(), ".")[1], "<")[0]
name := generatedName(templateName, targs)
instantiatedTypes[name] = instantiation{
name: templateName,
targs: targs,
newIdent := ast.NewIdent(f.targs[i+j].String())
c.Replace(newIdent)
}
}
}
return true
}, nil).(*ast.FuncDecl)
parentTypeParams := map[string]types.Type{}
for i, tparam := range template.TParams.List {
for j, tpident := range tparam.Names {
parentTypeParams[tpident.Name] = f.targs[i+j]
}
}
generatedDecls[name] = nil
decl = preprocess(decl, info, parentTypeParams).(*ast.FuncDecl)
decl.TParams = nil
generatedDecls[name] = decl
}
instantiatedFunctions = map[string]instantiation{}

// Generate types
for len(instantiatedTypes) > 0 {
temp := instantiatedTypes
instantiatedTypes = map[string]instantiation{}
for name, f := range temp {
if _, ok := generatedDecls[name]; ok {
continue
}

template := astcopy.TypeSpec(typeTemplates[f.name])
spec := astutil.Apply(template, func(c *astutil.Cursor) bool {
if call, ok := c.Node().(*ast.CallExpr); ok {
if ident, ok := call.Fun.(*ast.Ident); ok && isParameterized(info.Uses[ident]) {
if targs := info.TypeArguments[ident]; targs != nil {
resolvedArgs := make([]types.Type, len(targs))
copy(resolvedArgs, targs)
for i, targ := range targs {
if tp, ok := targ.(*types.TypeParam); ok {
for j, tparam := range template.TParams.List {
for k, tpident := range tparam.Names {
if tp.String() == tpident.Name {
resolvedArgs[i] = f.targs[j+k]
}
}
}
}
c.Replace(ast.NewIdent(name))
return true
}
targs = resolvedArgs
name := generatedName(ident.Name, targs)
instantiatedTypes[name] = instantiation{
name: ident.Name,
targs: targs,
}
c.Replace(ast.NewIdent(name))
return false
}
c.Replace(ast.NewIdent(f.targs[i+j].String()))
}
}

ident, ok := c.Node().(*ast.Ident)
if !ok {
return true
}
t := info.Types[ident]
if !t.IsType() {
return true
}
if info.Uses[ident] == info.Defs[template.Name] {
c.Replace(ast.NewIdent(name))
return true
}
for i, tparam := range template.TParams.List {
for j, tpident := range tparam.Names {
if info.Defs[tpident] == info.Uses[ident] {
if named, ok := f.targs[i+j].(*types.Named); ok {
if targs, ok := info.NamedTypeArguments[named]; ok {
templateName := strings.Split(strings.Split(named.String(), ".")[1], "<")[0]
name := generatedName(templateName, targs)
instantiatedTypes[name] = instantiation{
name: templateName,
targs: targs,
}
c.Replace(ast.NewIdent(name))
return true
}
}
c.Replace(ast.NewIdent(f.targs[i+j].String()))
}
}
}
return true
}, nil).(*ast.TypeSpec)
spec.TParams = nil
spec.Name = ast.NewIdent(name)
generatedDecls[name] = &ast.GenDecl{
Tok: token.TYPE,
Specs: []ast.Spec{spec},
}
for _, template := range methodTemplates[f.name] {
name := generatedName(template.Name.Name, f.targs)
instantiatedMethods[name] = instantiation{
name: f.name + "." + template.Name.Name,
targs: f.targs,
}
}
return true
}, nil).(*ast.TypeSpec)
spec.TParams = nil
spec.Name = ast.NewIdent(name)
generatedDecls[name] = &ast.GenDecl{
Tok: token.TYPE,
Specs: []ast.Spec{spec},
}
}

// Generate methods
for name, _ := range instantiatedMethods {
if _, ok := generatedDecls[name]; ok {
continue
}
panic("methods with generic receivers are not yet implemented")
}
instantiatedMethods = map[string]instantiation{}
}

return node
Expand Down Expand Up @@ -307,6 +365,15 @@ func main() {
functionTemplates[node.Name.Name] = node
c.Delete()
return false
} else if node.Recv != nil {
t := info.Types[node.Recv.List[0].Type]
if strings.Contains(t.Type.String(), "(") {
receiverName := strings.Split(strings.TrimPrefix(strings.Split(t.Type.String(), "(")[0], "*"), ".")[1]
methodTemplates[receiverName] = append(methodTemplates[receiverName], node)
functionTemplates[receiverName+"."+node.Name.Name] = node
c.Delete()
return false
}
}
case *ast.GenDecl:
var newSpecs []ast.Spec
Expand Down
14 changes: 14 additions & 0 deletions experimental/generics/preprocessor/testdata/nested_types.go2
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package main

type Foo(type T) struct {
bar Bar(T, T)
}

type Bar(type T, U) struct {
v U
}

func main() {
f := Foo(int){}
println(f.bar.v)
}

0 comments on commit a58af6b

Please sign in to comment.